#' Doubly-robust estimation for survival outcomes in CRTs
#'
#' @description
#' Fits doubly-robust estimators for cluster-randomized trials with right-censored survival outcomes,
#' including single-state and multi-state outcomes
#'
#' The outcome is specified as \code{Surv(time, status)}, where
#' \code{status in {0,1,2,...,Q}} and \code{status = 0} denotes censoring.
#' Values \code{1,2,...,Q} are ordered states, with the largest state typically
#' representing an absorbing state (e.g., death).
#'
#' The function supports two estimands:
#' \itemize{
#'   \item \strong{SPCE}: stage-specific survival probabilities
#'     \eqn{S_s(t)} for each state \eqn{s=1,\dots,S_{\max}} at all event times.
#'   \item \strong{RMTIF}: a generalized win-based restricted mean time in favor estimand
#'     constructed from the multi-state survival outcome. When \code{status} is
#'     binary (\code{0/1}), this reduces to an RMST estimand (evaluated on the full
#'     observed-time grid).
#'     \item \strong{RMST}: a special case of \strong{RMTIF} when
#'     \code{status in {0,1}} (one nonzero state). In this case the generalized
#'     RMT-IF reduces to a regular RMST contrast.
#' }
#'
#' Jackknife variance is computed via leave-one-cluster-out re-fitting method
#' \itemize{
#'   \item For \code{estimand = "SPCE"}: variances of \eqn{S_{1}(t)}, \eqn{S_{0}(t)},
#'         and \eqn{S_{1}(t) - S_{0}(t)} at each time and state.
#'   \item For \code{estimand = "RMTIF"}: variances and covariance of
#'         \eqn{R_{1}(\tau)}, \eqn{R_{0}(\tau)}, and \eqn{R_{1}(\tau) - R_{0}(\tau)}
#'         at each event time \eqn{\tau}.
#' }
#'
#' The returned object includes metadata needed for summaries and plotting:
#' final fitted outcome/censoring formulas, the cluster id column, number of clusters,
#' degrees of freedom for jackknife t-intervals (= M - 1), sample sizes, and the
#' cluster-level and individual-level estimators.
#'
#' @param data A \code{data.frame}.
#' @param formula Outcome model: e.g.,
#'   \code{Surv(time, status) ~ W1 + W2 + Z1 + Z2 + cluster(M)}.
#'   The left-hand side must be \code{Surv(time, status)} with
#'   \code{status in {0,1,2,...}} and \code{0} indicating censoring.
#'
#'   The right-hand side \emph{must} include a \code{cluster(<id>)} term specifying
#'   the cluster id for CRTs. All other covariates may be individual- or cluster-level.
#'
#' @param cens_formula Optional censoring model. If \code{NULL}, the censoring model
#'   is built automatically from the outcome formula by:
#'   \itemize{
#'     \item reusing the RHS (excluding \code{cluster()});
#'     \item using LHS \code{Surv(time, event == 0)};
#'   }
#'   If supplied, \code{cens_formula} is used as-is for all stage-specific fits,
#'   but the DR estimating equations still use the stage-specific \code{event}
#'   indicator as described above.
#'
#' @param intv Character: name of the cluster-level treatment column (0/1),
#'   constant within cluster.
#' @param id_var Character: name of the individual id column. If \code{NULL}, considered as single state.
#'
#' @param method \code{"marginal"} or \code{"frailty"}.
#'   \itemize{
#'     \item \code{"marginal"}: fits \code{survival::coxph} models with
#'           \code{cluster(<id>)} robust variance.
#'     \item \code{"frailty"}: fits \code{frailtyEM::emfrail} gamma-frailty models
#'           for outcome and censoring.
#'   }
#'
#' @param estimand \code{"SPCE"}, \code{"RMTIF"}, or \code{"RMST"}.
#'   \itemize{
#'     \item \code{"SPCE"}: returns stage-specific survival arrays
#'           \code{S_stage_cluster} and \code{S_stage_ind} with dimensions
#'           \code{[time × 2 × Q]}.
#'     \item \code{"RMTIF"}: returns the generalized win-based restricted mean time in favor estimand
#'           at each event time, along with stage-wise contributions. For a binary
#'           status, this reduces to an RMST estimands.
#'     \item \code{"RMST"}: restricted mean survival time difference for the
#'           binary case \code{status in {0,1}}. This is a convenience alias for
#'           the \code{"RMTIF"} calculations when there is exactly one nonzero
#'           event state.
#'   }
#'
#' @param trt_prob Optional length-2 numeric vector \code{(p0, p1)} giving the
#'   cluster-level treatment probabilities for arms 0 and 1. If \code{NULL}, they
#'   are computed as the empirical proportion of treatment assignments
#'   per cluster.
#'
#' @param variance \code{"none"} or \code{"jackknife"} for variance estimation.
#'
#' @param fit_controls Optional \code{frailtyEM::emfrail_control()} list, used only
#'   when \code{method = "frailty"}. If \code{NULL}, default fast-fitting controls
#'   are used (no standard errors from the frailtyEM fits are required here).
#'
#' @param verbose Logical; currently unused but kept for future verbosity options.
#'
#' @return An object of class \code{"DRsurvfit"} with fields depending on
#'   \code{estimand}:
#'
#'   \describe{
#'     \item{Common:}{
#'       \itemize{
#'         \item \code{method}: fitted method (\code{"marginal"} or \code{"frailty"}).
#'         \item \code{estimand}: requested estimand (\code{"SPCE"} or \code{"RMTIF"}).
#'         \item \code{trt_prob}: numeric vector \code{c(p0, p1)}.
#'         \item \code{event_time}: time grid:
#'           \itemize{
#'             \item SPCE: all event times including \code{0}.
#'             \item RMTIF: positive event times \eqn{\tau} at which the RMT-IF is evaluated.
#'           }
#'         \item \code{max_state}: maximum observed non-zero status.
#'         \item \code{cluster_col}: name of the cluster id column.
#'         \item \code{n_clusters}: number of clusters (\eqn{M}).
#'         \item \code{df_jackknife}: jackknife degrees of freedom (\eqn{M - 1}).
#'         \item \code{n_obs}: total number of observations.
#'         \item \code{n_events}: total number of non-censoring observations
#'               (\code{status != 0}).
#'         \item \code{cluster_trt_counts}: counts of treated and control clusters
#'               \code{c(n_trt0, n_trt1)} based on first row per cluster.
#'         \item \code{formula_outcome}: fully reconstructed outcome formula.
#'         \item \code{cens_formula}: final censoring formula used.
#'         \item \code{call}: the matched call.
#'         \item \code{jackknife}: logical indicating whether jackknife variance
#'               was computed.
#'       }
#'     }
#'
#'     \item{If \code{estimand = "SPCE"}:}{
#'       \itemize{
#'         \item \code{S_stage_cluster}: 3D array \code{[time × 2 × Q]} with
#'               stage-specific cluster-level survival:
#'               \code{S_stage_cluster[ , 1, s]} = \eqn{S_1^{(s)}(t)},
#'               \code{S_stage_cluster[ , 2, s]} = \eqn{S_0^{(s)}(t)}.
#'         \item \code{S_stage_ind}: analogous individual-level survival array.
#'         \item \code{var_stage_cluster}: jackknife variances for
#'               \eqn{S_1^{(s)}(t)}, \eqn{S_0^{(s)}(t)}, and
#'               \eqn{S_1^{(s)}(t) - S_0^{(s)}(t)} as a 3D array
#'               \code{[time × 3 × Q]} with dimension names
#'               \code{comp = c("Var(S1)","Var(S0)","Var(S1-S0)")}, when
#'               \code{variance = "jackknife"}; otherwise \code{NULL}.
#'         \item \code{var_stage_ind}: analogous individual-level variance array.
#'       }
#'     }
#'
#'     \item{If \code{estimand = "RMTIF"}:}{
#'       \itemize{
#'         \item \code{RMTIF_cluster}: matrix \code{[time × 3]} with columns
#'               \code{c("R1","R0","R1-R0")} giving the cluster-level RMT-IF curves
#'               at each event time \eqn{\tau}.
#'         \item \code{RMTIF_ind}: analogous individual-level RMT-IF matrix.
#'         \item \code{stagewise_cluster}: list of length \code{length(event_time)};
#'               each element is a \code{3 × (Q)} matrix of stage-wise
#'               contributions with rows
#'               \code{c("s1qs0qp1","s0qs1qp1","diff")} and columns
#'               \code{c("stage_1",...,"stage_Q","sum")}.
#'         \item \code{stagewise_ind}: analogous individual-level list.
#'         \item \code{var_rmtif_cluster}: jackknife variance/covariance matrix
#'               \code{[time × 4]} with columns
#'               \code{c("Var(R1)","Var(R0)","Var(R1-R0)","Cov(R1,R0)")},
#'               when \code{variance = "jackknife"}; otherwise \code{NULL}.
#'         \item \code{var_rmtif_ind}: analogous individual-level matrix.
#'         \item \code{S_stage_cluster}, \code{S_stage_ind}: the underlying
#'               stage-specific survival arrays are also returned for convenience.
#'       }
#'     }
#'   }
#'
#' @examples
#' \donttest{
#' data(datm)
#'
#' ## Multi-state RMT-IF (binary reduces to RMST-type)
#' fit_rmtif <- DRsurvfit(
#'   datm,
#'   Surv(time, event) ~ W1 + W2 + Z1 + Z2 + cluster(cluster),
#'   intv    = "trt",
#'   method  = "marginal",
#'   estimand = "RMTIF",
#'   variance = "none"
#' )
#' }
#' @export
DRsurvfit <- function(data,
                      formula,
                      cens_formula = NULL,
                      intv,
                      id_var=NULL,
                      method   = c("marginal", "frailty"),
                      estimand = c("SPCE", "RMTIF", "RMST"),  # <- ADDED "RMST"
                      trt_prob = NULL,
                      variance = c("none", "jackknife"),
                      fit_controls = NULL,
                      verbose = FALSE) {

  method   <- match.arg(method)
  estimand <- match.arg(estimand)
  variance <- match.arg(variance)

  ## NEW: distinguish user-facing estimand from core estimand
  is_rmst       <- identical(estimand, "RMST")
  estimand_core <- if (is_rmst) "RMTIF" else estimand

  if (!is.data.frame(data))
    stop("'data' must be a data.frame.", call. = FALSE)

  ## NEW: if RMST is requested, enforce binary status {0,1}
  nm <- .surv_lhs(formula)
  if (is_rmst) {
    if (!nm$status %in% names(data)) {
      stop("Status variable not found in 'data' when estimand = 'RMST'.",
           call. = FALSE)
    }
    status_vals <- data[[nm$status]]
    status_vals <- status_vals[!is.na(status_vals)]
    uniq_status <- sort(unique(status_vals))

    if (sum(uniq_status %in% c(0, 1))!=2) {
      stop(
        "estimand = 'RMST' requires binary status with values {0, 1}. ",
        "Observed unique non-missing values: ",
        paste(uniq_status, collapse = ", "),
        call. = FALSE
      )
    }
  }

  ## --- fit core on full grid -----------------------------------------------
  est <- .DR_est_core(
    data         = data,
    formula      = formula,
    cens_formula = cens_formula,
    intv         = intv,
    method       = method,
    estimand     = estimand_core,   # <- use core estimand
    trt_prob     = trt_prob,
    fit_controls = fit_controls,
    e_time       = NULL,
    id_var=id_var
  )

  et          <- est$event_time
  S_stage_c   <- est$S_stage_cluster   # [time × 2 × max_state] or NULL
  S_stage_i   <- est$S_stage_ind
  max_s       <- est$max_state
  cluster_col <- est$cluster_col

  ## --- reconstruct final formulas and cluster info -------------------------
  environment(formula) <- .surv_env(parent = environment(formula) %||% parent.frame())
  if (!is.null(cens_formula))
    environment(cens_formula) <- .surv_env(parent = environment(cens_formula) %||% parent.frame())

  nm   <- .surv_lhs(formula)
  info <- .extract_cluster(formula)
  clvar <- info$cluster
  rhs   <- info$rhs_wo_cluster

  if (is.null(clvar) || !nzchar(clvar))
    stop("Outcome formula must include cluster(<id>) for CRT.", call. = FALSE)

  rhs_surv <- rhs
  if (nzchar(rhs_surv))
    rhs_surv <- paste(rhs_surv, sprintf("+ cluster(%s)", clvar))
  else
    rhs_surv <- sprintf("cluster(%s)", clvar)

  if (is.null(cens_formula)) {
    cf <- .build_cens_formula_from(nm, rhs, clvar)
    cens_formula_final <- as.formula(cf)
  } else {
    cens_formula_final <- cens_formula
  }

  formula_outcome_final <- as.formula(
    sprintf("Surv(%s,%s) ~ %s", nm$time, nm$status, rhs_surv)
  )

  n_obs <- nrow(data)
  n_events <- if (nm$status %in% names(data)) {
    sum(as.integer(data[[nm$status]] != 0L), na.rm = TRUE)
  } else {
    NA_integer_
  }

  K       <- length(unique(data[[cluster_col]]))
  df_jack <- max(1L, K - 1L)

  ## cluster-level treatment split (first row per cluster)
  cl_ids <- unique(data[[cluster_col]])
  A_first <- vapply(cl_ids, function(g)
    data[[intv]][which(data[[cluster_col]] == g)[1]],
    numeric(1L))
  cluster_trt_counts <- c(
    n_trt0 = sum(A_first == 0, na.rm = TRUE),
    n_trt1 = sum(A_first == 1, na.rm = TRUE)
  )

  ## --- jackknife variance --------------------------------------------------
  var_out <- if (identical(variance, "jackknife")) {
    .DR_var_jackknife(
      data         = data,
      formula      = formula,
      cens_formula = cens_formula,
      intv         = intv,
      method       = method,
      estimand     = estimand_core,   # <- use core estimand
      trt_prob     = trt_prob,
      fit_controls = fit_controls,
      e_time_full  = et[et > 0],
      id_var=id_var
    )
  } else {
    NULL
  }

  ## --- assemble return object ---------------------------------------------
  if (estimand_core == "SPCE") {
    var_stage_cluster <- if (!is.null(var_out)) var_out$Cluster$var_stage else NULL
    var_stage_ind     <- if (!is.null(var_out)) var_out$Individual$var_stage else NULL

    ans <- list(
      method   = method,
      estimand = "SPCE",
      event_time        = et,            # includes 0
      max_state         = max_s,
      S_stage_cluster   = S_stage_c,
      S_stage_ind       = S_stage_i,
      RMTIF_cluster     = NULL,
      RMTIF_ind         = NULL,
      stagewise_cluster = NULL,
      stagewise_ind     = NULL,
      trt_prob          = unname(est$trt_prob),
      var_stage_cluster = var_stage_cluster,
      var_stage_ind     = var_stage_ind,

      ## metadata
      formula_outcome   = formula_outcome_final,
      cens_formula      = cens_formula_final,
      cluster_col       = cluster_col,
      n_clusters        = K,
      df_jackknife      = df_jack,
      n_obs             = n_obs,
      n_events          = n_events,
      cluster_trt_counts = cluster_trt_counts,
      call              = match.call(),
      jackknife         = identical(variance, "jackknife")
    )

  } else {  # RMTIF / RMST
    var_rmtif_cluster <- if (!is.null(var_out)) var_out$Cluster$var_R else NULL
    var_rmtif_ind     <- if (!is.null(var_out)) var_out$Individual$var_R else NULL

    ans <- list(
      method   = method,
      estimand = if (is_rmst) "RMST" else "RMTIF",  # <- user-facing label
      event_time        = et,                # greater than grid > 0
      max_state         = max_s,
      S_stage_cluster   = S_stage_c,
      S_stage_ind       = S_stage_i,
      RMTIF_cluster     = est$RMTIF_cluster,
      RMTIF_ind         = est$RMTIF_ind,
      stagewise_cluster = est$stagewise_cluster,
      stagewise_ind     = est$stagewise_ind,
      trt_prob          = unname(est$trt_prob),
      var_rmtif_cluster = var_rmtif_cluster,
      var_rmtif_ind     = var_rmtif_ind,

      ## metadata
      formula_outcome   = formula_outcome_final,
      cens_formula      = cens_formula_final,
      cluster_col       = cluster_col,
      n_clusters        = K,
      df_jackknife      = df_jack,
      n_obs             = n_obs,
      n_events          = n_events,
      cluster_trt_counts = cluster_trt_counts,
      call              = match.call(),
      jackknife         = identical(variance, "jackknife")
    )
  }

  class(ans) <- "DRsurvfit"
  ans
}





