#' Gather Past Results for Given Assignment Period
#' @name get_past_results
#' @description Summarizes results of prior periods to use for the current Multi-Arm-Bandit assignment. This function
#' calculates the number of success under each treatment and the total number of observations assigned to each treatment which are used
#' to calculate UCB1 values or Thompson sampling probabilities.
#'
#' @inheritParams single_mab_simulation
#' @inheritParams create_prior
#' @inheritParams cols
#' @param current_data A tibble/data.table with only observations from the current sampling period.
#' @param prior_data A tibble/data.table with only the observations from the prior index.
#' @returns A tibble/data.table containing the number of successes, and number of people for each
#' treatment condition.
#'
#' @details
#' When `perfect_assignment` is FALSE, the maximum value from the specified
#' `assignment_date_col` in the current data is taken as the last possible date
#' the researchers conducting the experiment could have learned about a treatment outcome.
#' All successes that occur past this date are masked and treated as failures for the purposes
#' of assigning this treatments periods, as it simulates the researchers not having
#' received that information yet.
#'
#' @seealso
#' * [run_mab_trial()]
#' * [single_mab_simulation()]
#' * [get_bandit()]
#' @keywords internal
#'
get_past_results <- function(
  current_data,
  prior_data,
  perfect_assignment,
  assignment_date_col = NULL,
  conditions
) {
  base::UseMethod("get_past_results", current_data)
}

#----------------------------------------------------------------------------------
#' @method get_past_results data.frame
#' @title
#' [get_past_results()] for data.frames
#' @inheritParams get_past_results
#' @noRd

get_past_results.data.frame <- function(
  current_data,
  prior_data,
  perfect_assignment,
  assignment_date_col = NULL,
  conditions
) {
  if (!perfect_assignment) {
    current_date <- base::max(current_data[[assignment_date_col$name]])

    prior_data$known_success <- base::ifelse(
      current_date >= prior_data[["new_success_date"]] &
        !base::is.na(prior_data[["new_success_date"]]),
      1,
      0
    )
  } else {
    prior_data$known_success <- prior_data$mab_success
  }

  prior_data <- prior_data |>
    dplyr::group_by(mab_condition) |>
    dplyr::summarize(
      successes = base::sum(known_success, na.rm = TRUE),
      success_rate = base::mean(known_success, na.rm = TRUE),
      n = dplyr::n(),
      .groups = "drop"
    ) |>
    dplyr::ungroup()

  if (base::nrow(prior_data) != base::length(conditions)) {
    conditions_add <- base::setdiff(conditions, prior_data$mab_condition)

    replace <- tibble::tibble(
      mab_condition = conditions_add,
      successes = 0,
      success_rate = 0,
      n = 0
    )

    prior_data <- dplyr::bind_rows(prior_data, replace)
    prior_data <- prior_data[order(prior_data$mab_condition), ]
  }
  return(prior_data)
}
#------------------------------------------------------------------------------

#' @method get_past_results data.table
#' @title
#' [get_past_results()] for data.tables
#' @inheritParams get_past_results
#' @noRd

get_past_results.data.table <- function(
  current_data,
  perfect_assignment,
  assignment_date_col = NULL,
  conditions,
  prior_data
) {
  if (!perfect_assignment) {
    current_date <- base::max(current_data[[assignment_date_col$name]])

    prior_data[,
      known_success := data.table::fifelse(
        current_date >= new_success_date &
          !is.na(new_success_date),
        1,
        0
      )
    ]
  } else if (perfect_assignment) {
    prior_data[, known_success := mab_success]
  } else {
    rlang::abort("Specify Logical for `perfect_assignment`")
  }

  past_results <- prior_data[,
    .(
      successes = base::sum(known_success, na.rm = TRUE),
      success_rate = base::mean(known_success, na.rm = TRUE),
      n = .N
    ),
    by = mab_condition
  ]

  if (base::nrow(past_results) != base::length(conditions)) {
    conditions_add <- base::setdiff(conditions, past_results$mab_condition)
    replace <- data.table::data.table(
      mab_condition = conditions_add,
      successes = 0,
      success_rate = 0,
      n = 0
    )

    past_results <- data.table::rbindlist(
      list(past_results, replace),
      use.names = TRUE
    )
  }
  data.table::setorder(past_results, mab_condition)
  return(invisible(past_results))
}

