Illustration of Adaptive Spline

Vivien Goepp

2022-06-07

Setup

library(aspline)
library(ggplot2)
library(dplyr)
library(tidyr)
library(mgcv)
library(splines2)
data(helmet)
x <- helmet$x
y <- helmet$y

Fit Aspline

k <- 40
knots <- seq(min(x), max(x), length = k + 2)[-c(1, k + 2)]
pen <- 10 ^ seq(-4, 4, 0.25)
x_seq <- seq(min(x), max(x), length = 1000)
aridge <- aspline::aspline(x, y, knots, pen)
a_fit <- lm(y ~ splines2::bSpline(x, knots = aridge$knots_sel[[which.min(aridge$ebic)]]))
X_seq <- splines2::bSpline(x_seq, knots = aridge$knots_sel[[which.min(aridge$ebic)]], intercept = TRUE)
a_basis <- (X_seq %*% diag(coef(a_fit))) %>%
  as.data.frame() %>%
  dplyr::mutate(x = x_seq) %>%
  tidyr::pivot_longer(cols = paste0("V", 1:9), names_to = "spline_n", values_to = "y") %>%
  dplyr::filter(y != 0)
a_predict <- dplyr::data_frame(x = x_seq, pred = predict(a_fit, data.frame(x = x_seq)))
#> Warning: `data_frame()` was deprecated in tibble 1.1.0.
#> Please use `tibble()` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
ggplot2::ggplot() +
  ggplot2::geom_point(data = helmet, aes(x, y), shape = 1) +
  ggplot2::geom_line(data = a_predict, aes(x, pred), size = 0.5) +
  ggplot2::geom_line(data = a_basis, aes(x, y, group = spline_n), linetype = 1, size = 0.1) +
  ggplot2::theme(legend.position = "none") +
  ggplot2::ylab("") +
  ggplot2::xlab("")

Fit P-Splines

p_fit <- mgcv::gam(y ~ s(x, bs = "ps", k = length(knots) + 3 + 1, m = c(3, 2)))
X <- splines2::bSpline(x_seq, knots = knots, intercept = TRUE)
p_basis <- (X %*% diag(coef(p_fit))) %>%
  as.data.frame() %>%
  dplyr::mutate(x = x_seq) %>%
  tidyr::pivot_longer(cols = paste0("V", 1:9), names_to = "spline_n", values_to = "y") %>%
  dplyr::as_tibble() %>%
  dplyr::filter(y != 0)
p_predict <- dplyr::data_frame(x = x_seq, pred = predict(p_fit, data.frame(x = x_seq)))
ggplot2::ggplot() +
  ggplot2::geom_point(data = helmet, aes(x, y), shape = 1) +
  ggplot2::geom_line(data = p_predict, aes(x, pred), size = 0.5) +
  ggplot2::geom_line(data = p_basis, aes(x, y, group = spline_n), linetype = 1, size = 0.1) +
  ggplot2::theme(legend.position = "none") +
  ggplot2::ylab("") + ggplot2::xlab("")