/* upsilon_cpp.cpp
 * 
 *  Author: Xuye Luo, Joe Song
 *  
 * Updated: 
 * 
 *   December 20, 2025. Changed the input type
 *     from C++ double vector to integer vector, 
 *     corresponding to factor vector in R.
 *      
 *   December 12, 2025
 */

#include <Rcpp.h>
#include <unordered_map>
using namespace Rcpp;

// [[Rcpp::interfaces(r, cpp)]]

/* 
//' @title Fast Upsilon Statistic (C++ Backend)
//' @description Calculates the Upsilon statistic directly from raw data vectors
//' using hash maps for efficient counting. 
//' Optimized using algebraic expansion to avoid iterating over zero cells.
//' @param x Numeric vector of the first variable.
//' @param y Numeric vector of the second variable.
//' @return A List containing the statistic, sample size (n), row count (nr), and col count (nc).
 */

// [[Rcpp::export]]
List upsilon_cpp(
    const IntegerVector &x, 
    const IntegerVector &y
) 
{
  int n = x.size();
  
  // Input Validation
  if (n != y.size()) {
    stop("Lengths of 'x' and 'y' must match.");
  }
  
  if (n == 0) {
    return List::create(Named("statistic") = 0, 
                        Named("n")  = 0,
                        Named("nr") = 0,
                        Named("nc") = 0);
  }

  // Build Contingency Table using Hash Maps
  std::unordered_map<
    unsigned, std::unordered_map<
      unsigned, unsigned
    > > observed;
  
  std::unordered_map<
    unsigned, unsigned
  > row_sum;
  
  std::unordered_map<
    unsigned, unsigned
  > col_sum;
  
  for (int i = 0; i < n; i++) {
    auto & val_x = x[i];
    auto & val_y = y[i];
    
    observed[val_x][val_y]++;
    row_sum[val_x]++;
    col_sum[val_y]++;
  }
  
  auto nr = row_sum.size();
  auto nc = col_sum.size();
  

  if (nr < 2 || nc < 2) {
    return List::create(Named("statistic") = 0, 
                        Named("n")  = n,
                        Named("nr") = nr,
                        Named("nc") = nc);
  }

  // Calculate Upsilon Statistic
  // Formula: Upsilon = Sum((O - E)^2) / avg
  // where avg = n / (nr * nc)
  // Expansion: Sum((O-E)^2) = Sum(O^2) - 2*Sum(O*E) + Sum(E^2)
  
  double n_dbl = (double) n;
  double avg = n_dbl / (nr * nc);
  
  // Sum(O^2) - 2*Sum(O*E)
  // We only iterate over non-zero observed cells (Efficient for sparse data)
  double sum_O2_minus_2OE = 0.0;
  
  for (auto const& row : observed) {
    double r_sum = row_sum[row.first];
    
    for (auto const& cell : row.second) {
      double O = cell.second;
      double c_sum = col_sum[cell.first];
      
      // E = (RowSum * ColSum) / N
      double E = (r_sum * c_sum) / n_dbl;
      
      sum_O2_minus_2OE += (O * O) - (2.0 * O * E);
    }
  }
  
  // Sum(E^2)
  // Sum(E_ij^2) = (Sum(RowSum^2) * Sum(ColSum^2)) / N^2
  double row_sq_sum = 0.0;
  double col_sq_sum = 0.0;
  
  for (auto const& r : row_sum) row_sq_sum += r.second * r.second;
  for (auto const& c : col_sum) col_sq_sum += c.second * c.second;
  
  double sum_E2 = (row_sq_sum * col_sq_sum) / (n_dbl * n_dbl);
  
  // Combine Parts
  double numerator = sum_O2_minus_2OE + sum_E2;
  double statistic = numerator / avg;


  if (statistic < 0) statistic = 0;

  return List::create(Named("statistic") = statistic, 
                      Named("n")  = n,
                      Named("nr") = nr,
                      Named("nc") = nc);
}
