#' Estimate the Multivariate Regression Association Measure
#'
#' @param y_data A \eqn{n \times d} matrix of responses, where \eqn{n} is the sample size.
#' @param x_data A \eqn{n \times p} matrix of predictors.
#' @param z_data A \eqn{n \times q} matrix of conditional predictors. The default value is \code{NULL}.
#' @param bootstrap Perform the \eqn{m}-out-of-\eqn{n} bootstrap if \code{TRUE}. The default value is \code{FALSE}.
#' @param B Number of bootstrap replications. The default value is \code{1000}.
#' @param g_vec A vector of candidate values for \eqn{\gamma} between 0 and 1, used to generate a collection of rules for the \eqn{m}-out-of-\eqn{n} bootstrap. The default value is \code{seq(0.4,0.9,by = 0.05)}.
#' @description Compute \eqn{T_n} and its standard error estimates using the nearest neighbor method and the \eqn{m}-out-of-\eqn{n} bootstrap.
#' @details Let \eqn{\{({\bf X}_i,{\bf Y}_i,{\bf Z}_i)\}_{i = 1}^n} be independent and identically distributed data from the population \eqn{({\bf X},{\bf Y},{\bf Z})}. The estimate \eqn{T_n({\bf X},{\bf Y})} for the unconditional measure (\code{z_data = NULL}) is given as
#'
#' \deqn{T_n({\bf X},{\bf Y}) = \binom{n}{2}^{-1} \sum_{i < j} \langle S({{\bf Y}_i - {\bf Y}_j}), S({{\bf Y}_{N(i)} - {\bf Y}_{N(j)}}) \rangle,}
#'
#' where \eqn{\langle \cdot, \cdot \rangle} is the dot product, \eqn{S(\cdot)} is the spatial sign function, and \eqn{N(i)} is the index \eqn{j} such that \eqn{{\bf X}_j} is the nearest neighbor of \eqn{{\bf X}_i} according to the Euclidean distance. The estimate \eqn{T_n({\bf X},{\bf Y} \mid {\bf Z})} for the conditional measure is given as
#'
#' \deqn{T_n({\bf X},{\bf Y} \mid {\bf Z} ) = \frac{T_n(({\bf X},{\bf Z}),{\bf Y} ) - T_n({\bf Z},{\bf Y} )}{1 - T_n({\bf Z},{\bf Y} )}.}
#'
#' See the paper Shih and Chen (2025, in revision) for more details.
#'
#' For the \eqn{m}-out-of-\eqn{n} bootstrap, the rule (resample size) is set to be \eqn{m = \lfloor n^\gamma \rfloor}, where \eqn{\lfloor x \rfloor} denotes the largest integer that is smaller than or equal to \eqn{x} and \eqn{0 < \gamma < 1} takes values from the vector \code{g_vec}. It is recommended to use \code{T_se_cluster}, the standard error estimate obtained based on the cluster rule. See Dette and Kroll (2024) for more details.
#'
#' The \code{mram} function is used in \code{\link{vs_mram}} function for variable selection.
#'
#' @return \item{T_est}{The estimate of the multivariate regression association measure. The value returned by \code{T_est} is between \eqn{-1} and \eqn{1}. However, it is between \eqn{0} and \eqn{1} asymptotically. A small value indicates that \code{x_data} has low predictability for \code{y_data} condition on \code{z_data} in the sense of the considered measure. On the other hand, a large value indicates that \code{x_data} has high predictability for \code{y_data} condition on \code{z_data}. If \code{z_data = NULL}, the returned value indicates the unconditional predictability.}
#' \item{T_se_cluster}{The standard error estimate based on the cluster rule.}
#' \item{m_vec}{The vector of \eqn{m} generated by \code{g_vec}.}
#' \item{T_se_vec}{The vector of standard error estimates obtained from the \eqn{m}-out-of-\eqn{n} bootstrap, where \eqn{m} is equal to \code{m_vec}.}
#' \item{J_cluster}{The index of the best \code{m_vec} chosen by the cluster rule.}
#'
#' @references Dette and Kroll (2024) A Simple Bootstrap for Chatterjee’s Rank Correlation, Biometrika, asae045.
#' @references Shih and Chen (2026) Measuring multivariate regression association via spatial sign, Computational Statistics & Data Analysis, 215, 108288.
#' @seealso \code{\link{vs_mram}}
#'
#' @importFrom stats ks.test sd
#' @importFrom utils combn
#' @importFrom RANN nn2
#' @export
#'
#' @examples
#' library(MRAM)
#'
#' n = 100
#'
#' set.seed(1)
#' x_data = matrix(rnorm(n*2),n,2)
#' y_data = matrix(0,n,2)
#' y_data[,1] = x_data[,1]*x_data[,2]+x_data[,1]+rnorm(n)
#' y_data[,2] = x_data[,1]*x_data[,2]-x_data[,1]+rnorm(n)
#'
#' mram(y_data,x_data[,1],x_data[,2])
#' mram(y_data,x_data[,2],x_data[,1])
#' mram(y_data,x_data[,1])
#' mram(y_data,x_data[,2])
#'
#' \dontrun{
#'
#' # perform the m-out-of-n bootstrap
#' mram(y_data,x_data[,1],x_data[,2],bootstrap = TRUE)
#' mram(y_data,x_data[,2],x_data[,1],bootstrap = TRUE)
#' mram(y_data,x_data[,1],bootstrap = TRUE)
#' mram(y_data,x_data[,2],bootstrap = TRUE)
#' }