#' Summary method for DRsurvfit objects (multi-state SPCE / RMT-IF/RMST)
#'
#' @description
#' Produces tabular summaries for multi-state doubly-robust estimators:
#' \itemize{
#'   \item For \code{estimand = "SPCE"}: stage-specific survival probabilities
#'         \eqn{S_s(t)} at selected times \eqn{\tau}, with optional jackknife
#'         t-based confidence intervals.
#'   \item For \code{estimand = "RMTIF"} or \code{"RMST"}: curves
#'         \eqn{R_1(\tau)}, \eqn{R_0(\tau)}, and \eqn{R_1(\tau) - R_0(\tau)}
#'         at the same set of \eqn{\tau}, again with optional jackknife
#'         t-based intervals.
#' }
#'
#' The same argument \code{tau} is used for both estimands. If \code{tau} is
#' \code{NULL}, the function uses the 25\%, 50\%, and 75\% quantiles of the
#' event-time grid (excluding time 0 if present).
#'
#' @param object A \code{DRsurvfit} object.
#' @param level Character: \code{"cluster"} or \code{"individual"} level summary.
#' @param tau Optional numeric vector of times at which to summarize both
#'   \code{SPCE} and \code{RMTIF}/\code{RMST}. If \code{NULL}, the 25\%, 50\%,
#'   and 75\% quantiles of the event-time grid are used.
#' @param states Optional integer vector of states to summarize for
#'   \code{estimand = "SPCE"}. Defaults to all states \code{1:object$max_state}.
#' @param digits Number of digits to print for estimates and confidence limits.
#' @param alpha Nominal type I error for the intervals; coverage is
#'   \code{1 - alpha}. Default is \code{0.05}, giving 95\% confidence intervals.
#' @param ... Additional arguments passed to or from methods. Currently
#'   unused.
#'
#' @return The input object \code{object}, invisibly.
#' @export
summary.DRsurvfit <- function(object,
                              level  = c("cluster", "individual"),
                              tau    = NULL,
                              states = NULL,
                              digits = 4,
                              alpha  = 0.05,
                              ...) {

  level <- match.arg(level)

  ## --- header -------------------------------------------------------------- ##
  cat(sprintf("DRsurvfit: method = %s, estimand = %s\n",
              object$method, object$estimand))
  if (!is.null(object$trt_prob)) {
    cat("Treatment probs (p0, p1): ",
        paste(signif(object$trt_prob, digits), collapse = ", "),
        "\n", sep = "")
  }
  if (!is.null(object$formula_outcome)) {
    cat("Outcome model:   ", deparse(object$formula_outcome), "\n", sep = "")
  }
  if (!is.null(object$cens_formula)) {
    cat("Censoring model: ", deparse(object$cens_formula), "\n", sep = "")
  }
  if (!is.null(object$cluster_col)) {
    cat("Cluster id col:  ", object$cluster_col, "\n", sep = "")
  }
  if (!is.null(object$n_clusters)) {
    cat("Clusters (M):    ", object$n_clusters, "\n", sep = "")
  }
  if (!is.null(object$n_obs)) {
    cat("Obs (N):         ", object$n_obs, "\n", sep = "")
  }
  if (!is.null(object$n_events)) {
    cat("Events (status != 0): ", object$n_events, "\n", sep = "")
  }
  cat("\n")

  ## --- helper: format estimate + CI --------------------------------------- ##
  fmt_ci <- function(est, se, tcrit) {
    if (is.null(se) || anyNA(se) || is.na(tcrit)) {
      return(formatC(est, digits = digits, format = "fg"))
    }
    lcl <- est - tcrit * se
    ucl <- est + tcrit * se
    sprintf("%s (%s, %s)",
            formatC(est, digits = digits, format = "fg"),
            formatC(lcl, digits = digits, format = "fg"),
            formatC(ucl, digits = digits, format = "fg"))
  }

  ## --- t critical and df --------------------------------------------------- ##
  K  <- object$n_clusters %||% NA_integer_
  df <- object$df_jackknife %||%
    if (is.finite(K)) max(1L, K - 1L) else NA_integer_
  tcrit <- if (is.finite(df)) stats::qt(1 - alpha / 2, df = df) else NA_real_
  if (is.finite(df)) attr(tcrit, "df") <- df

  et <- object$event_time

  ## --- choose tau if NULL ------------------------------------------------- ##
  if (is.null(tau)) {
    et_pos <- et[et > 0]
    if (length(et_pos)) {
      qs <- stats::quantile(et_pos, probs = c(0.25, 0.50, 0.75), type = 1)
      tau <- as.numeric(qs)
    } else {
      tau <- unique(et)
    }
  }
  tau <- sort(unique(as.numeric(tau)))
  idx <- pmax(1L, findInterval(tau, et))
  row_lab <- paste0("t=", formatC(et[idx], digits = digits, format = "fg"))

  ## ======================================================================== ##
  ## SPCE: stage-specific survival S_s(t)
  ## ======================================================================== ##
  if (identical(object$estimand, "SPCE")) {

    S_arr <- if (level == "cluster") object$S_stage_cluster else object$S_stage_ind
    V_arr <- if (level == "cluster") object$var_stage_cluster else object$var_stage_ind

    if (is.null(S_arr) || length(dim(S_arr)) != 3L) {
      warning("SPCE summaries unavailable: missing stage-specific survival arrays.",
              call. = FALSE)
      return(invisible(object))
    }

    max_s <- object$max_state %||% dim(S_arr)[3]
    if (is.null(max_s) || is.na(max_s)) max_s <- dim(S_arr)[3]

    if (is.null(states)) {
      states <- seq_len(max_s)
    } else {
      states <- intersect(states, seq_len(max_s))
      if (!length(states)) {
        stop("No valid states to summarize.", call. = FALSE)
      }
    }

    for (s in states) {
      cat(sprintf("Stage-specific SPCE: state %d (%s-level)\n", s, level))

      ## slice survival safely: [length(tau) × 2]
      S_slice <- S_arr[idx, , s, drop = FALSE]
      S1 <- S_slice[, 1L, 1L]
      S0 <- S_slice[, 2L, 1L]
      Sd <- S1 - S0

      if (!is.null(V_arr)) {

        ## handle mismatch: S_arr includes t=0 row but V_arr may not
        se1   <- rep(NA_real_, length(idx))
        se0   <- rep(NA_real_, length(idx))
        sedif <- rep(NA_real_, length(idx))

        if (dim(V_arr)[1] == length(et) - 1L) {
          ## V_arr corresponds to et[et > 0], so shift indices by -1
          ok <- idx > 1L
          idx_var <- idx[ok] - 1L

          if (length(idx_var)) {
            V_slice <- V_arr[idx_var, , s, drop = FALSE]
            se1[ok]   <- sqrt(pmax(0, V_slice[, 1L, 1L]))
            se0[ok]   <- sqrt(pmax(0, V_slice[, 2L, 1L]))
            sedif[ok] <- sqrt(pmax(0, V_slice[, 3L, 1L]))
          }
        } else {
          V_slice <- V_arr[idx, , s, drop = FALSE]
          se1   <- sqrt(pmax(0, V_slice[, 1L, 1L]))
          se0   <- sqrt(pmax(0, V_slice[, 2L, 1L]))
          sedif <- sqrt(pmax(0, V_slice[, 3L, 1L]))
        }

        c1 <- fmt_ci(S1, se1, tcrit)
        c2 <- fmt_ci(S0, se0, tcrit)
        c3 <- fmt_ci(Sd, sedif, tcrit)

      } else {
        c1 <- formatC(S1, digits = digits, format = "fg")
        c2 <- formatC(S0, digits = digits, format = "fg")
        c3 <- formatC(Sd, digits = digits, format = "fg")
      }

      out <- cbind(
        `S1 (LCL, UCL)`     = c1,
        `S0 (LCL, UCL)`     = c2,
        `S1-S0 (LCL, UCL)`  = c3
      )
      rownames(out) <- row_lab
      print(noquote(out))

      if (!is.null(V_arr) && !is.na(tcrit)) {
        df_here <- attr(tcrit, "df")
        if (!is.null(df_here)) {
          cat(sprintf("  t-intervals with df = %d, alpha = %.3f\n\n",
                      df_here, alpha))
        } else {
          cat(sprintf("  t-intervals (alpha = %.3f)\n\n", alpha))
        }
      } else if (is.null(V_arr)) {
        cat("  (jackknife variances not available; showing point estimates only)\n\n")
      } else {
        cat("\n")
      }
    }

    return(invisible(object))
  }

  ## ======================================================================== ##
  ## RMTIF / RMST: win-based RMST-type curves over event time
  ## ======================================================================== ##
  if (object$estimand %in% c("RMTIF", "RMST")) {

    label_est <- if (identical(object$estimand, "RMST")) "RMST" else "RMT-IF"

    Rmat <- if (level == "cluster") object$RMTIF_cluster else object$RMTIF_ind
    Vmat <- if (level == "cluster") object$var_rmtif_cluster else object$var_rmtif_ind

    if (is.null(Rmat) || !is.matrix(Rmat) || ncol(Rmat) < 2L) {
      warning(label_est, " summaries unavailable: missing matrices.", call. = FALSE)
      return(invisible(object))
    }

    R1 <- Rmat[idx, 1L, drop = TRUE]
    R0 <- Rmat[idx, 2L, drop = TRUE]
    Rd <- if (ncol(Rmat) >= 3L) Rmat[idx, 3L, drop = TRUE] else (R1 - R0)

    if (!is.null(Vmat)) {
      se1   <- sqrt(pmax(0, Vmat[idx, 1L, drop = TRUE]))
      se0   <- sqrt(pmax(0, Vmat[idx, 2L, drop = TRUE]))
      sedif <- if (ncol(Vmat) >= 3L) sqrt(pmax(0, Vmat[idx, 3L, drop = TRUE])) else NA_real_

      c1 <- fmt_ci(R1, se1, tcrit)
      c2 <- fmt_ci(R0, se0, tcrit)
      c3 <- fmt_ci(Rd, sedif, tcrit)
    } else {
      c1 <- formatC(R1, digits = digits, format = "fg")
      c2 <- formatC(R0, digits = digits, format = "fg")
      c3 <- formatC(Rd, digits = digits, format = "fg")
    }

    cat(sprintf("%s summary (%s-level)\n", label_est, level))
    out <- cbind(
      `R1 (LCL, UCL)`       = c1,
      `R0 (LCL, UCL)`       = c2,
      `R1-R0 (LCL, UCL)`    = c3
    )
    rownames(out) <- row_lab
    print(noquote(out))

    if (!is.null(Vmat) && !is.na(tcrit)) {
      df_here <- attr(tcrit, "df")
      if (!is.null(df_here)) {
        cat(sprintf("  t-intervals with df = %d, alpha = %.3f\n\n",
                    df_here, alpha))
      } else {
        cat(sprintf("  t-intervals (alpha = %.3f)\n\n", alpha))
      }
    } else if (is.null(Vmat)) {
      cat("  (jackknife variances not available; showing point estimates only)\n\n")
    } else {
      cat("\n")
    }

    cat("Stage-wise decompositions are available in:\n")
    cat(sprintf("  object$stagewise_%s[[k]] for each time index k.\n",
                if (level == "cluster") "cluster" else "ind"))
    return(invisible(object))
  }

  warning("Unknown estimand in DRsurvfit object.", call. = FALSE)
  invisible(object)
}