#-------------------------------------------------------------------------------
#' Calculate Multi-Arm Bandit Decision Based on Algorithm
#' @description Calculates the best treatment for a given period using either a UCB1 or Thompson sampling algorithm.
#' Thompson sampling is done using [bandit::best_binomial_bandit()] from
#' the \href{https://cran.r-project.org/package=bandit}{bandit}
#' package and UCB1 values are calculated using the well-defined formula that can be found
#' in \href{https://arxiv.org/abs/1402.6028}{Kuleshov and Precup (2014)}.
#'
#' @name get_bandit
#'
#' @inheritParams single_mab_simulation
#' @param past_results A tibble/data.table containing summary of prior periods, with
#' successes, number of observations, and success rates, which is created by [get_past_results()].
#' @param current_period Numeric value of length 1; current period of the adaptive trial simulation.
#'
#' @returns A list of length 2 containing:
#' \itemize{
#' \item `bandit`: Bandit object, either a named numeric vector of Thompson sampling probabilities or a
#' tibble/data.table of UCB1 values.
#' \item `assignment_probabilities:` Named numeric vector with a value for each condition
#' containing the probability of being assigned that treatment.}
#'
#' @details
#'
#' The Thompson `assignment_probabilities` are the same as the `bandit` vector except when
#' `control_augment` or `random_assign_prop` are greater than 0, as these arguments will alter the probabilities
#' of assignment.
#'
#' Thompson sampling is calculated using the \href{https://cran.r-project.org/package=bandit}{bandit}
#' package but the direct calculation can result in errors or overflow. If this occurs, a simulation based method
#' from the same package is used instead to estimate the posterior distribution.
#' If this occurs a warning will be presented. `ndraws` specifies the number of iterations for the
#' simulation based method, and the default value is 5000.
#'
#' The UCB1 algorithm only selects 1 treatment at each period, with no probability matching
#' so `assignment_probabilities` will always have 1 element equal to 1, and the rest equal to 0, unless
#' `control_augment` or `random_assign_prop` are greater than 0, which will alter the probabilities of assignment.
#' For example, if the original vector is `(0, 0, 1)`, and `control_augment` = 0.2,
#' the new vector is `(0.2, 0, 0.8)` assuming the first element is control. If instead the 3rd element
#' were the control group the resulting vector would not be changed because it already meets the
#' control group threshold.
#'
#'
#' @references
#' Kuleshov, Volodymyr, and Doina Precup. 2014. "Algorithms for Multi-Armed Bandit Problems."
#' \emph{arXiv}. \doi{10.48550/arXiv.1402.6028}.
#'
#' Loecher, Thomas Lotze and Markus. 2022.
#' "Bandit: Functions for Simple a/B Split Test and Multi-Armed Bandit Analysis."
#' \url{https://cran.r-project.org/package=bandit}.
#'
#' @keywords internal

get_bandit <- function(
  past_results,
  algorithm,
  conditions,
  current_period,
  control_augment = 0,
  ndraws
) {
  bandit <- switch(
    algorithm,
    "thompson" = get_bandit.thompson(
      past_results = past_results,
      conditions = conditions,
      current_period = current_period,
      ndraws = ndraws
    ),
    "ucb1" = get_bandit.ucb1(
      past_results = past_results,
      conditions = conditions,
      current_period = current_period
    ),
    rlang::abort("Invalid `algorithm`. Valid Algorithms: 'thomspon', 'ucb1'")
  )

  assignment_prob <- bandit[["assignment_prob"]]

  if (control_augment > 0) {
    ctrl <- names(conditions) == "control"
    if (assignment_prob[ctrl] < control_augment) {
      assignment_prob[ctrl] <- control_augment
      assignment_prob[!ctrl] <- (assignment_prob[!ctrl] /
        sum(assignment_prob[!ctrl])) *
        (1 - control_augment)
    }
  }
  if (!isTRUE(all.equal(sum(assignment_prob), 1))) {
    assignment_prob <- assignment_prob / sum(assignment_prob)
  }
  bandit[["assignment_prob"]] <- assignment_prob

  return(bandit)
}
#-------------------------------------------------------------------
#' @method get_bandit thompson
#' @title Thompson sampling Algorithm
#' @inheritParams get_bandit
#' @details
#' Thompson sampling is calculated using the \href{https://cran.r-project.org/package=bandit}{bandit}
#' package but the direct calculation can fail. If this occurs, a simulation based method is used
#' instead to estimate the posterior distribution, and the user receives a warning.
#'
#'
#' @returns A named list of length 2, where element 1 is the named numeric vector of Thompson
#' sampling probabilities, and element 2 is a reference to the same vector. The second element is
#' adjusted later in the simulation based on what the user has set for `control_augment` and `random_assign_prop` to reflect the
#' probability of assignment to a given treatment at that period.
#' @keywords internal

