#' @title Multiple imputation method for missing covariates in survival endpoint.
#' @description Generates multiple imputed datasets from a substantive model compatible fully conditional
#' specification model. This method incorporates the cause-specific hazards regression model into the imputation stage, and
#' assumes that the censoring process also depends on the covariates with missing values. Without loss of generality, we assumed
#' there is only one competing event. Our method is an extension of Bartlett et. al. (2015) and Bartlett and Taylor (2016).
#'
#' @param data The input data with missing values. Note: all missing values should be coded as \code{"NA"}, all binary covariates should be
#' coded as 0/1, and all categorical covariates with more than two categories should have a class of \code{factor}.
#' @param smformula The substantive model formula. For survival data, this should be a formula of \code{"Surv(time, delta) ~ Covariates list"} type.  (see example).
#' @param method A vector of strings for each covariate specifying the the type of regression model to impute them. The length of this vector should match the number of columns in the input dataset, and also match the position of each column (Including the outcomes).
#' If a covariate is fully observed, the value should be blank (""). Other possible options are \code{"norm"}(linear regression for continuous covariates), \code{"lm"}(logistic regression for binary covariates),
#' and \code{"mlogit"}(multinomial logistic regression for categorical covariates).
#'
#' @param m Number of complete datasets to generate, default is 10.
#' @param rjlimit Maximum number of rejection sampling attempts, default is 5000. If there are
#' subjects who did not get a success sampled value for the missing after reaching the limit,
#' a warning message will be issued suggesting to increase the limit.
#'
#' @return A list containing the imputed datasets.
#' @export
#' @references Bartlett JW, Seaman SR, White IR, Carpenter JR. (2015). Multiple imputation of covariates by fully conditional specification: accommodating the substantive model. \emph{Statistical Methods in Medical Research}. \strong{24(4)}, 462-487.
#'
#' @import survival VGAM MASS
#' @importFrom stats as.formula rmultinom lm glm rchisq
#'
#'
#' @examples
#' \donttest{
#' # Generate data with missing values
#' cox <- generate.surv(n = 500, beta = c(1,1,-1), phi= c(1,-1,-0.5), gamma = c(3,2,-1), seed = 112358)
#' # Impute
#' imputed <- new.smcfcs.surv(data = cox, smformula = "Surv(time, delta)~X1 +X2+X3",
#' method = c("","", "norm","logreg",""), m = 10, rjlimit = 10000)
#' # Fit a Cox regression on each imputed dataset, then produce the final estimates using Rubin's rule.
#' require(mitools)
#' library(survival)
#' imputed.fit <- with(imputationList(imputed), expr = coxph(Surv(time, delta) ~ X1+X2+X3))
#' summary(MIcombine(imputed.fit))}