#' Plot method for DRsurvfit objects (SPCE / RMT-IF)
#'
#' @description
#' Produces plots for multi-state doubly-robust estimators:
#' \itemize{
#'   \item For \code{estimand = "SPCE"}: for each state \eqn{s}, plots the
#'         difference curve \eqn{S_{1,s}(t) - S_{0,s}(t)} with jackknife
#'         t-based confidence bands over time.
#'   \item For \code{estimand = "RMTIF"}: plots the overall RMT-IF difference
#'         curve \eqn{R_1(t) - R_0(t)} (sum of stage-wise contributions) with
#'         jackknife t-based confidence bands over time.
#' }
#'
#' The argument \code{tau} is a truncation time: if supplied, the plot is
#' restricted to \code{event_time <= tau}. If \code{tau} is \code{NULL}, the
#' full event-time grid is used.
#'
#' @param x A \code{DRsurvfit} object.
#' @param level Character: \code{"cluster"} or \code{"individual"}.
#' @param states Optional integer vector of states to plot when
#'   \code{estimand = "SPCE"}. Defaults to all states
#'   \code{1:object$max_state}.
#' @param tau Optional numeric truncation time. If non-\code{NULL}, only
#'   event times \code{<= max(tau)} are plotted. If \code{NULL}, all event
#'   times are plotted.
#' @param alpha Nominal type I error for the intervals; coverage is
#'   \code{1 - alpha}. Default is \code{0.05}.
#' @param ... Unused; included for S3 consistency.
#'
#' @return The input object \code{x}, invisibly.
#' @export
#' @importFrom ggplot2 ggplot aes geom_line geom_ribbon labs theme_minimal
plot.DRsurvfit <- function(x,
                           level  = c("cluster", "individual"),
                           states = NULL,
                           tau    = NULL,
                           alpha  = 0.05,
                           ...) {

  object <- x  # keep the rest of your code unchanged


  level <- match.arg(level)
  et <- object$event_time

  if (is.null(et) || !length(et)) {
    warning("No event_time grid stored in DRsurvfit object; nothing to plot.",
            call. = FALSE)
    return(invisible(object))
  }

  ## ----- truncation by tau ------------------------------------------------- ##
  if (!is.null(tau)) {
    tau_star <- max(as.numeric(tau), na.rm = TRUE)
    idx_time <- which(et <= tau_star)
    if (!length(idx_time)) {
      stop("No event_time <= tau; cannot plot.", call. = FALSE)
    }
  } else {
    idx_time <- seq_along(et)
    tau_star <- max(et, na.rm = TRUE)
  }
  et_plot <- et[idx_time]

  ## ----- t critical using df from object ---------------------------------- ##
  K  <- object$n_clusters %||% NA_integer_
  df <- object$df_jackknife %||%
    if (is.finite(K)) max(1L, K - 1L) else NA_integer_
  tcrit <- if (is.finite(df)) stats::qt(1 - alpha / 2, df = df) else NA_real_

  ## ======================================================================== ##
  ## SPCE: plot only S1 - S0 per state with CI
  ## ======================================================================== ##
  if (identical(object$estimand, "SPCE")) {

    S_arr <- if (level == "cluster") object$S_stage_cluster else object$S_stage_ind
    V_arr <- if (level == "cluster") object$var_stage_cluster else object$var_stage_ind

    if (is.null(S_arr) || length(dim(S_arr)) != 3L) {
      warning("SPCE plot unavailable: missing stage-specific survival arrays.",
              call. = FALSE)
      return(invisible(object))
    }

    ## subset survival to plotting times
    S_arr_sub <- S_arr[idx_time, , , drop = FALSE]

    ## adjust variance array to align with S_arr time indexing
    V_arr_sub <- NULL
    if (!is.null(V_arr)) {

      # Case A: V_arr is on positive times only (no t=0),
      #         while S_arr includes t=0.
      if (dim(V_arr)[1] == length(et) - 1L) {

        V_sub <- array(NA_real_,
                       dim = c(length(idx_time), 3L, dim(V_arr)[3]))

        # fill where idx_time > 1 (skip t=0)
        pos <- which(idx_time > 1L)
        idx_var <- idx_time[idx_time > 1L] - 1L

        if (length(pos) && length(idx_var)) {
          V_sub[pos, , ] <- V_arr[idx_var, , , drop = FALSE]
        }
        V_arr_sub <- V_sub

      } else {
        # Case B: V_arr already includes same-length time dimension as S_arr
        V_arr_sub <- V_arr[idx_time, , , drop = FALSE]
      }
    }

    max_s <- object$max_state %||% dim(S_arr_sub)[3]
    if (is.null(max_s) || is.na(max_s)) max_s <- dim(S_arr_sub)[3]

    if (is.null(states)) {
      states <- seq_len(max_s)
    } else {
      states <- intersect(states, seq_len(max_s))
      if (!length(states)) {
        stop("No valid states to plot.", call. = FALSE)
      }
    }

    ## build plotting data.frame: one row per time × state
    df_list <- vector("list", length(states))

    for (i in seq_along(states)) {
      s <- states[i]

      S1   <- S_arr_sub[, 1L, s, drop = TRUE]
      S0   <- S_arr_sub[, 2L, s, drop = TRUE]
      diff <- S1 - S0

      lcl <- ucl <- rep(NA_real_, length(diff))

      if (!is.null(V_arr_sub)) {
        var_diff <- V_arr_sub[, 3L, s, drop = TRUE]
        se_diff  <- sqrt(pmax(0, var_diff))

        if (is.finite(tcrit)) {
          lcl <- diff - tcrit * se_diff
          ucl <- diff + tcrit * se_diff
        }
      }

      df_list[[i]] <- data.frame(
        time  = et_plot,
        diff  = diff,
        lcl   = lcl,
        ucl   = ucl,
        state = factor(s, levels = states,
                       labels = paste0("state ", states))
      )
    }

    df_plot <- do.call(rbind, df_list)

    ## separate line vs ribbon data to avoid geom_ribbon warnings
    df_line <- df_plot
    df_rib  <- df_plot[is.finite(df_plot$lcl) & is.finite(df_plot$ucl), , drop = FALSE]

    has_ci <- isTRUE(nrow(df_rib) > 0L)

    p <- ggplot2::ggplot(df_line,
                         ggplot2::aes(x = time, y = diff,
                                      color = state)) +
      ggplot2::geom_line() +
      ggplot2::labs(
        x     = "time",
        y     = expression(S[1](t) - S[0](t)),
        color = "State",
        fill  = "State",
        title = sprintf("Stage-specific SPCE differences (%s-level)", level),
        subtitle = if (is.finite(df))
          sprintf("Truncated at t <= %.3f; t-based %.1f%% CIs, df = %d",
                  tau_star, 100 * (1 - alpha), df)
        else
          sprintf("Truncated at t <= %.3f; CIs not available", tau_star)
      ) +
      ggplot2::theme_minimal()

    if (has_ci) {
      p <- p + ggplot2::geom_ribbon(
        data = df_rib,
        ggplot2::aes(
          x = time,
          ymin = lcl,
          ymax = ucl,
          fill = state
        ),
        alpha = 0.20,
        color = NA,
        inherit.aes = FALSE
      )
    }

    print(p)
    return(invisible(object))
  }

  ## ======================================================================== ##
  ## RMT-IF / RMST: plot only R1 - R0 over time with CI
  ## ======================================================================== ##
  if (object$estimand %in% c("RMTIF", "RMST")) {

    label_est <- if (identical(object$estimand, "RMST")) "RMST" else "RMT-IF"

    Rmat <- if (level == "cluster") object$RMTIF_cluster else object$RMTIF_ind
    Vmat <- if (level == "cluster") object$var_rmtif_cluster else object$var_rmtif_ind

    if (is.null(Rmat) || !is.matrix(Rmat) || ncol(Rmat) < 2L) {
      warning("RMT-IF plot unavailable: missing RMT-IF matrices.",
              call. = FALSE)
      return(invisible(object))
    }

    ## subset to plotting times
    Rmat_sub <- Rmat[idx_time, , drop = FALSE]
    Vmat_sub <- if (!is.null(Vmat)) Vmat[idx_time, , drop = FALSE] else NULL

    ## difference curve
    diff <- if (ncol(Rmat_sub) >= 3L) {
      Rmat_sub[, 3L, drop = TRUE]
    } else {
      Rmat_sub[, 1L, drop = TRUE] - Rmat_sub[, 2L, drop = TRUE]
    }

    lcl <- ucl <- rep(NA_real_, length(diff))

    if (!is.null(Vmat_sub) && ncol(Vmat_sub) >= 3L) {
      var_diff <- Vmat_sub[, 3L, drop = TRUE]
      se_diff  <- sqrt(pmax(0, var_diff))

      if (is.finite(tcrit)) {
        lcl <- diff - tcrit * se_diff
        ucl <- diff + tcrit * se_diff
      }
    }

    df_plot <- data.frame(
      time = et_plot,
      diff = diff,
      lcl  = lcl,
      ucl  = ucl
    )

    ## separate line vs ribbon data to avoid geom_ribbon warnings
    df_line <- df_plot
    df_rib  <- df_plot[is.finite(df_plot$lcl) & is.finite(df_plot$ucl), , drop = FALSE]
    has_ci  <- isTRUE(nrow(df_rib) > 0L)

    p <- ggplot2::ggplot(df_line,
                         ggplot2::aes(x = time, y = diff)) +
      ggplot2::geom_line() +
      ggplot2::labs(
        x = "time",
        y = expression(R[1](t) - R[0](t)),
        title = sprintf("%s difference (%s-level)", label_est, level),
        subtitle = if (is.finite(df))
          sprintf("Truncated at t <= %.3f; t-based %.1f%% CIs, df = %d",
                  tau_star, 100 * (1 - alpha), df)
        else
          sprintf("Truncated at t <= %.3f; CIs not available", tau_star)
      ) +
      ggplot2::theme_minimal()

    if (has_ci) {
      p <- p + ggplot2::geom_ribbon(
        data = df_rib,
        ggplot2::aes(
          x = time,
          ymin = lcl,
          ymax = ucl
        ),
        alpha = 0.20,
        color = NA,
        inherit.aes = FALSE
      )
    }

    print(p)
    return(invisible(object))
  }

  warning("Unknown estimand in DRsurvfit object; nothing plotted.", call. = FALSE)
  invisible(object)
}

if (getRversion() >= "2.15.1") {
  utils::globalVariables(c("time", "state"))
}