mram = function(y_data,
                x_data,
                z_data = NULL,
                bootstrap = FALSE,
                B = 1000,
                g_vec = seq(0.4,0.9,by = 0.05)) {

  y_data = as.matrix(y_data)

  n = dim(y_data)[1]

  if (!all(g_vec > 0 & g_vec < 1)) stop("All elements of g_vec must be between 0 and 1.")

  m_vec = floor(n^g_vec)

  n_combn = combn(n,2)
  n_choose = choose(n,2)

  s1 = n_combn[1,]
  s2 = n_combn[2,]

  m_combn_vec = sapply(m_vec,function(a) combn(a,2))
  m_choose_vec = sapply(m_vec,function(a) choose(a,2))

  xz_data = cbind(x_data,z_data)

  # nearest neighbor xz

  nn = RANN::nn2(xz_data,k = 2)
  index = nn$nn.idx[,2]
  y_prime = y_data[index,,drop = FALSE]

  k_y_data = y_data[s1,,drop = FALSE]-y_data[s2,,drop = FALSE]
  y_spatial_sign = k_y_data/sqrt(rowSums(k_y_data^2))

  k_y_prime = y_prime[s1,,drop = FALSE]-y_prime[s2,,drop = FALSE]
  y_prime_spatial_sign = k_y_prime/sqrt(rowSums(k_y_prime^2))
  y_prime_spatial_sign[is.na(y_prime_spatial_sign)] = 0

  T_est_xz = sum(y_spatial_sign*y_prime_spatial_sign)/n_choose

  if (is.null(z_data)) {

    T_est = T_est_xz

  } else {

    # nearest neighbor z

    nn = RANN::nn2(z_data,k = 2)
    index = nn$nn.idx[,2]
    y_prime = y_data[index,,drop = FALSE]

    k_y_data = y_data[s1,,drop = FALSE]-y_data[s2,,drop = FALSE]
    y_spatial_sign = k_y_data/sqrt(rowSums(k_y_data^2))

    k_y_prime = y_prime[s1,,drop = FALSE]-y_prime[s2,,drop = FALSE]
    y_prime_spatial_sign = k_y_prime/sqrt(rowSums(k_y_prime^2))
    y_prime_spatial_sign[is.na(y_prime_spatial_sign)] = 0

    T_est_z = sum(y_spatial_sign*y_prime_spatial_sign)/n_choose

    T_est = (T_est_xz-T_est_z)/(1-T_est_z)

  }

  ### m-out-of-n Bootstrap ###

  if (bootstrap == FALSE) {

    return(list(T_est = T_est))

  } else {

    g_L = length(g_vec)
    i_seq = c(1:g_L)

    T_est_matrix = T_est_xz_matrix = T_est_z_matrix = matrix(0,B,g_L)

    for (i in i_seq) {

      m = m_vec[i]
      m_combn = m_combn_vec[[i]]
      m_choose = m_choose_vec[i]

      s1_m = m_combn[1,]
      s2_m = m_combn[2,]

      for (b in 1:B) {

        boot = sample(1:n,m)

        y_boot = y_data[boot,,drop = FALSE]
        x_boot = x_data[boot,,drop = FALSE]
        z_boot = z_data[boot,,drop = FALSE]

        xz_boot = cbind(x_boot,z_boot)

        # nearest neighbor xz

        nn = RANN::nn2(xz_boot,k = 2)
        index_boot = nn$nn.idx[,2]
        y_pboot = y_boot[index_boot,,drop = FALSE]

        k_y_boot = y_boot[s1_m,,drop = FALSE]-y_boot[s2_m,,drop = FALSE]
        y_boot_spatial_sign = k_y_boot/sqrt(rowSums(k_y_boot^2))

        k_y_pboot = y_pboot[s1_m,,drop = FALSE]-y_pboot[s2_m,,drop = FALSE]
        y_pboot_spatial_sign = k_y_pboot/sqrt(rowSums(k_y_pboot^2))
        y_pboot_spatial_sign[is.na(y_pboot_spatial_sign)] = 0

        T_est_xz_matrix[b,i] = sum(y_boot_spatial_sign*y_pboot_spatial_sign)/m_choose

        if (is.null(z_data)) {

          T_est_matrix[b,i] = T_est_xz_matrix[b,i]

        } else {

          # nearest neighbor z

          nn = RANN::nn2(z_boot,k = 2)
          index_boot = nn$nn.idx[,2]
          y_pboot = y_boot[index_boot,,drop = FALSE]

          k_y_boot = y_boot[s1_m,,drop = FALSE]-y_boot[s2_m,,drop = FALSE]
          y_boot_spatial_sign = k_y_boot/sqrt(rowSums(k_y_boot^2))

          k_y_pboot = y_pboot[s1_m,,drop = FALSE]-y_pboot[s2_m,,drop = FALSE]
          y_pboot_spatial_sign = k_y_pboot/sqrt(rowSums(k_y_pboot^2))
          y_pboot_spatial_sign[is.na(y_pboot_spatial_sign)] = 0

          T_est_z_matrix[b,i] = sum(y_boot_spatial_sign*y_pboot_spatial_sign)/m_choose
          T_est_matrix[b,i] = (T_est_xz_matrix[b,i]-T_est_z_matrix[b,i])/(1-T_est_z_matrix[b,i])

        }

      }

    }

    T_temp = numeric(g_L)
    for (j in 1:g_L) {

      T_temp[j] = suppressWarnings(sum(sapply(1:g_L,function(k) ks.test(T_est_matrix[,j],T_est_matrix[,k])$statistic)))

    }

    J_cluster = which.min(T_temp)

    T_se_vec = numeric(g_L)
    for (j in 1:g_L) {

      T_se_vec[j] = sd(sqrt(m_vec[j])*T_est_matrix[,j])/sqrt(n)

    }

    ### result

    return(list(T_est = T_est,
                T_se_cluster = T_se_vec[J_cluster],
                J_cluster = J_cluster,
                m_vec = m_vec,
                T_se_vec = T_se_vec))

  }

}


