library(testthat)
library(dplyr)
library(dbplyr)
library(duckdb)
library(S7)

describe("lm_sql() - SQL-Backed Linear Regression", {

    # -------------------------------------------------------------------------
    # Test Setup
    # -------------------------------------------------------------------------
    con <- DBI::dbConnect(duckdb::duckdb())

    db_mtcars <- copy_to(con, mtcars, "mtcars_test", overwrite = TRUE)

    iris_df <- iris %>%
        mutate(Species = as.character(Species))
    db_iris <- copy_to(con, iris_df, "iris_test", overwrite = TRUE)

    na_df <- mtcars
    na_df[1, "hp"] <- NA
    na_df[2, "wt"] <- NA
    db_na <- copy_to(con, na_df, "na_test", overwrite = TRUE)

    date_df <- data.frame(
        y = c(10, 20, 30, 40, 50),
        dt = as.Date(c("2020-01-01", "2020-01-02", "2020-01-03", "2020-01-04", "2020-01-05")),
        x = c(1, 2, 3, 4, 5)
    )
    db_date <- copy_to(con, date_df, "date_test", overwrite = TRUE)

    withr::defer(DBI::dbDisconnect(con, shutdown = TRUE))

    # -------------------------------------------------------------------------
    # Numeric Regression Tests
    # -------------------------------------------------------------------------

    it("replicates base R lm() results for simple numeric regression", {
        fit_sql <- lm_sql(mpg ~ wt + hp + qsec, data = db_mtcars)
        fit_base <- lm(mpg ~ wt + hp + qsec, data = mtcars)

        expect_equal(
            fit_sql@coefficients, coef(fit_base),
            tolerance = 1e-5, ignore_attr = TRUE
        )
        expect_equal(fit_sql@r_squared, summary(fit_base)$r.squared, tolerance = 1e-5)
        expect_equal(
            fit_sql@std_error,
            summary(fit_base)$coefficients[, "Std. Error"],
            tolerance = 1e-5, ignore_attr = TRUE
        )
    })

    it("handles Intercept = FALSE correctly with matching R²", {
        fit_sql <- lm_sql(mpg ~ 0 + wt, data = db_mtcars)
        fit_base <- lm(mpg ~ 0 + wt, data = mtcars)

        expect_equal(fit_sql@coefficients, coef(fit_base), tolerance = 1e-5)
        expect_false("(Intercept)" %in% names(fit_sql@coefficients))
        expect_equal(fit_sql@r_squared, summary(fit_base)$r.squared, tolerance = 1e-5)
    })

    # -------------------------------------------------------------------------
    # Categorical / Factor Handling Tests
    # -------------------------------------------------------------------------

    it("automatically dummifies categorical variables (scout method)", {
        fit_sql <- lm_sql(`Sepal.Length` ~ `Sepal.Width` + Species, data = db_iris)
        fit_base <- lm(Sepal.Length ~ Sepal.Width + Species, data = iris)

        sql_coefs <- fit_sql@coefficients
        # Normalize: Species_versicolor -> Speciesversicolor
        names(sql_coefs) <- gsub("Species_", "Species", names(sql_coefs))

        expect_equal(sql_coefs, coef(fit_base), tolerance = 1e-5)
    })

    it("handles categorical interaction terms (Sepal.Width * Species)", {
        fit_sql <- lm_sql(`Sepal.Length` ~ `Sepal.Width` * Species, data = db_iris)
        fit_base <- lm(Sepal.Length ~ Sepal.Width * Species, data = iris)

        # Compare values — names will differ in format so just check numerically
        expect_equal(unname(fit_sql@coefficients), unname(coef(fit_base)), tolerance = 1e-5)
        expect_equal(length(fit_sql@coefficients), length(coef(fit_base)))

    })

    # -------------------------------------------------------------------------
    # Interaction Tests
    # -------------------------------------------------------------------------

    it("calculates interaction terms correctly", {
        fit_sql <- lm_sql(mpg ~ wt * am, data = db_mtcars)
        fit_base <- lm(mpg ~ wt * am, data = mtcars)

        expect_equal(fit_sql@coefficients, coef(fit_base), tolerance = 1e-5)
    })

    # -------------------------------------------------------------------------
    # Dot Expansion
    # -------------------------------------------------------------------------

    it("expands dot (y ~ .) to all columns", {
        # Use a subset to avoid near-singular matrix with full mtcars
        sub_df <- mtcars[, c("mpg", "wt", "hp", "qsec")]
        db_sub <- copy_to(con, sub_df, "sub_test", overwrite = TRUE)

        fit_sql <- lm_sql(mpg ~ ., data = db_sub)
        fit_base <- lm(mpg ~ ., data = sub_df)

        expect_equal(
            fit_sql@coefficients, coef(fit_base),
            tolerance = 1e-5, ignore_attr = TRUE
        )
    })

    # -------------------------------------------------------------------------
    # Grouped regression
    # -------------------------------------------------------------------------

    it("returns a tibble with model column for grouped data", {
        db_grouped <- db_mtcars %>% group_by(am)
        result <- lm_sql(mpg ~ wt + hp, data = db_grouped)

        expect_s3_class(result, "tbl_df")
        expect_true("model" %in% names(result))
        expect_equal(nrow(result), 2)  # am = 0 and am = 1

        # Check each group matches base R
        for (i in seq_len(nrow(result))) {
            am_val <- result$am[i]
            sql_fit <- result$model[[i]]
            base_fit <- lm(mpg ~ wt + hp, data = mtcars[mtcars$am == am_val, ])
            expect_equal(sql_fit@coefficients, coef(base_fit), tolerance = 1e-4)
        }
    })

    # -------------------------------------------------------------------------
    # Robustness & Methods
    # -------------------------------------------------------------------------

    it("performs listwise deletion for NAs", {
        fit_sql <- lm_sql(mpg ~ hp + wt, data = db_na)
        expect_equal(fit_sql@nobs, 30)

        fit_base <- lm(mpg ~ hp + wt, data = na_df)
        expect_equal(fit_sql@coefficients, coef(fit_base), tolerance = 1e-5)
    })

    it("throws error if passed a local data frame", {
        expect_error(
            lm_sql(mpg ~ wt, data = mtcars),
            "Requires a remote dbplyr 'tbl' object"
        )
    })

    # -------------------------------------------------------------------------
    # Broom / S7 Method Tests
    # -------------------------------------------------------------------------

    it("supports broom::tidy() with correct structure", {
        fit_sql <- lm_sql(mpg ~ wt, data = db_mtcars)
        tidied <- broom::tidy(fit_sql, conf.int = TRUE)

        expect_s3_class(tidied, "tbl_df")
        expect_true(all(c("term", "estimate", "std.error", "conf.low") %in% names(tidied)))
        expect_equal(nrow(tidied), 2)
    })

    it("supports broom::glance() with correct structure", {
        fit_sql <- lm_sql(mpg ~ wt, data = db_mtcars)
        glanced <- broom::glance(fit_sql)

        expect_s3_class(glanced, "tbl_df")
        expect_equal(glanced$nobs, 32)
        expect_equal(glanced$df, 1)
    })

    # -------------------------------------------------------------------------
    # Date Variable Tests
    # -------------------------------------------------------------------------

    it("handles Date predictor matching base R lm()", {
        fit_sql <- lm_sql(y ~ dt, data = db_date)
        fit_r   <- lm(y ~ as.numeric(dt), data = date_df)

        expect_equal(unname(fit_sql@coefficients), unname(coef(fit_r)), tolerance = 1e-6)
        expect_equal(fit_sql@r_squared, summary(fit_r)$r.squared, tolerance = 1e-6)
    })

    it("handles Date predictor alongside numeric predictor", {
        date_df2 <- data.frame(
            y = c(12, 25, 28, 43, 55),
            dt = as.Date(c("2020-01-01", "2020-01-02", "2020-01-03", "2020-01-04", "2020-01-05")),
            x = c(1, 3, 2, 5, 4)
        )
        db_date2 <- copy_to(con, date_df2, "date_test2", overwrite = TRUE)

        fit_sql <- lm_sql(y ~ dt + x, data = db_date2)
        fit_r   <- lm(y ~ as.numeric(dt) + x, data = date_df2)

        expect_equal(unname(fit_sql@coefficients), unname(coef(fit_r)), tolerance = 1e-6)
    })
})
