## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----echo = T, message=FALSE--------------------------------------------------
library(baytaAAR)
library(ggplot2)
library(bayesplot)
library(tidybayes)
library(flexsurv)

## ----ordered probit regression, echo = FALSE, message=FALSE, fig.width=6, fig.height=4, fig.align = 'center'----
library(flexsurv)

# Thresholds and trait labels
trans <- c(0.5, 1.5, 2.5, 3.5)
trait_stages <- data.frame(
  stage = 1:5,
  trait = c("first", "second", "third", "fourth", "fifth")
)

# Model params
beta <- 2
beta0 <- -0.5 - beta * log(10)  # So that red line starts at (10, -0.5)
sigma <- 0.5
k <- 3

# Compute exact ages for thresholds
ages <- exp((trans - beta0) / beta)

# Compute true means at those ages using the red regression line
mu_vals <- beta0 + beta * log(ages)

# Function to create closed polygon paths for green Gaussians
make_gaussian_polygon <- function(x0, mu_y, clip_to_x_axis = FALSE) {

  x_seq <- seq(-k * sigma, k * sigma, length.out = 100)
  y_vals <- dnorm(x_seq, 0, sigma)
  y_scaled <- y_vals / max(y_vals) * 3

  x_vals <- x0 + y_scaled
  y_vals_shifted <- mu_y + x_seq

  if (clip_to_x_axis) {

    keep <- y_vals_shifted >= -0.5

    x_vals <- x_vals[keep]
    y_vals_shifted <- y_vals_shifted[keep]

    base_y <- -0.5

  } else {

    base_y <- mu_y - k * sigma

  }

  data.frame(
    x = c(x_vals, x0),
    y = c(y_vals_shifted, base_y),
    group = x0
  )
}

# Create green polygon data
paths_list <- mapply(function(x0, mu_y) {
  clip <- x0 == min(ages)
  make_gaussian_polygon(x0, mu_y, clip_to_x_axis = clip)
}, x0 = ages, mu_y = mu_vals, SIMPLIFY = FALSE)
paths <- do.call(rbind, paths_list)

# Latent trait density polygon (shifted down by 0.5)
latent_z <- seq(0, 5.5, length.out = 300)
latent_density <- dnorm(latent_z, mean = 2.5, sd = 1.2)
latent_x <- 10 - latent_density * 25
latent <- data.frame(
  x = c(latent_x, 10),
  y = c(latent_z - 0.5, min(latent_z - 0.5)),
  group = 1
)

# Gompertz curve (shifted down by 0.5)
age_seq <- seq(10, 100, length.out = 400)
gompertz_density <- flexsurv::dgompertz(age_seq - 10, shape = 0.06, rate = 0.002)
gompertz_density_scaled <- gompertz_density / max(gompertz_density) * 2
gompertz <- data.frame(
  x = c(age_seq, rev(age_seq)),
  y = c(-gompertz_density_scaled - 0.4, rep(-0.5, length(age_seq))),
  group = 1
)

# Axes (x-axis now at y = -0.5)
x_axis <- data.frame(x = c(10, 100), y = c(-0.5, -0.5))
y_axis <- data.frame(x = c(10, 10), y = c(-2.5, 5.5))

# Arrow for equation label
arrow_x <- 50
arrow_y <- beta0 + beta * log(arrow_x)
arrow_df <- data.frame(
  x = c(arrow_x, arrow_x),
  y = c(-0.5, arrow_y),
  xend = c(arrow_x, 10),
  yend = c(arrow_y, arrow_y)
)

# Threshold lines
thresholds <- data.frame(y = trans)

age_seq <- seq(10, 100, length.out = 400)

reg_df <- data.frame(
  x = age_seq,
  y = beta0 + beta * log(age_seq)
)

x_breaks <- seq(20, 90, by = 10)

tick_df <- data.frame(
  x = x_breaks,
  xend = x_breaks,
  y = -0.5,
  yend = -0.6   # small downward ticks
)

# Plot
  ggplot() +
  geom_line(data = reg_df,
            aes(x = x, y = y),
            color = "grey",
            linewidth = 1) +
  geom_hline(data = thresholds, aes(yintercept = y), linetype = "dashed", color = "grey50") +
  annotate("text", x = 1, y = c(trans, 4.5) - 0.5, label = trait_stages$trait, hjust = 0, size = 4) +

  geom_polygon(data = paths, aes(x = x, y = y, group = group), fill = "darkgreen", alpha = 0.4) +
  geom_polygon(data = latent, aes(x = x, y = y, group = group), fill = "steelblue", alpha = 0.4) +
  geom_polygon(data = gompertz, aes(x = x, y = y), fill = "lightgrey") +

  #geom_segment(data = x_axis, aes(x = x[1], y = y[1], xend = x[2], yend = y[2]), linewidth = 0.5) +
  #geom_segment(data = y_axis, aes(x = x[1], y = -0.5, xend = x[2], yend = y[2]), linewidth = 0.5) +
  annotate("segment", x = 10, xend = 100, y = -0.5, yend = -0.5) +
  annotate("segment", x = 10, xend = 10, y = -0.5, yend = 5.5) +
  geom_segment(data = arrow_df, aes(x = x, y = y, xend = xend, yend = yend),
               arrow = arrow(length = unit(0.2, "cm")), linewidth = 0.4) +

  annotate("text", x = arrow_x - 20, y = arrow_y + 0.3,
           label = expression(mu == beta[0] + beta[1]*log(age)), size = 5) +

  xlab("Age-at-death (years)") + ylab("Latent trait variable") +
  coord_cartesian(xlim = c(5, 90), ylim = c(-2.5, 4.25)) +
   # coord_cartesian(xlim = c(5, 90), ylim = c(-0.5, 4.25)) +
    geom_segment(data = tick_df,
                 aes(x = x, xend = xend, y = y, yend = yend),
                 linewidth = 0.4) +
    geom_text(data = data.frame(x = x_breaks),
              aes(x = x, y = -0.75, label = x),
              size = 4) +
    annotate("text",
             x = mean(range(x_breaks)),
             y = -1.2,
             label = "Age-at-death (years)",
             size = 5) +
  theme_minimal(base_size = 14) +
  theme(
    axis.text.y = element_blank(),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    axis.title.y = element_text(margin = margin(t = 0, r = 20, b = 0, l = 0), hjust = 0.7),
    axis.text.x  = element_blank(),
    axis.title.x = element_blank(),
    axis.ticks.x = element_blank()
  )

