# Schuster (2001) symmetric rater agreement model

#' Computes the model that has kappa as a coefficient and symmetry.
#'
#' Schuster, C. (2001). Kappa as a parameter of a symmetry model for rater agreement.
#' Journal of Educational and Behavioral Statistics, 26(3), 331-342.
#'
#' @param n the matrix of observed counts
#' @param verbose logical. should cycle-by-cycle information be printed out
#' @param max_iter integer. maximum number of iterations to perform
#' @param min_iter integer. minimum number of iterations to perform
#' @param criterion number. maximum change in log(likelihood) to decide convergence
#' @returns a list containing
#'    marginal_pi: vector of expected proportions for each category
#'    kappa numeric: kappa coefficient
#'    v: matrix of symmetry parameters
#'    chisq: Pearson X^2
#'    g_squared: likelihood ratio G^2
#'    df: degrees of freedom
#' @export
Schuster_symmetric_rater_agreement_model <- function(n, verbose=FALSE, max_iter= 10000,
                                                     criterion=1.0e-7, min_iter=1000) {
  r <- nrow(n)
  N <- sum(n)
  start <- Schuster_compute_starting_values(n)
  marginal_pi <- start$marginal_pi
  kappa <- start$kappa
  v <- start$v

  # marginal_pi <- c(0.1, 0.2, 0.3, 0.4)
  # kappa <- 0.60
  # v <- matrix(c(2.4, 2.3, 0.7, 0.225,
  #               2.3, 2.4, 0.95, 0.0125,
  #               0.7, 0.95, 2.4, 0.05,
  #               0.225, 0.0125, 0.05, 2.4), nrow=4, byrow=TRUE)

  pi <- Schuster_compute_pi(marginal_pi, kappa, v)

  logL <- log_likelihood(n, pi)
  if (verbose) {
    message("iter  log(likelihood)   G^2          X^2         criterion\n")
  }
  g_squared <- likelihood_ratio_chisq(n, pi)
  chisq <- pearson_chisq(n, pi)
  if (verbose) {
    message(paste("0", "    ", logL, "    ", g_squared, "    ", chisq, "\n"))
  }

  old_log_l <- logL
  for (iter in 1:max_iter) {
    parms <- Schuster_newton_raphson(n, marginal_pi, kappa, v)
    marginal_pi <- parms$marginal_pi
    kappa <- parms$kappa
    v <- parms$v
    pi <- Schuster_compute_pi(marginal_pi, kappa, v)
    pi <- pi / sum(pi)
    logL <- log_likelihood(n, pi)
    if (is.nan(logL)) {
      stop(paste("NaN encountered", pi, "\n",
          marginal_pi, "\n",
          kappa, "\n",
          v))
    }

    if (verbose || iter %% 100 == 0) {
      x2 <- pearson_chisq(n, pi)
      g2 <- likelihood_ratio_chisq(n, pi)
      message(paste(iter, "    ", logL, "    ", g2, "    ", x2, "   ",
          (logL - old_log_l) / abs(logL), "\n"))
    }

    if (iter < min_iter) next
    if ((logL - old_log_l) / abs(logL) <= criterion) {
      x2 <- pearson_chisq(n, pi)
      g2 <- likelihood_ratio_chisq(n, pi)
      if (verbose) {
        message(paste("\nconverged at iteration", iter, ", log(L) =", logL,
          ", G^2 =", g2, ", X^2 =", x2, "\n"))
      }
      break
    }

    if (logL < old_log_l) {
      x2 <- pearson_chisq(n, pi)
      g2 <- likelihood_ratio_chisq(n, pi)
      if (verbose) {
        message(paste(iter, "    ", logL, "  ", g2, "  ", x2, "\n"))
        message(paste("\nlogL decreased: iteration", iter, ", log(L)  =  ", logL,
          ", G^2 =", g2, ", X^2 =", x2, "\n"))
      }
      break
    }
    old_log_l <- logL
  }

  pi <- Schuster_compute_pi(marginal_pi, kappa, v)
  pi <- pi / sum(pi)
  chisq <- pearson_chisq(n, pi)
  g_squared <- likelihood_ratio_chisq(n, pi)
  df <- Schuster_compute_df(marginal_pi)
  list(marginal_pi=marginal_pi, kappa=kappa, v=v,
       chisq=chisq, g_squared=g_squared, df=df,
       predicted=N * pi)
}


#' Computes starting values for the model.
#'
#' Patterned after example in code in appendix to article
#' @param n matrix of observed counts
#' @returns a list containing
#'    marginal_pi: vector of expected proportions for each category
#'    kappa: kappa coefficient of agreement
#'    v: matrix of symmetry parameters
Schuster_compute_starting_values <- function(n) {
  r <- nrow(n)
  N <- sum(n)
  p_row <- rowSums(n) / N
  p_column <- colSums(n) / N
  marginal_pi <- (p_row + p_column) / 2.0

  p_chance <- sum(diag(p_column %*% t(p_row)))
  kappa <- (sum(diag(n) / N) - p_chance) / (1.0 - p_chance)

  v <- matrix(0.0, nrow=r, ncol=r)
  v_tilde <- Schuster_v_tilde(marginal_pi, kappa)
  for (i in 1:r) {
    for (j in 1:r) {
      v[i, j] <- (n[i, j] + n[j, i]) / (2.0 * N * p_row[i] * p_column[j])
    }
    v[i, i] <- v_tilde
  }
  list(marginal_pi=marginal_pi, kappa=kappa, v=v)
}


#' Compute matrix of model-based proportions pi.
#'
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of the kappa coefficient
#' @param v symmetry matrix
#' @param validate logical. should the cells be validated within this function?
#' Defaults to TRUE
#' @returns matrix of model-based cell proportions
Schuster_compute_pi <- function(marginal_pi, kappa, v, validate=TRUE) {
  r <- length(marginal_pi)
  marginal_pi[r] <- 1.0 - sum(marginal_pi[1:(r - 1)])
  if (validate && (marginal_pi[r] <= 0.0 || marginal_pi[r] >= 1.0)) {
    stop(paste("out of range value of marginal_pi[r]", marginal_pi[r]))
  }

  if (validate) {
    v <- Schuster_enforce_constraints_on_v(marginal_pi, kappa, v)
  }

  pe <- sum(marginal_pi^2)
  pi <- matrix(nrow=r, ncol=r)
  for (i in 1:r) {
    for (j in 1:r) {
      if (j == i) {
        next
      }
      pi[i, j] <- marginal_pi[i] * marginal_pi[j] * v[i, j]
      pi[i, j] <- max(pi[i, j], 0.0001)
    }
    pi[i, i] <- marginal_pi[i]^2 + marginal_pi[i]^2 * kappa * (1.0 - pe) / pe
  }
  pi
}