get_bandit.thompson <- function(
  past_results,
  conditions,
  current_period,
  ndraws
) {
  bandit <- tryCatch(
    {
      result <- rlang::set_names(
        as.vector(bandit::best_binomial_bandit(
          x = past_results$successes,
          n = past_results$n,
          alpha = 1,
          beta = 1
        )),
        past_results$mab_condition
      )
      if (bandit_invalid(result)) {
        stop("Invalid Bandit")
      }
      result
    },
    error = function(e) {
      rlang::warn(c(
        "Thompson sampling calculation overflowed; simulation based posterior estimate was used instead",
        "i" = sprintf("Period: %d", current_period)
      ))
      result <- rlang::set_names(
        as.vector(bandit::best_binomial_bandit_sim(
          x = past_results$successes,
          n = past_results$n,
          alpha = 1,
          beta = 1,
          ndraws = ndraws
        )),
        past_results$mab_condition
      )

      result
    }
  )

  if (bandit_invalid(bandit)) {
    rlang::abort(c(
      "Thompson sampling simulation failed",
      "x" = paste0("Most Recent Result:", paste0(bandit, collapse = " ")),
      "i" = "Consider setting `ndraws` higher or reducing `prior_periods`."
    ))
  }

  return(list(bandit = bandit, assignment_prob = bandit))
}
#' @name bandit_invalid
#' @title Checks Validity of Thompson sampling probabilities
#' @description Checks if the Thompson sampling probabilities either sum arbitrarily close
#' to 0 or if any of them are NA, indicating the direct calculation failed or did not converge.
#' @param bandit a numeric vector of Thompson sampling probabilities.
#' @returns Logical; TRUE if the vector is invalid, FALSE if valid
#' @keywords internal
bandit_invalid <- function(bandit) {
  return(any(is.na(bandit)) || isTRUE(all.equal(base::sum(bandit), 0)))
}
#-------------------------------------------------------------------
#' @method get_bandit ucb1
#' @title UCB1 Sampling Algorithm
#' @description Calculates upper confidence bounds for each treatment arm
#'
#' @inheritParams get_bandit
#' @returns A named list with 2 elements: a tibble/data.table containing UCB1 and success rate for each condition,
#' and a named numeric vector of assignment probabilities, where the highest UCB1 out of the treatments
#' is assigned 1, and the rest 0.
#' @keywords internal

get_bandit.ucb1 <- function(past_results, conditions, current_period) {
  correction <- 1e-10 ## Prevents Division by 0 when n = 0

  if (data.table::is.data.table(past_results)) {
    past_results[,
      ucb := success_rate +
        base::sqrt(
          (2 * base::log(current_period - 1)) / (n + correction)
        )
    ]

    best_condition <- past_results[
      which.max(ucb),
      mab_condition
    ]
  } else {
    past_results$ucb <- past_results$success_rate +
      base::sqrt(
        (2 * base::log(current_period - 1)) / (past_results$n + correction)
      )

    best_condition <- past_results$mab_condition[base::which.max(
      past_results$ucb
    )]
  }
  assignment_probs <- rlang::set_names(
    rep(0, length(conditions)),
    conditions
  )

  assignment_probs[[best_condition]] <- 1

  return(invisible(list(
    bandit = past_results,
    assignment_prob = assignment_probs
  )))
}
#-------------------------------------------------------------------------------
#' Adaptively Assign Treatments in a Period
#' @description Assigns new treatments for an assignment wave based on the assignment probabilities provided from
#' [get_bandit()], and the proportion of randomly assigned observations specified in `random_assign_prop`.
#' Assignments are made randomly with the given probabilities using [randomizr::block_ra()] or
#' [randomizr::complete_ra()].
#'
#' @name assign_treatments
#' @inheritParams single_mab_simulation
#' @inheritParams cols
#' @param probs Named numeric Vector; probability of assignment for each treatment condition.
#' @inheritParams get_past_results
#' @returns Updated tibble/data.table with the new treatment conditions for each observation, and whether imputation is required.
#' If this treatment is different then from under the original experiment, the 'impute_req' is 1, and else is 0 for the observation.
#'
#' @details
#' The number of rows which are randomly assigned in each period is `random_assign_prop` multiplied by
#' the number of rows in the period. If this number is less than 1, then Bernoulli draws are made for each row
#' with probability `random_assign_prob` to determine if that row will be assigned randomly. Else, the number of random
#' rows is rounded to the nearest whole number, and then that many rows are selected to be assigned through
#' complete random assignment. The row selections are also random.
#' @seealso
#'* [randomizr::block_ra()]
#'* [randomizr::complete_ra()]
#' @keywords internal