## ----sorsum data, echo = T----------------------------------------------------
data(sorsum_as, package = "baytaAAR")
head(sorsum_as)

## ----sorsum plot, echo = F, fig.width=4, fig.height=3, fig.align = 'center'----
ggplot(sorsum_as, aes(x =auricular_surface)) + geom_histogram(binwidth = 0.5) + 
  theme_light() +
  scale_x_continuous(breaks=seq(1,8,1),minor_breaks=NULL) + 
  xlab("\nauricular surface") + ylab("count\n")

## ----sorsum analysis, echo = TRUE---------------------------------------------
sorsum_as_res <- bay.ta(
  method = sorsum_as[,2],
  minimum_age = 18,
  thinSteps = 1000,
  numSavedSteps = 1000,
  seed = 1234
)

## ----sorsum diagnostics, echo = TRUE------------------------------------------
sorsum_as_res_diag <- diagnostic.summary(sorsum_as_res) 
sorsum_as_res_diag |> head(10) |> knitr::kable(digits = 4)

## ----sorsum diagnostic table, echo = T----------------------------------------
diagnostics.max.min(sorsum_as_res_diag)

## ----sorsum trace plots, echo = T---------------------------------------------
bayesplot::color_scheme_set("viridis")
bayesplot::mcmc_trace(sorsum_as_res, 
                      pars = c("age.s[1]", "beta[1]"), n_warmup = 300,
                      facet_args = list(nrow = 1, labeller = label_parsed))

## ----sorsum gelman plots, echo = T--------------------------------------------
coda::gelman.plot(sorsum_as_res[, c("b", "beta[1]")])

## ----sorsum thresholds, echo = TRUE-------------------------------------------
thresholds <- threshold.chains(sorsum_as_res)
thresh_diag <- diagnostic.summary(thresholds)
threshold.matrix(thresh_diag) |> data.frame() |> knitr::kable(digits = 1)

## ----thresholds plot, echo = TRUE, message=FALSE, warning=F, fig.width=5, fig.height=6, fig.align = 'center'----
bayesplot::color_scheme_set("gray")
bayesplot::mcmc_areas_ridges(thresholds, prob = 0.8, point_est = c("median"), 
                      border_size = 0.2) + 
  theme_light() + xlim(18,100) + labs(x = "\nage-at-death (years)")

## ----sorsum agerange, echo = TRUE---------------------------------------------
age.estim.summary(sorsum_as_res_diag) |> knitr::kable(digits = 3)

## ----sorsum ages, echo = TRUE, warning=F, fig.width=5, fig.height=6, fig.align = 'center'----
sorsum_as_res |> tidybayes::spread_draws(age.s[age_number])  |> 
  subset(age_number < 8) |>
  ggplot(aes(y =  as.factor(age_number), x = age.s)) +
  tidybayes::stat_halfeye(
    .width = 0.95, point_interval = mode_hdi, fill = "lightgrey") + 
  scale_x_continuous(breaks = seq(10,100,10), limits = c(18, 100)) +
  labs( x = "\nModal age-at-death (years)", y = "Individual no.\n" ) + 
  theme_light() +
  theme(panel.grid.minor.x = element_blank(), text = element_text(size = 12))

## ----sorsum Gompertz plot, echo = TRUE, fig.width=6, fig.height=4, fig.align = 'center'----
ggplot() + ylab("density\n") + 
  geom_function(fun =  function(x) 
    flexsurv::dgompertz(x - 18, sorsum_as_res_diag["b",3],
                        sorsum_as_res_diag["a",3])) +
  xlab("\nAge in years") + theme_light() +
  scale_x_continuous(breaks = seq(10,100,10), limits = c(18, 100)) +
  theme(panel.grid.minor.x = element_blank(), text = element_text(size = 12))

## ----sorsum analysis jags, echo = TRUE, eval=FALSE----------------------------
# sorsum_as_res <- bay.ta(
#   framework = "JAGS",
#   method = sorsum_as[,2],
#   minimum_age = 18,
#   thinSteps = 100,
#   numSavedSteps = 5000,
#   seed = 1234
# )

