% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/causalDT.R
\name{causalDT}
\alias{causalDT}
\title{Causal Distillation Trees (CDT)}
\usage{
causalDT(
  X,
  Y,
  Z,
  W = NULL,
  holdout_prop = 0.3,
  holdout_idxs = NULL,
  teacher_model = "causal_forest",
  teacher_predict = NULL,
  student_model = "rpart",
  rpart_control = NULL,
  rpart_prune = c("none", "min", "1se"),
  nfolds_crossfit = NULL,
  nreps_crossfit = NULL,
  B_stability = 100,
  max_depth_stability = NULL,
  ...
)
}
\arguments{
\item{X}{A tibble, data.frame, or matrix of covariates.}

\item{Y}{A vector of outcomes.}

\item{Z}{A vector of treatments.}

\item{W}{A vector of weights corresponding to treatment propensities.}

\item{holdout_prop}{Proportion of data to hold out for honest estimation of
treatment effects. Used only if \code{holdout_idxs} is NULL.}

\item{holdout_idxs}{A vector of indices to hold out for honest estimation of
treatment effects. If NULL, a holdout set of size \code{holdout_prop} x nrow(X)
is randomly selected.}

\item{teacher_model}{Teacher model used to estimate individual-level
treatment events. Should be either "causal_forest" (default),
"bcf", or a function.
If "causal_forest", \code{grf::causal_forest()} is used as the teacher
model. If "bcf", \code{bcf::bcf()} is used as the teacher model.
Otherwise, the function should take in the named arguments
\code{X}, \code{Y}, \code{Z}, optionally \code{W} (corresponding to the covariates,
outcome, treatment, and propensity weights,
respectively), and (optional) additional arguments passed to
the function via \code{...}. Moreover, the function should return a model object
that can be used to predict individual-level treatment effects using
\code{teacher_predict(teacher_model, x)}.}

\item{teacher_predict}{Function used to predict individual-level treatment
effects from the teacher model. Should take in two arguments. as input: the
first being the model object returned by \code{teacher_model}, and the second
being a tibble, data.frame, or matrix of covariates. If \code{NULL}, the
default is \code{predict()}.}

\item{student_model}{Student model used to estimate subgroups of individuals
and their corresponding estimated treatment effects. Should be either
"rpart" (default) or a function. If "rpart", \code{rpart::rpart()} is used.
Otherwise, the function should take in two arguments as input: the first
being a tibble, data.frame, or matrix of covariates, and the second being a
vector of predicted individual-level treatment effects. Moreover, the
function should return a list. At a minimum, this list should contain one
element named \code{fit} that is a model object that can be used to output the
leaf membership indices for each observation via
\code{predict(student_model, x, type = 'node')}. In general, we recommend
using the default "rpart".}

\item{rpart_control}{A list of control parameters for the \code{rpart} algorithm.
See \code{? rpart.control} for details. Used only if \code{student_model} is "rpart".}

\item{rpart_prune}{Method for pruning the tree. Default is \code{"none"}.
Options are \code{"none"}, \code{"min"}, and \code{"1se"}. If \code{"min"},
the tree is pruned using the complexity threshold which minimizes the
cross-validation error. If \code{"1se"}, the tree is pruned using the
largest complexity threshold which yields a cross-vaidation error within
one standard error of the minimum. If \code{"none"}, the tree is not
pruned.}

\item{nfolds_crossfit}{Number of folds in cross-fitting procedure.
If \code{teacher_model} is "causal_forest", the default is 1 (no cross-fitting
is performed). Otherwise, the default is 2 (one fold for training the
teacher model and one fold for estimating the individual-level treatment effects).}

\item{nreps_crossfit}{Number of repetitions of the cross-fitting procedure.
If \code{teacher_model} is "causal_forest", the default is 1 (no cross-fitting
is performed). Otherwise, the default is 50.}

\item{B_stability}{Number of bootstrap samples to use in evaluating stability
diagnostics (which can be used to select an appropriate teacher model).
Default is 100. Stability diagnostics are only performed if
\code{student_model} is an \code{rpart} object. If \code{B_stability} is 0, no stability
diagnostics are performed. We refer to Huang et al. (2025) for additional
details on using the stability diagnostic to select the teacher model.}

\item{max_depth_stability}{Maximum depth of the decision tree used in
evaluating stability diagnostics. If \code{NULL}, the default is
max(4, max depth of fitted student model).}

\item{...}{Additional arguments passed to the \code{teacher_model} function.}
}
\value{
A list with the following elements:
\item{estimate}{Estimated subgroup average treatment effects tibble with the following columns:
\itemize{
\item{leaf_id - Leaf node identifier.}
\item{subgroup - String representation of the subgroup.}
\item{estimate - Estimated conditional average treatment effect for the subgroup.}
\item{variance - Asymptotic variance of the estimated conditional average treatment effect.}
\item{.var1 - Sample variance for treated observations in the subgroup.}
\item{.var0 - Sample variance for control observations in the subgroup.}
\item{.n1 - Number of treated observations in the subgroup.}
\item{.n0 - Number of control observations in the subgroup.}
\item{.sample_idxs - Indices of (holdout) observations in the subgroup.}
}
}
\item{student_fit}{Output of \code{student_model()}, which can vary. If
\code{student_model} is "rpart", the output is a list with the following elements:
\itemize{
\item{fit - The fitted student model. An \code{rpart} model object.}
\item{tree_info - A data.frame with the tree structure/split information.}
\item{subgroups - A list of subgroups given by their string representation.}
\item{predictions - Student model predictions for the training (non-holdout) data.}
}
}
\item{teacher_fit}{A list of (cross-fitted) teacher model fits.}
\item{teacher_predictions}{The predicted individual-level treatment effects, averaged across all cross-fitted teacher model.}
\item{teacher_predictions_ls}{A list of predicted individual-level treatment effects from each (cross-fitted) teacher model fit.}
\item{crossfit_idxs_ls}{A list of fold indices used in each cross-fit.}
\item{stability_diagnostics}{A list of stability diagnostics with the following elements:
\itemize{
\item{jaccard_mean - Vector of mean Jaccard similarity index for each tree depth. The tree depth is given by the vector index.}
\item{jaccard_distribution - List of Jaccard similarity indices across all bootstraps for each tree depth.}
\item{bootstrap_predictions - List of mean student model predictions (for training (non-holdout) data) across all bootstraps for each tree depth.}
\item{bootstrap_predictions_var - List of variance of student model predictions (for training (non-holdout) data) across all bootstraps for each tree depth.}
\item{leaf_ids - List of leaf node identifiers, indicating the leaf membership of each training sample in the (original) fitted student model.}
}
}
\item{holdout_idxs}{Indices of the holdout set.}
}
\description{
This function implements causal distillation trees (CDT),
developed in Huang et al. (2025). Briefly, CDT is a two-stage
procedure that allows researchers to identify interpretable subgroups with
heterogeneous treatment effects. In the first stage, researchers are free
to use any machine learning model or metalearner to predict the
heterogeneous treatment effects for each individual in the dataset. In the
second stage, CDT ``distills'' these predicted heterogeneous treatment
effects into interpretable subgroups by fitting an ordinary decision tree
using the predicted heterogeneous treatment effects from the first stage
as the response variable.
}
\examples{
n <- 50
p <- 3
X <- matrix(rnorm(n * p), nrow = n, ncol = p)
Z <- rbinom(n, 1, 0.5)
Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1)

# causal distillation trees using causal forest teacher model
\donttest{
out <- causalDT(X, Y, Z)
}

}
\references{
Huang, M., Tang, T. M., and Kenney, A. M. (2025). Distilling heterogeneous treatment effects: Stable subgroup estimation in causal inference. \emph{arXiv preprint arXiv:2502.07275}.
}