#' Computes the degrees of freedom for the model.
#'
#' @param pi_margin expected proportions for each of the categories
#' @returns the df for the model
Schuster_compute_df <- function(pi_margin) {
  r <- length(pi_margin)
  df <- (r + 2) * (r - 1) / 2
  df
}


#' Derivative of pi[i, j] wrt kappa coefficient.
#'
#' @param i first index into pi
#' @param j second index into pi
#' @param marginal_pi expected proportions in each category
#' @param kappa current value of kappa coefficient
#' @param v symmetry matrix
#' @returns the derivative of pi[i, j] wrt kappa
Schuster_derivative_pi_wrt_kappa <- function(i, j, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  pe <- sum(marginal_pi^2)
  der_v_tilde_der_kappa <- (1.0 - pe) / pe

  if (j != r && i != r) {
    if (j != i) {
      deriv <- 0.0
    } else {
      deriv <- marginal_pi[i]^2 * der_v_tilde_der_kappa
    }
  } else {
    if (i == r && j != r) {
      deriv <- marginal_pi[i] * marginal_pi[j] * (-marginal_pi[j] * der_v_tilde_der_kappa / marginal_pi[r])
    } else if (j == r && i != r) {
      deriv <- marginal_pi[i] * marginal_pi[j] * (-marginal_pi[i] * der_v_tilde_der_kappa / marginal_pi[r])
    } else {  # i == j == r
      deriv <- marginal_pi[r]^2 * der_v_tilde_der_kappa
    }
  }
  deriv
}


#' Derivative of pi[i, j] wrt marginal_pi[k].
#'
#' @param i first index into pi
#' @param j second index into pi
#' @param k index into marginal_pi
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns derivative of pi[i, j] wrt marginal_pi[k]
Schuster_derivative_pi_wrt_marginal_pi <- function(i, j, k, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  if (r - 1 < k) {
    stop(paste("out of bounds value for k", k, r))
  }

  sum_rest <- sum(marginal_pi[1:(r - 1)])
  marginal_pi[r] <- 1.0 - sum_rest
  if (marginal_pi[r] <= 0.0) {
    stop(paste("out of range value for marginal_pi[r]", marginal_pi[r]))
  }
  v <- Schuster_enforce_constraints_on_v(marginal_pi, kappa, v)

  pe <- sum(marginal_pi^2)
  pi_k <- marginal_pi[k]
  der_pe_der_marginal_pi <- 2.0 * pi_k - 2.0 * (1.0 - sum_rest)

  if (j == i) {
    if (j == r) {  # [r, r] case
      deriv <- -2.0 * marginal_pi[r] - 2.0 * marginal_pi[r] * kappa * (1.0 - pe) / pe
      deriv <- deriv + (marginal_pi[r]^2 * kappa
                * (-1.0 / pe^2) * der_pe_der_marginal_pi)
    } else if (k != j) {  # [j, j], j != k case
      deriv <- (marginal_pi[i]^2 * kappa
                  * (-1.0 / pe^2) * der_pe_der_marginal_pi)
    } else {  # [j, j], j == k
      deriv <- 2.0 * pi_k + 2.0 * pi_k * kappa * (1.0 - pe) / pe
      deriv <- deriv + (pi_k^2 * kappa
                        * (-1.0 / pe^2) * der_pe_der_marginal_pi)
    }
  } else if (j == r || i == r) {
    if (j == r) {
      der_v_tilde_der_pk <- -2.0 * kappa * (marginal_pi[k] - marginal_pi[r]) / pe^2
      if (i != k) {
        deriv <- (marginal_pi[i] * marginal_pi[r]
                  * (-v[i, k] - der_v_tilde_der_pk * marginal_pi[i]) / marginal_pi[r])
      } else {
        deriv <- v[k, r] * marginal_pi[r]
        deriv <- deriv + (marginal_pi[k] * marginal_pi[r]
                          * (-v[k, k] - marginal_pi[k] * der_v_tilde_der_pk) / marginal_pi[r])
      }
    } else {
      der_v_tilde_der_pk <- -2.0 * kappa * (marginal_pi[k] - marginal_pi[r]) / pe^2
      if (j != k) {
        deriv <- (marginal_pi[j] * marginal_pi[r]
                  * (-v[j, k] - der_v_tilde_der_pk * marginal_pi[j]) / marginal_pi[r])
      } else {
        deriv <- v[k, r] * marginal_pi[r]
        deriv <- deriv + (marginal_pi[k] * marginal_pi[r]
                          * (-v[k, k] - marginal_pi[k] * der_v_tilde_der_pk) / marginal_pi[r])
      }
    }
  } else if (j != k && i != k) {
    deriv <- 0.0
  } else {
    if (i == k) {
      deriv <- marginal_pi[j] * v[i, j]
    } else {
      deriv <- marginal_pi[i] * v[i, j]
    }
  }
}


#' Computes derivative of pi[i, j] wrt v[i1, j1]
#'
#' @param i first index into pi
#' @param j second index into pi
#' @param i1 first index into v
#' @param j1 second index into v
#' @param marginal_pi expected marginal proportions
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns value of derivative of specified pi wrt specified element of v
Schuster_derivative_pi_wrt_v <- function(i, j, i1, j1, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  pe <- sum(marginal_pi^2)

  if (i == j || i1 == j1) {
    deriv <- 0.0
  } else {
    match <- i == i1 && j == j1
    swap <- i == j1 && j == i1
    last <- i == r || j == r
    if (!match && !swap) {
      if (!last) {
        deriv <- 0.0
      } else {
        deriv <- marginal_pi[i] * marginal_pi[j] * Schuster_derivative_v_wrt_v(i, j, i1, j1, marginal_pi, kappa, v)
      }
    } else {
      deriv <- marginal_pi[i] * marginal_pi[j] * Schuster_derivative_v_wrt_v(i, j, i1, j1, marginal_pi, kappa, v)
    }
  }
  deriv
}


