#' Centroid Decision Forest (robust, single-function version)
#'
#' Same algorithmic idea as your original:
#'  - SNR-based top-k feature selection
#'  - nearest-centroid split
#'  - bagging over ntrees
#'
#' Extra data-robustness to avoid errors:
#'  - Coerces inputs to numeric matrices (drops/repairs bad columns)
#'  - Removes all-NA / zero-variance columns (train-guided; test matched)
#'  - Aligns xtest columns to xtrain by name (or position fallback)
#'  - Imputes missing xtest cols with training medians
#'  - Trims/validates labels; drops rows with missing labels
#'  - NA-tolerant distances; safe fallbacks for degenerate cases
#'
#' @param xtrain numeric matrix/data.frame (n x p)
#' @param ytrain factor/character labels (length n)
#' @param xtest  numeric matrix/data.frame (m x p or with same-named cols)
#' @param ntrees,depth,mnode,k,mtry,seed as before
#' @return list(predictions, probabilities, feature_importance)
#' @importFrom matrixStats colSds
#' @export
CDF <- function(xtrain, ytrain, xtest,
                ntrees = 500,
                depth  = 3,
                mnode  = 3,
                k      = round(2 * log(ncol(xtrain))),
                mtry   = round(0.2 * ncol(xtrain)),
                seed   = NULL) {

  ## ---------- helpers: data hygiene ----------
  to_numeric_matrix <- function(X) {
    if (is.matrix(X) && is.numeric(X)) return(X)
    if (is.data.frame(X)) {
      Xnum <- lapply(X, function(col) {
        if (is.numeric(col)) return(as.numeric(col))
        if (is.integer(col)) return(as.numeric(col))
        if (is.logical(col)) return(as.numeric(col))
        if (is.factor(col))  return(suppressWarnings(as.numeric(as.character(col))))
        if (is.character(col)) return(suppressWarnings(as.numeric(col)))
        suppressWarnings(as.numeric(col))
      })
      Xmat <- do.call(cbind, Xnum)
      colnames(Xmat) <- names(X)
      return(as.matrix(Xmat))
    }
    # last resort
    X <- as.matrix(X)
    mode(X) <- "numeric"
    X
  }

  col_median <- function(x) {
    med <- stats::median(x, na.rm = TRUE)
    if (is.na(med) || !is.finite(med)) 0 else med
  }

  # Trim labels, drop rows with missing labels
  clean_labels <- function(y) {
    y <- as.character(y)
    y <- trimws(y)
    y[y == ""] <- NA_character_
    y
  }

  ## ---------- coerce & validate ----------
  if (!is.null(seed)) set.seed(seed)

  Xtr_raw <- xtrain
  Xte_raw <- xtest

  Xtr <- to_numeric_matrix(Xtr_raw)
  Xte <- to_numeric_matrix(Xte_raw)

  # basic shape checks
  if (nrow(Xtr) < 2L) stop("xtrain must have at least two rows.")
  if (ncol(Xtr) < 1L) stop("xtrain must have at least one column.")

  ytr <- clean_labels(ytrain)
  if (length(ytr) != nrow(Xtr)) stop("ytrain length must equal nrow(xtrain).")
  keep_rows <- !is.na(ytr)
  if (!all(keep_rows)) {
    Xtr <- Xtr[keep_rows, , drop = FALSE]
    ytr <- ytr[keep_rows]
  }
  classes <- unique(ytr)
  if (length(classes) < 2L) stop("At least two classes are required in ytrain.")
  class_levels <- classes  # freeze class order

  # align columns between train & test by name if possible
  tr_names <- colnames(Xtr)
  te_names <- colnames(Xte)

  if (!is.null(tr_names) && !is.null(te_names)) {
    # For each training column, take same-named test column or impute with training median
    med_tr <- apply(Xtr, 2, col_median)
    Xte_aligned <- matrix(NA_real_, nrow = nrow(Xte), ncol = ncol(Xtr))
    colnames(Xte_aligned) <- tr_names
    for (j in seq_along(tr_names)) {
      nm <- tr_names[j]
      if (nm %in% te_names) {
        Xte_aligned[, j] <- Xte[, which(te_names == nm)[1L]]
      } else {
        Xte_aligned[, j] <- med_tr[j]
      }
    }
    Xte <- Xte_aligned
  } else {
    # Fallback: positional match; pad/crop test to match training columns
    if (ncol(Xte) < ncol(Xtr)) {
      pad <- matrix(rep(apply(Xtr, 2, col_median), each = nrow(Xte)),
                    nrow = nrow(Xte), byrow = FALSE)[, seq_len(ncol(Xtr))]
      Xte <- cbind(Xte, pad[, (ncol(Xte)+1):ncol(Xtr), drop = FALSE])
    } else if (ncol(Xte) > ncol(Xtr)) {
      Xte <- Xte[, seq_len(ncol(Xtr)), drop = FALSE]
    }
    colnames(Xte) <- colnames(Xtr)
  }

  # Drop columns that are all-NA in training
  all_na_tr <- apply(Xtr, 2, function(col) all(is.na(col)))
  if (any(all_na_tr)) {
    Xtr <- Xtr[, !all_na_tr, drop = FALSE]
    Xte <- Xte[, !all_na_tr, drop = FALSE]
  }

  # Drop zero-variance columns in training (using colSds)
  zs <- matrixStats::colSds(Xtr, na.rm = TRUE)
  zero_var <- !is.finite(zs) | zs == 0
  if (any(zero_var)) {
    Xtr <- Xtr[, !zero_var, drop = FALSE]
    Xte <- Xte[, !zero_var, drop = FALSE]
  }

  # Ensure we still have columns
  if (ncol(Xtr) < 1L) stop("All columns in xtrain were invalid (all-NA or zero-variance).")

  # Recompute p and clamp hyperparameters
  p <- ncol(Xtr)
  ntrees <- max(1L, as.integer(ntrees))
  depth  <- max(0L, as.integer(depth))
  mnode  <- max(1L, as.integer(mnode))
  mtry   <- max(1L, min(p, as.integer(mtry)))
  k      <- max(1L, min(p, as.integer(k)))

  ## ---------- core helpers (same idea) ----------
  calculate_snr <- function(X, y) {
    cls <- unique(y); kcl <- length(cls)
    if (kcl < 2L) return(rep(0, ncol(X)))
    class_stats <- lapply(cls, function(c) {
      subset <- X[y == c, , drop = FALSE]
      list(mean = colMeans(subset, na.rm = TRUE),
           sd   = matrixStats::colSds(subset, na.rm = TRUE))
    })
    class_means <- do.call(rbind, lapply(class_stats, `[[`, "mean"))
    class_sd    <- do.call(rbind, lapply(class_stats, `[[`, "sd"))
    pairs <- utils::combn(seq_along(cls), 2)
    if (length(pairs) == 0L) return(rep(0, ncol(X)))
    snr_matrix <- apply(pairs, 2, function(pair) {
      md <- abs(class_means[pair[1], ] - class_means[pair[2], ])
      ps <- (class_sd[pair[1], ] + class_sd[pair[2], ]) + 1e-7
      md / ps
    })
    (2 / (kcl * (kcl - 1))) * rowSums(snr_matrix, na.rm = TRUE)
  }

  compute_centroids <- function(X, y, sel_idx) {
    u <- unique(y)
    cents <- lapply(u, function(cl) colMeans(X[y == cl, sel_idx, drop = FALSE], na.rm = TRUE))
    names(cents) <- as.character(u)
    cents
  }

  majority <- function(y) {
    tab <- table(y)
    nm  <- names(which.max(tab))
    if (length(nm) == 1L && !is.na(nm)) nm else class_levels[1L]
  }

  ## ---------- tree builder & predictor ----------
  build_tree <- function(X, y, d = 0L, feat_imp_acc) {
    if (d >= depth || nrow(X) < mnode || length(unique(y)) == 1L) {
      return(list(node = list(class = majority(y)), feat_imp = feat_imp_acc))
    }

    p_here <- ncol(X)
    if (p_here < 1L) {
      return(list(node = list(class = majority(y)), feat_imp = feat_imp_acc))
    }

    cand_size <- min(p_here, mtry)
    candidate <- if (p_here <= cand_size) seq_len(p_here) else sample.int(p_here, cand_size, replace = FALSE)
    candidate <- as.integer(candidate[is.finite(candidate) & candidate >= 1L & candidate <= p_here])

    XX <- X[, candidate, drop = FALSE]
    if (ncol(XX) < 1L) {
      return(list(node = list(class = majority(y)), feat_imp = feat_imp_acc))
    }

    snr <- calculate_snr(XX, y)
    snr[!is.finite(snr)] <- -Inf
    ord <- order(snr, decreasing = TRUE)
    topk <- min(k, length(candidate))
    if (topk < 1L || length(ord) < 1L) {
      return(list(node = list(class = majority(y)), feat_imp = feat_imp_acc))
    }

    sel <- candidate[ord[seq_len(topk)]]
    sel <- as.integer(sel[!is.na(sel) & is.finite(sel) & sel >= 1L & sel <= p_here])
    if (length(sel) == 0L) {
      return(list(node = list(class = majority(y)), feat_imp = feat_imp_acc))
    }

    # simple frequency importance
    feat_imp_acc[sel] <- feat_imp_acc[sel] + 1L

    cents <- compute_centroids(X, y, sel)
    if (length(cents) < 1L) {
      return(list(node = list(class = majority(y)), feat_imp = feat_imp_acc))
    }

    dist_mat <- matrix(NA_real_, nrow = nrow(X), ncol = length(cents))
    for (i in seq_along(cents)) {
      dist_mat[, i] <- rowSums(
        (X[, sel, drop = FALSE] -
           matrix(cents[[i]], nrow = nrow(X), ncol = length(sel), byrow = TRUE))^2,
        na.rm = TRUE
      )
    }
    cluster <- max.col(-dist_mat)  # one winner per row

    node <- list(centroids = cents, selected_indices = sel, branches = vector("list", length(cents)))
    for (cid in sort(unique(cluster))) {
      cid <- as.integer(cid[1L])
      idx <- which(cluster == cid)
      if (length(idx) > 0L && cid >= 1L && cid <= length(node$branches)) {
        sub <- build_tree(X[idx, , drop = FALSE], y[idx], d + 1L, feat_imp_acc)
        node$branches[[cid]] <- sub$node
        feat_imp_acc <- sub$feat_imp
      } else {
        node$branches[[cid]] <- list(class = majority(y))
      }
    }

    list(node = node, feat_imp = feat_imp_acc)
  }

  predict_tree_one <- function(tree, xrow) {
    node <- tree
    repeat {
      if (!is.null(node$class)) {
        cls <- as.character(node$class)
        return(if (cls %in% class_levels) cls else class_levels[1L])
      }
      if (is.null(node$centroids) || is.null(node$selected_indices) ||
          length(node$selected_indices) == 0L || length(node$centroids) == 0L) {
        return(class_levels[1L])
      }
      dists <- vapply(node$centroids, function(mu) {
        sum((xrow[node$selected_indices] - mu)^2, na.rm = TRUE)
      }, numeric(1))
      if (!length(dists) || all(!is.finite(dists))) return(class_levels[1L])
      dists[!is.finite(dists)] <- Inf
      best <- which.min(dists)
      if (!length(best) || is.na(best) || best < 1L || best > length(node$branches)) {
        return(class_levels[1L])
      }
      child <- node$branches[[best]]
      if (is.null(child)) return(class_levels[1L])
      node <- child
    }
  }

  ## ---------- forest build ----------
  trees <- vector("list", ntrees)
  feat_imp_global <- numeric(p)

  for (t in seq_len(ntrees)) {
    idx <- sample.int(nrow(Xtr), size = nrow(Xtr), replace = TRUE)
    Xb <- Xtr[idx, , drop = FALSE]
    yb <- ytr[idx]
    built <- build_tree(Xb, yb, 0L, feat_imp_acc = numeric(p))
    trees[[t]] <- built$node
    feat_imp_global <- feat_imp_global + built$feat_imp
  }

  ## ---------- predictions & probabilities ----------
  vote_mat <- matrix(NA_character_, nrow = nrow(Xte), ncol = ntrees)
  for (t in seq_len(ntrees)) {
    vote_mat[, t] <- apply(Xte, 1, function(r) predict_tree_one(trees[[t]], r))
  }

  predictions <- apply(vote_mat, 1, function(v) {
    tab <- table(factor(v, levels = class_levels))
    names(which.max(tab))
  })

  probabilities <- sapply(class_levels, function(cl) rowMeans(vote_mat == cl, na.rm = TRUE))
  if (is.null(dim(probabilities))) {
    probabilities <- matrix(probabilities, ncol = 1L,
                            dimnames = list(NULL, class_levels[1L]))
  }
  colnames(probabilities) <- class_levels
  rownames(probabilities) <- NULL

  ## ---------- feature importance ----------
  if (sum(feat_imp_global) > 0) {
    feature_importance <- feat_imp_global / sum(feat_imp_global)
  } else {
    feature_importance <- feat_imp_global
  }
  names(feature_importance) <- colnames(Xtr)

  list(
    predictions        = predictions,
    probabilities      = probabilities,
    feature_importance = feature_importance
  )
}
