#' Run bootstrap for nuisance function estimation.
#'
#' @param D Treatment variable(s).
#' @param X Design matrix.
#' @param Y Outcome variable.
#' @param B Number of bootstrap replications.
#' @param ml_f Outcome learner for DoubleML (e.g. lrn("regr.cv_glmnet")).
#' @param ml_g Treatment learner for DoubleML (same type).
#' @return List with: E_hat (bootstrapped samples' nuisance function estimation),
#' E_hat_mean (estimated mean of E_hat), E_hat_var (estimated variance of E_hat)
run_bootstrap_E_hat <- function(D, X, Y, B, ml_f, ml_g) {
  n <- length(Y)
  E_hat <- numeric(B)
  for (b in 1:B) {
    selector <- sample(1:n, n, replace = TRUE)
    X_b <- X[selector, , drop = FALSE]
    Y_b <- Y[selector]
    D_b <- D[selector]
    dml_data <- DoubleML::double_ml_data_from_matrix(X = X_b, y = Y_b, d = D_b)
    obj_dml <- DoubleML::DoubleMLPLR$new(dml_data, ml_l = ml_f, ml_m = ml_g)
    fit <- obj_dml$fit()
    causal_hat <- fit$all_coef
    E_hat[b] <- mean(Y - D %*% causal_hat)
  }
  list(E_hat = E_hat, E_hat_mean = mean(E_hat), E_hat_var = var(E_hat))
}


#' Bootstrap-based source detection: identify which sources are transferable to target.
#'
#' @param D_t Target treatment; n_t x 1.
#' @param X_t Target design matrix; n_t x p.
#' @param Y_t Target outcome; n_t x 1.
#' @param D_s_all Source treatments concatenated by row; (sum of source_sizes) x q. Rows are split by source_sizes into K sources.
#' @param X_s_all Source design matrices concatenated by row; (sum of source_sizes) x p. Rows split by source_sizes.
#' @param Y_s_all Source outcomes concatenated by row; (sum of source_sizes) x 1. Rows split by source_sizes.
#' @param source_sizes Integer vector of length K: sample size of each source. Must sum to nrow(Y_s_all).
#' @param B Number of bootstrap replications.
#' @param ml_f Outcome learner for DoubleML (e.g. lrn("regr.cv_glmnet")).
#' @param ml_g Treatment learner for DoubleML (same type).
#' @return Data.frame with source labels and detected source.
boot_detection <- function(D_t, X_t, Y_t, D_s_all, X_s_all, Y_s_all, source_sizes, B, ml_f, ml_g) {

    # target baseline estimation
    boot_t <- run_bootstrap_E_hat(D_t, X_t, Y_t, B, ml_f, ml_g)
    E_hat_mean_t <- boot_t$E_hat_mean
    E_hat_var_t <- boot_t$E_hat_var

    # Convert source_sizes to consecutive index vectors, e.g. c(100,50,150,200) -> c(1:100, 101:150, 151:300, 301:500)
    ends <- cumsum(source_sizes)
    starts <- ends - source_sizes + 1
    source_indexes <- mapply(seq, starts, ends, SIMPLIFY = FALSE)

    detected.source <- matrix(NA, nrow = length(source_sizes), ncol = B)
    # source detection
    for (k in 1:length(source_sizes)) {
        Y_s_k <- Y_s_all[source_indexes[[k]], ]
        D_s_k <- D_s_all[source_indexes[[k]], ]
        X_s_k <- X_s_all[source_indexes[[k]], ]
        boot_s_k <- run_bootstrap_E_hat(D_s_k, X_s_k, Y_s_k, B, ml_f, ml_g)
        E_hat_mean_s_k <- boot_s_k$E_hat_mean
        E_hat_var_s_k <- boot_s_k$E_hat_var
        C0 <- 4
        threshold.s_k <- C0*max(0.01, sqrt(E_hat_var_t + E_hat_var_s_k))
        diff_s_k <- abs(E_hat_mean_s_k - E_hat_mean_t)
        detected.source[k,] <- diff_s_k <= threshold.s_k
    }
    n_sources <- length(source_sizes)
    result <- data.frame(
        Source = paste0("Source ", 1:n_sources),
        detected.source = I(lapply(1:n_sources, function(i) detected.source[i, ])),
        stringsAsFactors = FALSE
    )
    return(result)
}