#' Computes derivative of v[i1, j1] wrt v[i2, j2]
#'
#' Needed because of computed v terms in column r
#' @param i1 first index into target v
#' @param j1 second index into target v
#' @param i2 first index into
#' @param j2 second index into
#' @param marginal_pi expected marginal proportions
#' @param kappa current estimate of kappa coefficient
#' @param v matrix of symmetry parameters
#' @returns derivative of v[i1, j1] wrt v[i2, j2]
Schuster_derivative_v_wrt_v <- function(i1, j1, i2, j2, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  pe <- sum(marginal_pi^2)
  if (i2 == j2) {
    return(0.0)
  } else if (i1 == j1) {
    return(0.0)
  }

  match <- i1 == i2 && j1 == j2
  swap <- i1 == j2 && j1 == i2
  last <- i1 == r || j1 == r
  deriv <- 0.0
  if (match || swap) {
    if (last) {
      deriv <- 0.0
    } else {
      deriv <- 1.0
    }
  } else if (last) {
    if (i1 == r && i2 == j1) {
      deriv <- -marginal_pi[j2] / marginal_pi[r]
    } else if (i1 == r && j2 == j1) {
      deriv <- -marginal_pi[i2] / marginal_pi[r]
    } else if (j1 == r && j2 == i1) {
      deriv <- -marginal_pi[i2] / marginal_pi[r]
    } else if (j1 == r && i2 == i1) {
      deriv <- -marginal_pi[j2] / marginal_pi[r]
    } else {
      deriv <- 0
    }
  }
  deriv
}


#' Second order partial wrt kappa, kappa
#'
#' Derivative is uniformly 0
#' @param i first index of pi
#' @param j second index of pi
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of the kappa coefficient
#' @param v symmetry matrix
#' @returns second order partial derivative
Schuster_second_deriv_pi_wrt_kappa_2 <- function(i, j, marginal_pi, kappa, v) {
  0.0
}


#' Second order partial wrt kappa, marginal_pi
#'
#' Derivative is uniformly 0
#' @param i first index of pi
#' @param j second index of pi
#' @param k index of marginal_pi
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of the kappa coefficient
#' @param v symmetry matrix
#' @returns second order partial derivative
Schuster_second_deriv_pi_wrt_marginal_pi_kappa <- function(i, j, k, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  pe <- sum(marginal_pi^2)
  sum_rest <- sum(marginal_pi[1:(r - 1)])
  der_v_tilde_der_kappa <- (1.0 - pe) / pe
  der_pe_der_marginal_pi <- 2.0 * marginal_pi[k] - 2.0 * (1.0 - sum_rest)

  last = (i == r) || (j == r)
  if (!last) {
    if (i != j) {
      deriv <- 0.0
    } else {  # i == j
      if (k != i) {
        deriv <- -marginal_pi[i]^2 * der_pe_der_marginal_pi / pe^2
      } else { # k == i == j != r, e.g. 1, 1, 1
        deriv <- 2.0 * marginal_pi[k] * (1.0 - pe) / pe - marginal_pi[k]^2 * der_pe_der_marginal_pi / pe^2
      }
    }
  } else { # last, and i != k and j != k, i may == j
    if (i == r && j != r) {
      if (k == j) {
        deriv <- -2.0 * marginal_pi[j] * der_v_tilde_der_kappa
        deriv <- deriv + marginal_pi[j]^2 * der_pe_der_marginal_pi / pe^2
      } else {
        deriv <- marginal_pi[j]^2 * der_pe_der_marginal_pi / pe^2
      }
    } else if (j == r && i != r) {
      if (k == i) {
        deriv <- -2.0 * marginal_pi[i] * der_v_tilde_der_kappa
        deriv <- deriv + marginal_pi[i]^2 * der_pe_der_marginal_pi / pe^2
      } else {
        deriv <- marginal_pi[i]^2 * der_pe_der_marginal_pi / pe^2
      }
    } else {  # i == j == r
      deriv <- -marginal_pi[r]^2 * der_pe_der_marginal_pi / pe^2
      deriv <- deriv - 2.0 * marginal_pi[r] * der_v_tilde_der_kappa
    }
  }
  deriv
}


#' Second order partial wrt kappa, v
#'
#' Derivative is uniformly 0
#' @param i first index of pi
#' @param j second index of pi
#' @param i1 first index of v
#' @param j1 second index of v
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of the kappa coefficient
#' @param v symmetry matrix
#' @returns second order partial derivative
Schuster_second_deriv_pi_wrt_kappa_v <- function(i, j, i1, j1, marginal_pi, kappa, v) {
  0.0
}


#' Second order partial wrt v^2
#'
#' Derivative is uniformly 0
#' @param i first index of pi
#' @param j second index of pi
#' @param i1 first index of first v
#' @param j1 second index of first v
#' @param i2 first index of second v
#' @param j2 second index of second
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of the kappa coefficient
#' @param v symmetry matrix
#' @returns second order partial derivative
Schuster_second_deriv_pi_wrt_v_2 <- function(i, j, i1, j1, i2, j2, marginal_pi, kappa, v) {
  0.0
}


#' Second derivative of pi[i, j] wrt marginal_pi[k]^2
#'
#' @param i first index into pi
#' @param j second index into pi
#' @param k index into marginal_pi
#' @param k2 second index into marginal_pi
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns second derivative of pi[i, j] wrt marginal_pi^2
Schuster_second_deriv_pi_wrt_marginal_pi_2 <- function(i, j, k, k2, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  if (r - 1 < k) {
    stop(paste("out of bounds value for k", k, r))
  }

  v <- Schuster_enforce_constraints_on_v(marginal_pi, kappa, v)
  sum_rest <- sum(marginal_pi[1:(r - 1)])
  marginal_pi[r] <- 1.0 - sum_rest
  if (marginal_pi[r] <= 0.0) {
    stop(paste("out of range value for marginal_pi[r]", marginal_pi[r]))
  }

  if (j == i) {
    if (j != r) {
      deriv <- handle_tied_below_maximum(j, k, k2, marginal_pi, kappa, v)
    } else {
      deriv <- handle_tied_maximum(k, k2, marginal_pi, kappa, v)
    }
  } else if (i == r || j == r) {
    deriv <- handle_one_maximum(i, j, k, k2, marginal_pi, kappa, v)
  } else {
    deriv <- handle_untied_below_maximum(i, j,k, k2, marginal_pi, kappa, v)
  }
  deriv
}