new.smcfcs.surv <- function(data, smformula, method, m = 10, rjlimit = 5000){

  originaldata <- data

  #create matrix of response indicators (0: missing, 1:observed)
  r <- 1*(is.na(originaldata) == 0)
  # determine which covariate(s) are partially observed
  missingCounts <- nrow(originaldata) - colSums(r)

  if (ncol(originaldata) != length(method)){
    stop("Method argument must have the same length as the number of columns in the data frame.")
  }

  # Get the column position of time and delta
  timeCol <- (1:dim(originaldata)[2])[colnames(originaldata) %in% toString(as.formula(smformula)[[2]][[2]])]
  dCol <- (1:dim(originaldata)[2])[colnames(originaldata) %in% toString(as.formula(smformula)[[2]][[3]])]
  outcomeCol <- c(timeCol, dCol)
  # timeCol.name <- cat(colnames(originaldata)[timeCol])
  # dCol.name <- cat(colnames(originaldata)[dCol])


  # check all time points are positive
  if(any(unique(originaldata[,timeCol]) <= 0)){
    stop("Time can only take positive values")
  }

  d <- originaldata[,dCol]
  if (!(all(sort(unique(d))==c(0,1))) & !(all(unique(d)==1))){
    stop("Event indicator for coxph must be coded 0/1 for censoring/event.")
  }

  #partial vars are those variables for which an imputation method has been specified among the available regression types
  partialVars <- which((method=="norm") | (method=="logreg") |  (method=="mlogit"))

  smcovnames <- attr(terms(as.formula(smformula)), "term.labels")

  # Qinghua 02/20/24 update: commented out the below code
  # smcov.func <- c()
  #
  # for (var in 1:length(colnames(originaldata))){
  #   if (identical(grep(colnames(originaldata)[var], smcovnames), integer(0))){
  #     temp <- ""
  #   } else {
  #     temp <- smcovnames[grep(colnames(originaldata)[var], smcovnames)]
  #   }
  #   smcov.func <- c(smcov.func, temp)
  # }


  #fully observed vars are those that are fully observed and are covariates in the substantive model
  fullObsVars <- which((missingCounts == 0) & (colnames(originaldata) %in% smcovnames))

  if (length(partialVars) == 0){
    stop("You have not specified any valid imputation methods in the method argument.")
  }

  # check that methods are given for each partially observed covariates, and not for those fully observed ones
  for (column in 1:ncol(originaldata)){
    if (method[column] != ""){
      # an imputation model has been specified
      if (column %in% outcomeCol){
        stop(paste("An imputation method has been specified for ",
                   colnames(originaldata)[column],
                   ". Elements of the method argument corresponding to the outcome covariates should be empty.", sep = ""))
      } else if (missingCounts[column] == 0){
        stop(paste("An imputation method has been specified for ",
                   colnames(originaldata)[column],
                   ", but it appears to be fully observed.", sep = ""))
      }
    } else {
      # no impuatation method has been specified
      if (missingCounts[column] > 0){
        stop(paste("Covariate", colnames(originaldata)[column], " does not have an imputation method specified,
                   yet appears to have missing values.", sep = ""))
      }
    }
  }

  numit <- 10 # number of iteration
  rjlimit <- rjlimit
  m <- m # number of imputed datasets to generate
  n <- nrow(originaldata)
  imputations <- list()

  for (imp in 1:m){
    imputations[[imp]] <- originaldata
  }

  for (imp in 1:m) {
    message(paste("Imputation ",imp))

    #### initial imputation of each partially observed covariates by a randomly chosen observed value (place holder)

    for (var in 1:length(partialVars)){
      targetCol <- partialVars[var]
      imputations[[imp]][r[, targetCol] == 0, targetCol] <- sample(imputations[[imp]][r[, targetCol] == 1, targetCol],
                                                                   size = sum(r[,targetCol] == 0), replace = TRUE)
    }

    # basehaz.index <- rank(originaldata$time)
    # basehaz.index <- rank(originaldata[,timeCol])
    # Qinghua 03/10/24 update: 'rank' function returns non integer values when there are ties
    nullmod <- coxph(Surv(originaldata[,timeCol], originaldata[,dCol]) ~1)
    basehaz <- basehaz(nullmod)
    basehaz.index <- match(originaldata[,timeCol], basehaz[,2])

    for (cyclenum in 1:numit) {
      # cyclenum <- 1
      # print(paste("Iteration ", cyclenum))

      for (var in 1:length(partialVars)){
        # var <- 2
        ### estimate parameters of substantive model
        # eventmod <- coxph(Surv(time, delta == 1) ~ X1 + X2 + X3, data = imputations[[imp]])
        event.formula <- paste0("Surv(", toString(as.formula(smformula)[[2]][[2]]), ",",toString(as.formula(smformula)[[2]][[3]]),
                                " == 1)")
        eventmod <- coxph(as.formula(paste(event.formula, deparse(as.formula(smformula)[[3]], width.cutoff = 500), sep = " ~ ")),
                          data = imputations[[imp]])
        # summary(eventmod)
        beta <- eventmod$coefficients
        beta.Sigma <- eventmod$var
        newbeta <- beta + MASS::mvrnorm(1, mu = rep(0, length(beta)), Sigma = beta.Sigma)
        basehaz.e <- basehaz(eventmod, centered = FALSE)[,1]
        H1 <- basehaz.e[basehaz.index]

        # censormod <- coxph(Surv(time, delta == 0) ~ X1 + X2 + X3, data = imputations[[imp]])
        censor.formula <- paste0("Surv(", toString(as.formula(smformula)[[2]][[2]]), ",",toString(as.formula(smformula)[[2]][[3]]),
                                 " == 0)")
        censormod <- coxph(as.formula(paste(censor.formula, deparse(as.formula(smformula)[[3]], width.cutoff = 500), sep = " ~ ")),
                           data = imputations[[imp]])
        # summary(censormod)
        phi <- censormod$coefficients
        phi.Sigma <- censormod$var
        newphi <- phi + MASS::mvrnorm(1, mu = rep(0, length(phi)), Sigma = phi.Sigma)
        basehaz.c <- basehaz(censormod, centered = FALSE)[,1]
        H0 <-basehaz.c[basehaz.index]

        targetCol <- partialVars[var]
        predictorCols <- c(partialVars[! partialVars %in% targetCol], fullObsVars)

        if((imp == 1) & (cyclenum ==1)){
          message(paste("Imputing: ",colnames(imputations[[imp]])[targetCol]," using ",
                      paste(colnames(imputations[[imp]])[predictorCols],collapse=',')," plus outcome",collapse=','))
        }

        if (length(predictorCols) > 0){
          # xmodformula <- as.formula(paste(colnames(imputations[[imp]])[targetCol], " ~ ",
          #                                 paste(smcov.func[predictorCols], collapse = "+"),
          #                                 sep = ""))
          # will ask the user to factorise
          xmodformula <- as.formula(paste(colnames(imputations[[imp]])[targetCol], " ~ ",
                                          paste(colnames(imputations[[imp]])[predictorCols], collapse = "+"),
                                          sep = ""))
        }
        xmoddata <- imputations[[imp]]

        if (method[targetCol] == "norm"){
          xmod <- lm(xmodformula, data = xmoddata)

          # sd <- summary(xmod)$sigma
          theta <- xmod$coefficients
          sigmasq <- summary(xmod)$sigma^2
          newsigmasq <- (sigmasq*xmod$df)/rchisq(1, xmod$df)
          covariance <- (newsigmasq/sigmasq)*vcov(xmod)
          # theta.Sigma <- vcov(xmod)
          # newtheta <- theta + MASS::mvrnorm(1, mu = rep(0, length(theta)), Sigma = theta.Sigma)
          newtheta <- theta + MASS::mvrnorm(1, mu = rep(0, length(theta)), Sigma = covariance)
          # calculate fitted values
          xfitted <- model.matrix(xmod)%*%newtheta
        } else if (method[targetCol] == "logreg"){
          xmod <- glm(xmodformula, data = xmoddata, family = "binomial")
          theta <- xmod$coefficients
          theta.Sigma <- vcov(xmod)
          newtheta <- theta + MASS::mvrnorm(1, mu = rep(0, length(theta)), Sigma = theta.Sigma)
          # calculate probability of x2 taking 1
          xfitted <- exp(model.matrix(xmod)%*% newtheta)/(1 + exp(model.matrix(xmod)%*% newtheta))
        } else if (method[targetCol] == 'mlogit'){
          if (is.factor(imputations[[imp]][,targetCol]) == FALSE){
            stop ("Variables to be imputed using method mlogit must be stored as factors.")
          }
          xmod <- VGAM::vglm(xmodformula, VGAM::multinomial(refLevel = 1), data = xmoddata)
          xmod.dummy <- VGAM::vglm(xmodformula, VGAM::multinomial(refLevel = 1), data = imputations[[imp]])
          theta <- VGAM::coef(xmod)
          theta.Sigma <- VGAM::vcov(xmod)
          newtheta <- theta + MASS::mvrnorm(1, mu = rep(0, length(theta)), Sigma = theta.Sigma)
          linpreds <- matrix((VGAM::model.matrix(xmod.dummy)) %*% newtheta, byrow = TRUE, ncol = (nlevels(imputations[[imp]][,targetCol])-1))
          denom <- 1 + rowSums(exp(linpreds))
          xfitted <- cbind(1/denom, exp(linpreds)/denom)
        }

        ## impute the target column, either direct sampling or rejection sampling
        imputationNeeded <- (1:n)[r[,targetCol] == 0]

        if ((method[targetCol] == 'logreg') | (method[targetCol] == "mlogit")){
          #### direct sampling for binary/categorical covariates ########
          if (method[targetCol] == 'logreg'){
            numberOutcomes <- 2
            fittedMean <- cbind(1 - xfitted, xfitted)
          } else {
            numberOutcomes <- nlevels(imputations[[imp]][,targetCol])
            fittedMean <- xfitted
          }

          # D = 0
          outcomeDens.0 <- array(dim = c(length(imputationNeeded), numberOutcomes),0)

          for (xMisVal in 1:numberOutcomes){
            if (method[targetCol] == 'logreg'){
              valToImpute <- xMisVal - 1
            } else {
              valToImpute <- levels(imputations[[imp]][,targetCol])[xMisVal]
            }
            imputations[[imp]][imputationNeeded, targetCol] <- valToImpute
            outmod <- model.matrix(as.formula(smformula), imputations[[imp]])
            g0 <- as.matrix(outmod[,2:dim(outmod)[2]])%*%newphi
            g1 <- as.matrix(outmod[,2:dim(outmod)[2]])%*%newbeta
            outcomeDens <- exp(-exp(g0[imputationNeeded])*H0[imputationNeeded])*exp(
              g0[imputationNeeded])*exp(-exp(g1[imputationNeeded])*H1[imputationNeeded])

            outcomeDens.0[,xMisVal] <- outcomeDens*fittedMean[imputationNeeded, xMisVal]
          }
          directImpProbs.0 <- outcomeDens.0/rowSums(outcomeDens.0)

          # D = 1
          outcomeDens.1 <- array(dim = c(length(imputationNeeded), numberOutcomes),0)
          for (xMisVal in 1:numberOutcomes){
            if (method[targetCol] == 'logreg'){
              valToImpute <- xMisVal - 1
            } else {
              valToImpute <- levels(imputations[[imp]][,targetCol])[xMisVal]
            }
            imputations[[imp]][imputationNeeded, targetCol] <- valToImpute
            outmod <- model.matrix(as.formula(smformula), imputations[[imp]])
            g0 <- as.matrix(outmod[,2:dim(outmod)[2]])%*%newphi
            g1 <- as.matrix(outmod[,2:dim(outmod)[2]])%*%newbeta
            outcomeDens <- exp(-exp(g0[imputationNeeded])*H0[imputationNeeded])*exp(
              g1[imputationNeeded])*exp(-exp(g1[imputationNeeded])*H1[imputationNeeded])

            outcomeDens.1[,xMisVal] <- outcomeDens*fittedMean[imputationNeeded, xMisVal]
          }
          directImpProbs.1 <- outcomeDens.1/rowSums(outcomeDens.1)

          if (method[targetCol] == 'logreg'){
            directImpProbs <- directImpProbs.0[,2]*(originaldata[imputationNeeded, dCol] == 0) +
              directImpProbs.1[,2]*(originaldata[imputationNeeded, dCol] == 1)
            # summary(directImpProbs)
            imputations[[imp]][imputationNeeded, targetCol] <- rbinom(length(imputationNeeded), 1, directImpProbs)
          } else {
            directImpProbs <- directImpProbs.0*(originaldata[imputationNeeded, dCol] == 0) +
              directImpProbs.1*(originaldata[imputationNeeded, dCol] == 1)
            imputations[[imp]][imputationNeeded, targetCol] <- levels(imputations[[imp]][,targetCol])[apply(directImpProbs, 1, catdraw)]
          }

        } else {
          #### rejection sampling #####

          firstTryLimit <- 25
          j <- 1
          while((j < firstTryLimit) & (length(imputationNeeded) > 0)){
            # draw x* from the envelope function
            imputations[[imp]][imputationNeeded, targetCol] <- rnorm(length(imputationNeeded),
                                                                     xfitted[imputationNeeded], newsigmasq^0.5)
            uDraw <- runif(length(imputationNeeded))
            outmod <- model.matrix(as.formula(smformula), imputations[[imp]])
            # head(outmod, n = 10)
            g0 <- as.matrix(outmod[,2:dim(outmod)[2]])%*%newphi
            g1 <- as.matrix(outmod[,2:dim(outmod)[2]])%*%newbeta

            D0 <- H0[imputationNeeded]*exp(1 + g0[imputationNeeded] -
                                             exp(g0[imputationNeeded])*H0[imputationNeeded] -
                                             exp(g1[imputationNeeded])*H1[imputationNeeded])

            D1 <- H1[imputationNeeded]*exp(1 + g1[imputationNeeded] -
                                             exp(g0[imputationNeeded])*H0[imputationNeeded] -
                                             exp(g1[imputationNeeded])*H1[imputationNeeded])

            prob <- (originaldata$delta[imputationNeeded] == 0)*D0 + (originaldata$delta[imputationNeeded] == 1)*D1

            reject <- 1*(uDraw > prob)
            imputationNeeded <- imputationNeeded[reject == 1]
            j <- j + 1
          }

          # for those remaining, acceptance probability should be super low, we then sample by subject
          for (i in imputationNeeded){

            tempdata <- imputations[[imp]][i,]
            tempdata <- tempdata[rep(1, rjlimit),]
            tempdata[,targetCol] <- rnorm(rjlimit, xfitted[i], newsigmasq^0.5)
            uDraw <- runif(rjlimit)
            outmod <- model.matrix(as.formula(smformula), tempdata)
            g0 <- as.matrix(outmod[,2:dim(outmod)[2]])%*%newphi
            g1 <- as.matrix(outmod[,2:dim(outmod)[2]])%*%newbeta

            D0 <- H0[i]*exp(1 + g0 - exp(g0)*H0[i] - exp(g1)*H1[i])
            D1 <- H1[i]*exp(1 + g1 - exp(g0)*H0[i] - exp(g1)*H1[i])

            prob <- (originaldata$delta[i] == 0)*D0 + (originaldata$delta[i] == 1)*D1
            reject <- 1*(uDraw > prob)

            if(sum(reject) < rjlimit){
              imputations[[imp]][i,targetCol] <- tempdata[reject == 0,targetCol][1]
            } else {
              warning("rejection sampling failed for some subjects, increase rjlimit")
            }
          }
        }
      } # end of var loop
    } # end of cyclenum loop
  } # end of imp loop

  return(imputations)
}

catdraw <- function(prob) {
  (1:length(prob))[rmultinom(1,size=1,prob=prob)==1]
}