assign_treatments <- function(
  current_data,
  probs,
  blocking = NULL,
  conditions,
  condition_col,
  random_assign_prop
) {
  rows <- base::nrow(current_data)
  random_rows <- rows * random_assign_prop
  rand_idx <- if (random_assign_prop > 0 && random_rows < 1) {
    base::which(base::as.logical(stats::rbinom(
      rows,
      1,
      random_assign_prop
    )))
  } else {
    rand_idx <- base::sample(
      x = rows,
      size = base::round(random_rows, 0),
      replace = FALSE
    )
  }

  num_conditions <- base::length(conditions)
  random_probs <- base::rep(1 / num_conditions, num_conditions)
  band_idx <- base::setdiff(seq_len(rows), rand_idx)

  current_data <- if (data.table::is.data.table(current_data)) {
    assign_treatments.data.table(
      current_data = current_data,
      probs = probs,
      blocking = blocking,
      conditions = conditions,
      condition_col = condition_col,
      rand_idx = rand_idx,
      band_idx = band_idx,
      random_probs = random_probs
    )
  } else {
    assign_treatments.data.frame(
      current_data = current_data,
      probs = probs,
      blocking = blocking,
      conditions = conditions,
      condition_col = condition_col,
      rand_idx = rand_idx,
      band_idx = band_idx,
      random_probs = random_probs
    )
  }
  return(current_data)
}
#-----------------------------------------------------

#' @method assign_treatments data.frame
#' @title [assign_treatments()] for data.frames
#' @noRd
assign_treatments.data.frame <- function(
  current_data,
  probs,
  blocking = NULL,
  conditions,
  condition_col,
  rand_idx,
  band_idx,
  random_probs
) {
  current_data$assignment_type[band_idx] <- "bandit"
  current_data$assignment_type[rand_idx] <- "random"

  if (blocking) {
    bandit_blocks <- current_data$block[band_idx]
    random_blocks <- current_data$block[rand_idx]
    if (length(rand_idx) > 0) {
      current_data$mab_condition[
        rand_idx
      ] <- base::as.character(randomizr::block_ra(
        blocks = random_blocks,
        prob_each = random_probs,
        conditions = conditions,
        check_inputs = TRUE
      ))
    }
    if (base::length(band_idx) > 0) {
      current_data$mab_condition[
        band_idx
      ] <- base::as.character(randomizr::block_ra(
        blocks = bandit_blocks,
        prob_each = probs,
        conditions = conditions,
        check_inputs = TRUE
      ))
    }
  } else {
    if (base::length(rand_idx) > 0) {
      current_data$mab_condition[
        rand_idx
      ] <- base::as.character(randomizr::complete_ra(
        N = length(rand_idx),
        prob_each = random_probs,
        conditions = conditions,
        check_inputs = TRUE
      ))
    }
    if (base::length(band_idx) > 0) {
      current_data$mab_condition[
        band_idx
      ] <- base::as.character(randomizr::complete_ra(
        N = length(band_idx),
        prob_each = probs,
        conditions = conditions,
        check_inputs = TRUE
      ))
    }
  }

  current_data$impute_req <- base::ifelse(
    base::as.character(current_data$mab_condition) !=
      base::as.character(current_data[[condition_col$name]]),
    1,
    0
  )
  return(current_data)
}

#' @method assign_treatments data.table
#' @title [assign_treatments()] for data.tables
#' @noRd
assign_treatments.data.table <- function(
  current_data,
  probs,
  blocking = NULL,
  conditions,
  condition_col,
  rand_idx,
  band_idx,
  random_probs
) {
  current_data[band_idx, assignment_type := "bandit"]
  current_data[rand_idx, assignment_type := "random"]

  if (blocking) {
    bandit_blocks <- current_data[band_idx, block]
    random_blocks <- current_data[rand_idx, block]

    if (length(rand_idx) > 0) {
      current_data[
        rand_idx,
        mab_condition := base::as.character(randomizr::block_ra(
          blocks = random_blocks,
          prob_each = random_probs,
          conditions = conditions,
          check_inputs = TRUE
        ))
      ]
    }

    if (base::length(band_idx) > 0) {
      current_data[
        band_idx,
        mab_condition := base::as.character(randomizr::block_ra(
          blocks = bandit_blocks,
          prob_each = probs,
          conditions = conditions,
          check_inputs = TRUE
        ))
      ]
    }
  } else {
    if (base::length(rand_idx) > 0) {
      current_data[
        rand_idx,
        mab_condition := base::as.character(randomizr::complete_ra(
          N = length(rand_idx),
          prob_each = random_probs,
          conditions = conditions,
          check_inputs = TRUE
        ))
      ]
    }
    if (base::length(band_idx) > 0) {
      current_data[
        band_idx,
        mab_condition := base::as.character(randomizr::complete_ra(
          N = length(band_idx),
          prob_each = probs,
          conditions = conditions,
          check_inputs = TRUE
        ))
      ]
    }
  }
  current_data[,
    impute_req := data.table::fifelse(
      base::as.character(mab_condition) !=
        base::as.character(base::get(condition_col$name)),
      1,
      0
    )
  ]
  return(invisible(current_data))
}