#' Case where pi[r, r] with k and k2
#'
#' @param k first index into marginal_pi
#' @param k2 second index into marginal_pi
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @return second order derivative
handle_tied_maximum <- function(k, k2, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  sum_rest <- sum(marginal_pi[1:(r - 1)])
  pi_r <- 1.0 - sum_rest
  pe <- sum(marginal_pi^2)
  pi_k <- marginal_pi[k]
  pi_k2 <- marginal_pi[k2]
  der_pe_der_marginal_pi <- 2.0 * pi_k - 2.0 * (1.0 - sum_rest)
  der2_pe_der_marginal_pi <- 2.0
  if (k2 == k) {
    der2_pe_der_marginal_pi <- 4.0
  }

  if (k == k2) {
    deriv <- 2.0 + 2.0 * kappa * (1.0 - pe) / pe
    deriv <- deriv - 4.0 * pi_r * kappa * (-1.0 / pe^2) * (2.0 * pi_k - 2.0 * pi_r)
    deriv <- deriv + pi_r^2 * kappa * (2.0/ pe^3) * (2.0 * pi_k - 2.0 * pi_r)^2
    deriv <- deriv + pi_r^2 * kappa * (-4.0 / pe^2)
  } else {
    deriv <- 2.0 + 2.0 * kappa * (1.0 - pe) / pe
    deriv <- deriv - 2.0 * pi_r * kappa * (-1.0 / pe^2) * (2.0 * pi_k - 2.0 * pi_r)
    deriv <- deriv - 2.0 * pi_r * kappa * (-1.0 / pe^2) * (2.0 * pi_k2 - 2.0 * pi_r)
    deriv <- deriv + pi_r^2 * kappa * (2.0 / pe^3) * (2.0 * pi_k - 2.0 * pi_r) *  (2.0 * pi_k2 - 2.0 * pi_r)
    deriv <- deriv + pi_r^2 * kappa * (-2.0 / pe^2)
  }

  deriv
}


#' Case where i == j, i < r, j < r
#'
#' @param j index of pi
#' @param k first index into marginal_pi
#' @param k2 second index into marginal_pi
#' @param marginal_pi expected proportions for each of the categories
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns derivative
handle_tied_below_maximum <- function(j, k, k2, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  sum_rest <- sum(marginal_pi[1:(r - 1)])
  rest <- 1.0 - sum_rest
  pe <- sum(marginal_pi^2)
  pi_j <- marginal_pi[j]
  pi_k <- marginal_pi[k]
  pi_k2 <- marginal_pi[k2]

  if (k != j && k2 != j) {
    if (k != k2) {
      deriv <- pi_j^2 * kappa * (2.0 / pe^3) * (2.0 * pi_k - 2.0 * rest) * (2.0 * pi_k2 - 2.0 * rest)
      deriv <- deriv - pi_j^2 * kappa * (2.0 / pe^2)
    } else {
      deriv <- pi_j^2 * kappa * (2.0 / pe^3) * (2.0 * pi_k - 2.0 * rest)^2
      deriv <- deriv - pi_j^2 * kappa * (4.0 / pe^2)
    }
  } else if (k == j) {
    if (k2 == j) {  # k = j, k2 = j
      deriv <- 2.0 + 2.0 * kappa * (1.0 - pe) / pe
      deriv <- deriv - 4.0 * pi_j * kappa * (1.0 / pe^2) * (2.0 * pi_j - 2.0 * rest)
      deriv <- deriv + pi_j^2 * kappa * (2.0 / pe^3) * (2.0 * pi_j - 2.0 * rest)^2
      deriv <- deriv - pi_j^2 * kappa * 4.0 / pe^2
    } else {  # k = j, k2 != j
      deriv <- 2.0 * pi_j * kappa * (-1 / pe^2) * (2.0 * pi_k2 - 2.0 * rest)
      deriv <- deriv + pi_j^2 * kappa * 2.0 / pe^3 * (2.0 * pi_j - 2.0 * rest) * (2.0 * pi_k2 - 2.0 * rest)
      deriv <- deriv + pi_j^2 * kappa * (-2.0 / pe^2)
    }
  } else {  # k2 == j, k != j
    deriv <- 2.0 * pi_j * kappa * (-1 / pe^2) * (2.0 * pi_k - 2.0 * rest)
    deriv <- deriv + pi_j^2 * kappa * 2.0 / pe^3 * (2.0 * pi_j - 2.0 * rest) * (2.0 * pi_k - 2.0 * rest)
    deriv <- deriv + pi_j^2 * kappa * (-2.0 / pe^2)
  }

  deriv
}


#' Case where pi[i, r] with k and k2
#'
#' @param i first index of pi
#' @param j second index of pi
#' @param k first index into marginal_pi
#' @param k2 second index into marginal_pi
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @return second order derivative
handle_one_maximum <- function(i, j, k, k2, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  sum_rest <- sum(marginal_pi[1:(r - 1)])
  pe <- sum(marginal_pi^2)
  pi_i <- marginal_pi[i]
  pi_j <- marginal_pi[j]
  pi_k <- marginal_pi[k]
  pi_k2 <- marginal_pi[k2]
  pi_r <- 1.0 - sum_rest
  der_pe_der_marginal_pi <- 2.0 * pi_k - 2.0 * (1.0 - sum_rest)
  der2_pe_der_marginal_pi <- 2.0
  if (k2 == k) {
    der2_pe_der_marginal_pi <- 4.0
  }

  if (j == r) {
    if (i == k && i == k2) {
      deriv <- handle_max_i_i(i, marginal_pi, kappa, v)
    } else if (i == k && i != k2) {
      deriv <- handle_max_i_k(i, k2, marginal_pi, kappa, v)
    } else if (i != k && i == k2) {
      deriv <- handle_max_i_k(i, k,  marginal_pi, kappa, v)
    } else {  # i != k && i != k2
      deriv <- handle_max_k_k2(i, k, k2, marginal_pi, kappa, v)
    }
  } else {
    if (j == k && j == k2) {
      deriv <- handle_max_i_i(j, marginal_pi, kappa, v)
    } else if (j == k && j != k2) {
      deriv <- handle_max_i_k(j, k2, marginal_pi, kappa, v)
    } else if (j != k && j == k2) {
      deriv <- handle_max_i_k(j, k,  marginal_pi, kappa, v)
    } else {  # j != k && j != k2
      deriv <- handle_max_k_k2(j, k, k2, marginal_pi, kappa, v)
    }
  }

  deriv
}


#' Case where j == r, i == k == k2
#'
#' @param i index into marginal_pi
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns second-order derivative
handle_max_i_i <- function(i, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  sum_rest <- sum(marginal_pi[1:(r - 1)])
  pe <- sum(marginal_pi^2)
  pi_i <- marginal_pi[i]
  pi_r <- 1.0 - sum_rest

  deriv <- -2.0 * v[i, i] - 4.0 * pi_i * kappa * (-1.0 / pe^2) * (2.0 * pi_i - 2.0 * pi_r)
  deriv <- deriv - pi_i^2 * kappa * (2 / pe^3) * (2.0 * pi_i - 2.0 * pi_r)^2
  deriv <- deriv - pi_i^2 * kappa * (-4.0 / pe^2)

  deriv
}

