nadaraya_watson_multivariate_cv <- function(X, Y, h_grid, n_folds = 10){
  # X is really X[, -i]
  # fold_indices <- sample(rep(1:n_folds, lenght.out = nrow(X)))
  #
  # ### EXTRACDT THIS FUNCTION ###
  # # Compute cross-validation error for a specific bandwidth
  # compute_cv_error <- function(h){
  #   fold_errors <- numeric(n_folds)
  #
  #   for(fold in 1:n_folds){
  #     train_indices <- which(fold_indices != fold)
  #     valid_indices <- which(fold_indices == fold)
  #
  #     X_train <- X[train_indices, ]
  #     Y_train <- Y[train_indices]
  #     X_valid <- X[valid_indices, ]
  #     Y_valid <- Y[valid_indices]
  #     print("hi")
  #     cat(valid_indices, "\n")
  #     cat("dim:", dim(Y_valid))
  #     # Compute predictions for validation set
  #     predictions <- apply(X_valid, 1, function(x){
  #       nadaraya_watsion_multivariate(X_train, Y_train, x, h)
  #     })
  #
  #     # Compute mean squared error for the fold
  #     fold_errors[fold] <- mean((Y_valid - predictions)^2)
  #   } # End inner for
  #
  #   return(mean(fold_errors))
  # } # End compute_cv_error
  # ### END EXTRACT ###
  #
  # # Evaluate CV error for each bandwidth in the grid
  # cv_errors <- sapply(h_grid, compute_cv_error)
  #
  # # Find the optimal bandwidth
  # optimal_h <- h_grid[which.min(cv_errors)]
  #
  # return(optimal_h)

  fold_indices <- sample(rep(1:n_folds, length.out = nrow(X)))

  # Compute cross-validation error for a specific bandwidth
  compute_cv_error <- function(h){
    fold_errors <- numeric(n_folds)

    for(fold in 1:n_folds){
      train_index <- which(fold_indices != fold)
      valid_index <- which(fold_indices == fold)

      X_train <- X[train_index, ]
      Y_train <- Y[train_index]
      X_valid <- X[valid_index, ]
      Y_valid <- Y[valid_index]

      predictions <- apply(X_valid, 1, function(x){
        nadaraya_watson_multivariate(X_train, Y_train, x, h)
      })

      fold_errors[fold] <- mean((Y_valid - predictions)^2)
    }

    return(mean(fold_errors))
  }

  cv_errors <- sapply(h_grid, compute_cv_error)

  optimal_h <- h_grid[which.min(cv_errors)]

  return(optimal_h)
} # End nadaraya_watson_multivariate_cv