#' Case where j == r, i != k, i == k2
#'
#' @param i index into pi
#' @param k index into v (other is i)
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns second-order derivative
handle_max_i_k <- function(i, k, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  sum_rest <- sum(marginal_pi[1:(r - 1)])
  pe <- sum(marginal_pi^2)
  pi_i <- marginal_pi[i]
  pi_k <- marginal_pi[k]
  pi_r <- 1.0 - sum_rest

  # deriv <- -pi_i^2 * kappa * (2.0 / pe^3) * (2.0 * pi_k - 2.0 * pi_r)^2
  # deriv <- deriv - pi_i^2 * kappa * (-4.0 / pe^2)
  deriv = -v[i, k] - 2.0 * pi_i * kappa * (-1.0 / pe^2) * (2.0 * pi_k - 2.0 * pi_r)
  deriv <- deriv - pi_i^2 * kappa * (2.0 / pe^3) * (2.0 * pi_k - 2.0 * pi_r) * (2.0 * pi_i - 2.0 * pi_r)
  deriv <- deriv - pi_i^2 * kappa * (-2.0 / pe^2)
  deriv
}


#' Case where j == r, i != k && i != k2
#'
#' @param i index into pi
#' @param k first index into marginal_pi
#' @param k2 second index into marginal_pi
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns second-order derivative
handle_max_k_k2 <- function(i, k, k2, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  sum_rest <- sum(marginal_pi[1:(r - 1)])
  pe <- sum(marginal_pi^2)
  pi_i <- marginal_pi[i]
  pi_k <- marginal_pi[k]
  pi_k2 <- marginal_pi[k2]
  pi_r <- 1.0 - sum_rest
  if (k == k2) {
    deriv <- -pi_i^2 * kappa * (2.0 / pe^3) * (2.0 * pi_k - 2.0 * pi_r)^2
    deriv <- deriv - pi_i^2 * kappa * (-4.0 / pe*2)
    deriv <- -pi_i^2 * kappa * (2.0 / pe^3) * (2.0 * pi_k - 2.0 * pi_r)^2
    deriv <- deriv - pi_i^2 * kappa * (-4.0 / pe^2)
  } else {  # k != k2
    deriv <- -pi_i^2 * kappa * (2.0 / pe^3) * (2.0 * pi_k - 2.0 * pi_r) * (2.0 * pi_k2 - 2.0 * pi_r)
    deriv <- deriv - pi_i^2 * kappa * (-2.0 / pe^2)
  }
  deriv
}

#' Case where i != j, i < r && j < r
#'
#' @param i first index of pi
#' @param j second index of pi
#' @param k first index of marginal_pi
#' @param k2 second index of marginal_pi
#' @param marginal_pi expected proportions of each of the categories
#' @param kappa current value of kappa coefficient
#' @param v symmetry matrix
handle_untied_below_maximum <- function(i, j, k, k2, marginal_pi, kappa, v) {
  r <- length(marginal_pi)

  if ((j != k && i != k) || (j != k2 && i != k2)) {
    deriv <- 0.0
  } else {
    if (i == k) {
      if (j == k2) {
        deriv <- v[i, j]
      } else {
        deriv <- 0.0
      }
    } else if(j == k) {
      if (i == k2) {
        deriv <- v[i, j]
      } else {
        deriv <- 0.0
      }
    }
  }
  deriv
}


#' Second order partial pi wrt marginal_pi and v
#'
#' @param i first index of pi
#' @param j second index of pi
#' @param k index of marginal_pi
#' @param i1 first index of v
#' @param j1 second index of v
#' @param marginal_pi expected proportions of each of the categories
#' @param kappa current value of kappa coefficient
#' @param v symmetry matrix
#' @returns derivative
Schuster_second_deriv_pi_wrt_marginal_pi_v <- function(i, j, k, i1, j1, marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  sum_rest <- sum(marginal_pi[1:(r - 1)])
  pe <- sum(marginal_pi^2)
  pi_i <- marginal_pi[i]
  pi_j <- marginal_pi[j]
  pi_i1 <- marginal_pi[i1]
  pi_j1 <- marginal_pi[j1]

  deriv <- 99
  if ((i == j) || (i1 == j1)) {
    deriv <- 0.0
  } else if ((i < r) && (j < r)) {
    match <- (i == i1) && (j == j1)
    swapped <- (i == j1) && (j == i1)
    if (match || swapped) {
      if (i == k) {
        deriv <- pi_j
      } else if (j == k) {
        deriv <- pi_i
      } else {
        deriv <- 0.0
      }
    } else {
      deriv <- 0.0
    }
  } else if (i == r) {
    if (j == j1) {
      if (k == j1)
        deriv <- -pi_i1
      else if (k == i1) {
        deriv <- -pi_j1
      } else {
        deriv <- 0.0
      }
    } else if (j == i1) {
      if (k == j1) {
        deriv <- -pi_i1
      } else if (k == i1) {
        deriv <- -pi_j1
      } else {
        deriv <- 0.0
      }
    } else {
      deriv <- 0.0
    }
  } else {  # j == r
    if (i == j1) {
      if (k == i1) {
        deriv <- -pi_j1
      } else if (k == j1) {
        deriv <- -pi_i1
      } else {
        deriv <- 0.0
      }
    } else if (i == i1) {
      if (k == i1) {
        deriv <- -pi_j1
      } else if (k == j1) {
        deriv <- -pi_i1
      } else {
        deriv <- 0.0
      }
    } else {
      deriv <- 0.0
    }
  }
  deriv
}


#' Derivative of log(likelihood) wrt kappa.
#'
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each category
#' @param kappa current value of kappa coefficient
#' @param v symmetry matrix
#' @returns derivative of log(L) wrt kappa
Schuster_derivative_log_l_wrt_kappa <- function(n, marginal_pi, kappa, v) {
  r <- nrow(n)
  N <- sum(n)
  v <- Schuster_enforce_constraints_on_v(marginal_pi, kappa, v)
  pi <- Schuster_compute_pi(marginal_pi, kappa, v)
  deriv <- 0.0
  for (i in 1:r) {
    for (j in 1:r) {
      der <- Schuster_derivative_pi_wrt_kappa(i, j, marginal_pi, kappa, v)
      deriv <- deriv + (n[i, j] / pi[i, j] - N) * der
    }
  }
  deriv
}


#' Derivative of log(likelihood) wrt marginal_pi[k]
#'
#' @param n matrix of observed counts
#' @param k index into marginal_pi
#' @param marginal_pi expected proportions of each of the categories
#' @param kappa current value of the kappa coefficient
#' @param v symmetry matrix
#' @returns derivative of log(L) wrt marginal_pi[k]
Schuster_derivative_log_l_wrt_marginal_pi <- function(n, k, marginal_pi, kappa, v) {
  r <- nrow(n)
  if (r - 1 < k) {
    stop("out of bounds value for k")
  }
  N <- sum(n)
  pi <- Schuster_compute_pi(marginal_pi, kappa, v)
  deriv <- 0.0
  for (i in 1:r) {
    for (j in 1:r) {
      der <- Schuster_derivative_pi_wrt_marginal_pi(i, j, k, marginal_pi, kappa, v)
      deriv <- deriv + (n[i, j] / pi[i, j] - N) * der
    }
  }
  deriv
}


#' Derivative of log(likelihood) wrt v[i1, j1]
#'
#' @param n matrix of observed counts
#' @param i1 first index into v
#' @param j1 second index into v
#' @param marginal_pi expected marginal proportions
#' @param kappa current value of kappa coefficient
#' @param v symmetry matrix
#' @returns derivative of log(L) wrt v[i1, j1]
Schuster_derivative_log_l_wrt_v <- function(n, i1, j1, marginal_pi, kappa, v) {
  r <- nrow(n)
  N <- sum(n)
  pi <- Schuster_compute_pi(marginal_pi, kappa, v)
  deriv <- 0.0
  for (i in 1:r) {
    for (j in 1:r) {
      der <- Schuster_derivative_pi_wrt_v(i, j, i1, j1, marginal_pi, kappa, v)
      deriv <- deriv + (n[i, j] / pi[i, j] - N) * der
    }
  }
  deriv
}


#' Second order partial log(L) wrt marginal_pi^2.
#'
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each response category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' second derivative of log(L) wrt marginal_pi^2
Schuster_second_deriv_log_l_wrt_marginal_pi_2 <- function(n, marginal_pi, kappa, v) {
  r <- nrow(n)
  N <- sum(n)
  v <- Schuster_enforce_constraints_on_v(marginal_pi, kappa, v)
  pi <- Schuster_compute_pi(marginal_pi, kappa, v)
  deriv <- matrix(0.0, nrow=r - 1, ncol=r - 1)
  for (i in 1:r) {
    for (j in 1:r) {
      for (k in 1:(r - 1)) {
        der1 <- Schuster_derivative_pi_wrt_marginal_pi(i, j, k, marginal_pi, kappa, v)
        for (k1 in 1:(r - 1)) {
          der2 <- Schuster_derivative_pi_wrt_marginal_pi(i, j, k1, marginal_pi, kappa, v)
          der <- Schuster_second_deriv_pi_wrt_marginal_pi_2(i, j, k, k1, marginal_pi, kappa, v)
          deriv[k, k1] <- deriv[k, k1] + (n[i, j] / pi[i, j] - N) * der
          deriv[k, k1] <- deriv[k, k1] - (n[i, j] / pi[i, j]^2) * der1 * der2
        }
      }
    }
  }
  deriv
}


#' Second order partial log(L) wrt marginal_pi and kappa.
#'
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each response category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' second derivative of log(L) wrt marginal_pi and kappa
Schuster_second_deriv_log_l_wrt_marginal_pi_kappa <- function(n, marginal_pi, kappa, v) {
  r <- nrow(n)
  N <- sum(n)
  v <- Schuster_enforce_constraints_on_v(marginal_pi, kappa, v)
  pi <- Schuster_compute_pi(marginal_pi, kappa, v)

  deriv <- rep(0.0, r - 1)
  for (i in 1:r) {
    for (j in 1:r) {
      for (k in 1:(r - 1)) {
        der1 <- Schuster_derivative_pi_wrt_marginal_pi(i, j, k, marginal_pi, kappa, v)
        der2 <- Schuster_derivative_pi_wrt_kappa(i, j, marginal_pi, kappa, v)
        d <- Schuster_second_deriv_pi_wrt_marginal_pi_kappa(i, j, k, marginal_pi, kappa, v)
        deriv[k] <- deriv[k] + (n[i, j] / pi[i, j] - N) * d
        deriv[k] <- deriv[k] - (n[i, j] / pi[i, j]^2) * der1 * der2
      }
    }
  }
  deriv
}


#' Second order partial log(L) wrt marginal_pi and v.
#'
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each response category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' second derivative of log(L) wrt marginal_pi and v
Schuster_second_deriv_log_l_wrt_marginal_pi_v <- function(n, marginal_pi, kappa, v) {
  r <- nrow(n)
  N <- sum(n)
  v <- Schuster_enforce_constraints_on_v(marginal_pi, kappa, v)
  pi <- Schuster_compute_pi(marginal_pi, kappa, v)
  deriv <- matrix(0.0, nrow=r - 1, ncol=(r - 1) * (r - 2) / 2)
  for (i in 1:r) {
    for (j in 1:r) {
      for (k in 1:(r - 1)) {
        der1 <- Schuster_derivative_pi_wrt_marginal_pi(i, j, k, marginal_pi, kappa, v)
        index <- 1
        for (i1 in 1:(r - 2)) {
          for (j1 in (i1 + 1):(r - 1)) {
            der2 <- Schuster_derivative_pi_wrt_v(i, j, i1, j1, marginal_pi, kappa, v)
            deriv[k, index] <- (deriv[k, index] + (n[i, j] / pi[i, j] - N)
                * Schuster_second_deriv_pi_wrt_marginal_pi_v(i, j, k, i1, j1, marginal_pi, kappa, v))
            deriv[k, index] <- deriv[k, index] - (n[i, j] / pi[i, j]^2) * der1 * der2
            index <- index + 1
          }
        }
      }
    }
  }
  deriv
}


#' Second order partial log(L) wrt kappa^2.
#'
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each response category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' second derivative of log(L) wrt kappa^2
Schuster_second_deriv_log_l_wrt_kappa_2 <- function(n, marginal_pi, kappa, v) {
  r <- nrow(n)
  N <- sum(n)
  v <- Schuster_enforce_constraints_on_v(marginal_pi, kappa, v)
  pi <- Schuster_compute_pi(marginal_pi, kappa, v)
  deriv <- 0.0
  for (i in 1:r) {
    for (j in 1:r) {
      der <- Schuster_derivative_pi_wrt_kappa(i, j, marginal_pi, kappa, v)
      deriv <- deriv + (n[i, j] / pi[i, j] - N) * Schuster_second_deriv_pi_wrt_kappa_2(i, j,
                                                                                       marginal_pi,
                                                                                       kappa, v)
      deriv <- deriv - (n[i, j] / pi[i, j]^2) * der^2
    }
  }
  deriv
}


#' Second order partial log(L) wrt kappa and v.
#'
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each response category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' second derivative of log(L) wrt kappa and v
Schuster_second_deriv_log_l_wrt_kappa_v <- function(n, marginal_pi, kappa, v) {
  r <- nrow(n)
  N <- sum(n)
  v <- Schuster_enforce_constraints_on_v(marginal_pi, kappa, v)
  pi <- Schuster_compute_pi(marginal_pi, kappa, v)
  deriv <- rep(0.0, (r - 1) * (r - 2) / 2)
  for (i in 1:r) {
    for (j in i:r) {
      der1 <- Schuster_derivative_pi_wrt_kappa(i, j, marginal_pi, kappa, v)
      index <- 1
      for (i1 in 1: (r - 2)) {
        for (j1 in (i1 + 1): (r - 1)) {
          der2 <- Schuster_derivative_pi_wrt_v(i, j, i1, j1, marginal_pi, kappa, v)
          d <- Schuster_second_deriv_pi_wrt_kappa_v(i, j, i1, j1, marginal_pi, kappa, v)
          # multiply by 2 because symmetry of v means double conrtibution [i1, j1] and [j1, i1]
          deriv[index] <- deriv[index] + 2.0 * (n[i, j] / pi[i, j] - N) *d
          deriv[index] <- deriv[index] - 2.0 * (n[i, j] / pi[i, j]^2) * der1 * der2
          index <- index + 1
        }
      }
    }
  }
  deriv
}


#' Second order partial log(L) wrt v^2.
#'
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each response category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' second derivative of log(L) wrt v^2
Schuster_second_deriv_log_l_wrt_v_2 <- function(n, marginal_pi, kappa, v) {
  r <- nrow(n)
  n_parm <- (r - 1) * (r - 2) / 2
  N <- sum(n)
  v <- Schuster_enforce_constraints_on_v(marginal_pi, kappa, v)
  pi <- Schuster_compute_pi(marginal_pi, kappa, v)
  deriv_mat <- matrix(0.0, nrow=n_parm, ncol=n_parm)
  for (i in 1:r) {
    for (j in 1:r) {
      index <- 1
      for (i1 in 1:(r - 2)) {
        for (j1 in (i1 + 1):(r - 1)) {
          der1 <- Schuster_derivative_pi_wrt_v(i, j, i1, j1, marginal_pi, kappa, v)
          index2 <- 1
          for (i2 in 1:(r - 2)) {
            for (j2 in (i2 + 1): (r - 1)) {
              der2 <- Schuster_derivative_pi_wrt_v(i, j, i2, j2, marginal_pi, kappa, v)
              d <- Schuster_second_deriv_pi_wrt_v_2(i, j, i1, j1, i2, j2, marginal_pi, kappa, v)
              deriv_mat[index, index2] <- deriv_mat[index, index2] + (n[i, j] / pi[i, j] - N) * d
              deriv_mat[index, index2] <- deriv_mat[index, index2] - (n[i, j] / pi[i, j]^2) * der1 * der2
              index2 <- index2 + 1
            }
          }
          index <- index + 1
        }
      }
    }
  }
  deriv_mat
}


#' Gradient vector log(L) wrt parameters.
#'
#' Work is delegated to functions that compute partial derivatives.
#' This function is responsible for laying them out in correct
#' positions in the vector.
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each response category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns gradient vector
Schuster_gradient <- function(n, marginal_pi, kappa, v) {
  r <- nrow(n)
  n_parms <- r - 1
  n_parms <- n_parms + 1
  n_parms <- n_parms + (r - 1) * (r - 2) / 2
  gradient <- vector("double", n_parms)

  index <- 1
  for (k in 1:(r - 1)) {
    gradient[index] <- Schuster_derivative_log_l_wrt_marginal_pi(n, k, marginal_pi, kappa, v)
    index <- index + 1
  }

  gradient[index] <- Schuster_derivative_log_l_wrt_kappa(n, marginal_pi, kappa, v)
  index <- index + 1

  for (i1 in 1:(r - 2)) {
    for (j1 in (i1 + 1):(r - 1)) {
      gradient[index] <- Schuster_derivative_log_l_wrt_v(n, i1, j1, marginal_pi, kappa, v)
      index <- index + 1
    }
  }

  gradient
}


#' Computes the hessian matrix of second-order partial derivatives
#' of log(L).
#'
#' Work is delegated to functions that compute second-order partial
#' derivatives.  This function is responsible for laying them out in
#' correct positions in the matrix.
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of the kappa coefficient
#' @param v symmetry matrix
#' @returns hessian matrix
Schuster_hessian <- function(n, marginal_pi, kappa, v) {
  r <- nrow(n)
  n_parms <- r - 1
  n_parms <- n_parms + 1
  n_v <- (r - 1) * (r - 2) / 2
  n_parms <- n_parms + n_v
  hessian <- matrix(0.0, nrow=n_parms, ncol=n_parms)

  mat_marginal_pi_marginal_pi <- Schuster_second_deriv_log_l_wrt_marginal_pi_2(n, marginal_pi, kappa, v)
  vec_marginal_pi_kappa <- Schuster_second_deriv_log_l_wrt_marginal_pi_kappa(n, marginal_pi, kappa, v)
  mat_marginal_pi_v <- Schuster_second_deriv_log_l_wrt_marginal_pi_v(n, marginal_pi, kappa, v)
  deriv_kappa <- Schuster_second_deriv_log_l_wrt_kappa_2(n, marginal_pi, kappa, v)
  vec_kappa_v <- Schuster_second_deriv_log_l_wrt_kappa_v(n, marginal_pi, kappa, v)
  mat_v_v <- Schuster_second_deriv_log_l_wrt_v_2(n, marginal_pi, kappa, v)

  index <- 1
  index1 <- 1
  hessian[1:(r-1), 1:(r-1)] <- mat_marginal_pi_marginal_pi
  hessian[1:(r - 1), r] <- t(vec_marginal_pi_kappa)
  hessian[r, 1:(r - 1)] <- vec_marginal_pi_kappa
  hessian[1:(r - 1),(r + 1):(r + n_v)] <- mat_marginal_pi_v
  hessian[(r + 1):(r + n_v), 1:(r - 1)] <- t(mat_marginal_pi_v)
  hessian[r, r] <- deriv_kappa
  hessian[r, (r + 1):(r + n_v)] <- t(vec_kappa_v)
  hessian[(r + 1):(r + n_v), r] <- vec_kappa_v
  hessian[(r + 1):(r + n_v), (r + 1):(r + n_v)] = mat_v_v

  hessian
}


#' Determines whether the candidate pi matrix is valid.
#'
#' All elements must lie in (0, 1)
#' @param pi matrix of model-based proportions
#' @returns logical value indicating whether or not
#' the matrix is valid.
Schuster_is_pi_valid <- function(pi) {
  too_small <- sum(pi <= 0.0)
  too_big <- sum(1.0 <= pi)
  too_small == 0 && too_big == 0
}


#' Computes the Newton-Raphson update
#'
#' Computes both gradient and hessian, and then solves the system of equations
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each category
#' @param kappa current value of kappa coefficient
#' @param v symmetry matrix
#' @returns the vector of updates
Schuster_update <- function(n, marginal_pi, kappa, v) {
  g <- Schuster_gradient(n, marginal_pi, kappa, v)
  h <- Schuster_hessian(n, marginal_pi, kappa, v)
  update <- solve(h, g)
  update
}


#' Performs Newton-Raphson step.
#'
#' The step size is determined to be the largest that yields valid results
#' for all quantities marginal_pi and v.  Both must be positive,
#' and the elements of marginal_pi must be valid proportions that sum to 1.0.
#' @param n matrix of observed counts
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of the kappa coefficient
#' @param v symmetry matrix
#' @returns a list containing updated versions of model quantities
#'    marginal_pi
#'    kappa
#'    v
Schuster_newton_raphson <- function(n, marginal_pi, kappa, v) {
  r <- nrow(n)
  n_v <- (r - 1) * (r - 2) / 2
  v1 <- v
  update <- Schuster_update(n, marginal_pi, kappa, v)
  step <- 1.0
  marginal_pi1 <- marginal_pi
  kappa1 <- kappa
  for (iter in 1:15) {
    marginal_pi1[1:(r - 1)] <- marginal_pi[1:(r - 1)] - step * update[1:(r - 1)]
    marginal_pi1[r] <- 1.0 - sum(marginal_pi1[1:(r - 1)])
    if (sum(marginal_pi1 <= 0.0) > 0 || sum(marginal_pi1 >= 1.0) > 0) {
      step <- step / 2.0
      next
    }
    kappa1 <- kappa - step * update[r]
    index <- r + 1
    for (i1 in 1:(r - 2)) {
      for (j1 in (i1 + 1):(r - 1)) {
        v1[i1, j1] <- v[i1, j1] - step * update[index]
        v1[j1, i1] <- v1[i1, j1]
        index <- index + 1
      }
    }

    v1 <- Schuster_solve_for_v(marginal_pi1, kappa1, v1)

    if (sum(v1 <= 0.0) > 0) {
      step <- step / 2.0
      next
    }

    pi1 <- Schuster_compute_pi(marginal_pi1, kappa1, v1, FALSE)
    if (Schuster_is_pi_valid(pi1)) {
      break
    }
    step <- step / 2.0
  }
  list(marginal_pi=marginal_pi1,kappa=kappa1,v=v1)
}


#' Computes the common diagonal term v-tilde.
#'
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param validate logical. should the value of pi[r,r] be checked for validity?
#' Default is TRUE
#' @returns v-tilde
Schuster_v_tilde <- function(marginal_pi, kappa, validate=TRUE) {
  r <- length(marginal_pi)
  marginal_pi[r] <- 1.0 - sum(marginal_pi[1:(r - 1)])
  if (validate && marginal_pi[r] <= 0.0) {
    stop(paste("out of range estimate for marginal_pi[r]", marginal_pi[r]))
  }
  pe <- sum(marginal_pi^2)
  1.0 + kappa * (1.0 - pe) / pe
}


#' Compute v matrix subject to constraints on rows 1..r-1.
#'
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns new v matrix with last row/column set to agree with constraints.
#' Element v[r, r] is set to v-tilde
Schuster_enforce_constraints_on_v <- function(marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  marginal_pi[r] <- 1.0 - sum(marginal_pi[1:(r - 1)])
  # if (marginal_pi[r] <= 0.0) {
  #   stop(paste("out of range value for marginal_pi[r]", marginal_pi[r]))
  # }

  v_tilde <- Schuster_v_tilde(marginal_pi, kappa)
  for (i in 1:r) {
    v[i, i] <- v_tilde
  }

  # Schuster_solve_for_v(marginal_pi, kappa, v)

  for (i in 1:(r - 1)) {
    sum_product <- 0.0
    for (j in 1:(r - 1)) {
      sum_product <- sum_product + v[i, j] * marginal_pi[j]
    }
    v[i, r] <- (1.0 - sum_product) / marginal_pi[r]
    v[r, i] <- v[i, r]
  }

  for (i in 1:(r - 1)) {
    if (abs(sum(v[i,] * marginal_pi) - 1.0) > 1.0e-6) {
      stop(paste("constraint violated", i, sum(v[i,] * marginal_pi)))
    }
  }
  v
}


#' Solves for the last row and diagonal of symmetry matrix v
#' (v-tilde) using constraint equations
#'
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns revised version of v matrix with last row and diagonal modified
Schuster_solve_for_v <- function(marginal_pi, kappa, v) {
  r <- length(marginal_pi)
  v1 <- v
  v_tilde <- Schuster_v_tilde(marginal_pi, kappa)
  for (i in 1:r) {
    v1[i, i] <- v_tilde
  }

  for (i in 1:(r - 1)) {
    v1[i, r] <- 0.0
    for (a in 1:r) {
      v1[i, r] <- v1[i, r] + marginal_pi[a] * v1[i, a]
    }
    v1[i, r] <- (1.0 - v1[i, r]) / marginal_pi[r]
    v1[r, i] <- v1[i, r]
  }
  v1
}


#' Solves for the last row and diagonal of symmetry matrix v
#' (parameteer v-tilde) using linear algebra formulation from paper.
#'
#' @param marginal_pi expected proportions for each category
#' @param kappa current estimate of kappa coefficient
#' @param v symmetry matrix
#' @returns revised version of v matrix with last row and diagonal modified
Schuster_solve_for_v1 <- function(marginal_pi, kappa, v) {
  r = length(marginal_pi)
  v1 <- v
  A <- matrix(0.0, nrow=r, ncol=r)
  b <- vector("double", r)

  for (i in 1:r) {
    A[i, i] <- marginal_pi[r]
    A[r, i] <- marginal_pi[i]
    A[i, r] <- marginal_pi[i]
  }

  determinant <- marginal_pi[r]^2
  for (i in 1:(r - 1)) {
    sum_product <- 0.0
    for (j in 1:(r - 1)) {
      if (j == i) next
      sum_product <- sum_product + marginal_pi[j] * v[i, j]
    }
    b[i] <- 1.0 - sum_product
    determinant <- determinant - marginal_pi[i]^2
  }
  b[r] <- 1.0

  determinant <- marginal_pi[r]^(r - 2) * determinant

  c <- solve(A, b)

  for (i in 1:(r - 1)) {
    v1[i, r] <- c[i]
    v1[r, i] <- v1[i, r]
  }
  for (i in 1:r) {
    v1[i, i] = c[r]
  }
  v1
}

