
#' @title Global Low Rank Correction Quantile VB
#' @description
#'  A variational Bayesian algorithm is proposed for multi-source heterogeneous quantile models under the
#'    Spike-and-Slab prior, enabling simultaneous variable selection for both homogeneous and
#'    heterogeneous covariates.
#'
#' @param X Homogeneous covariates
#' @param Z Heterogeneous covariates
#' @param y Response covariates
#' @param tau Quantile Levels
#' @param eps Algorithm convergence tolerance, Defaut:1e-3
#'
#' @return The mean of the homogeneous coefficient:mu;
#'         Low-rank correction of homogeneous coefficients:lr_mu;
#'         The variance of homogeneous coefficient:sigma;
#'         Selection homogeneous coefficient:rho;
#'         The mean of the heterogeneous coefficient:beta;
#'         Low-rank correction of heterogeneous coefficients:lr_beta;
#'         The variance of heterogeneous coefficient:sigma_beta;
#'         Selection heterogeneous coefficient:rho_beta.
#' @export
#' @importFrom glmnet cv.glmnet glmnet
#' @importFrom lava tr
#' @importFrom Rcpp evalCpp
#' @importFrom stats optim coef
#' @useDynLib LRQVB, .registration = TRUE



lr_qvb_global = function(X,Z,y,tau,eps=1e-3){
  bessel <- function(x){
    a = besselK(x,1.5,expon.scaled = FALSE)/besselK(x,0.5,expon.scaled = FALSE)
    return(a)
  }

  cal_elbo = function(mu,sigma,rho,sigma2_qk,beta_qk,r,s,a_pi,b_pi,chi,psi,E_z_inverse,E_z,pi_k){
    f1 = (3*n*K+r_0+2)/2*(digamma(r/2)-log(s/2))
    f2 = 0
    for (k in 1:K) {
      a = (-r/(2*k2^2*s)*(sum(k1^2*E_z[,k]-2*k1*(y[,k]-X[,,k]%*%(mu*rho)-Z[,,k]%*%(beta_qk[,k]*pi_k[,k]))
                              +E_z_inverse[,k]*((y[,k]-X[,,k]%*%(mu*rho)-Z[,,k]%*%(beta_qk[,k]*pi_k[,k]))^2
                                                +as.matrix(diag(X[,,k]%*%sigma%*%t(X[,,k])),ncol = 1)
                                                +as.matrix(diag(Z[,,k]%*%sigma2_qk[,,k]%*%t(Z[,,k])),ncol = 1)))))
      f2 = f2+a
    }

    f3 = -r/s*sum(E_z)-s_0*r/(2*s)
    f4 = -0.5*sum((mu^2+diag(sigma))*rho/tau1+(mu^2+diag(sigma))*(1-rho)/tau0)
    f5 = (sum(rho)+a_0-1)*(digamma(a_pi)-digamma(a_pi+b_pi))+(p-sum(rho)+b_0-1)*(digamma(b_pi)-digamma(a_pi+b_pi))
    elbo1 = f1+f2+f3+f4+f5

    rho = ifelse(rho>=1,1-1e-4,rho)
    rho = ifelse(rho<=0,1e-4,rho)
    g1 = -0.5*sum(log(det(sigma)+1e-3))
    g2 = -0.5*r-log(s/2)+(1+r/2)*digamma(r/2)
    g3 = sum(0.25*(log(psi)-log(chi))-log((2*besselK(sqrt(chi*psi),0.5,expon.scaled = FALSE)))-0.5*(chi*E_z_inverse+psi*E_z))
    g4 = sum(rho*log(rho)+(1-rho)*log(1-rho))
    g5 = (a_pi-1)*(digamma(a_pi)-digamma(a_pi+b_pi))+(b_pi-1)*(digamma(b_pi)-digamma(a_pi+b_pi))
    elbo2 = g1+g2+g3+g4+g5

    elbo = elbo1-elbo2
    return(elbo)
  }



  tau = tau
  k1 = (1-2*tau)/(tau*(1-tau))
  k2 = sqrt(2/(tau*(1-tau)))
  n = nrow(X[,,1])
  p = ncol(X[,,1])
  K = dim(X)[3]
  q = ncol(Z[,,1])



  r_0 = 4
  s_0 = 1
  sigma2_qk = array(1,dim=c(q,q,K))
  beta_qk = matrix(1,nrow=q,ncol=K)
  for (k in 1:K) {
    sigma2_qk[,,k] = diag(q)
  }
  sigma2_0k = sigma2_qk
  sigma = diag(p)
  mu = matrix(0,nrow=p,ncol=1)
  pi =  c(rep(1, 20),rep(0, p-20))
  pi_k = matrix(c(rep(1, 20),rep(0, q-20)),nrow=q,ncol=K)
  tau1k = 1
  tau0k = 1/n
  tau1 = 1
  tau0 = 1/n
  k1 = (1-2*tau)/(tau*(1-tau))
  k2 = sqrt(2/(tau*(1-tau)))
  r_0 = 4
  s_0 = 1
  r = 1
  s = 1
  a_0 = 1
  b_0 = 1
  a_pi = 1
  b_pi = 1
  a_k = rep(1,K)
  b_k = rep(1,K)
  chi = matrix(rep(1,n),nrow = n,ncol = K)
  psi = matrix(rep(1,n),nrow = n,ncol = K)
  E_z_inverse = sqrt(psi/chi)*bessel(sqrt(chi*psi))-1/chi
  E_z = sqrt(chi/psi)*bessel(sqrt(chi*psi))
  elbo_new = 1
  dif = 1
  epsilon = eps



  combine_data <- function(k) {
    list(
      all_X = X[,, k],
      all_Z = Z[,, k],
      all_y = y[, k],
      all_Zm = Z[,, k] %*% beta_qk[, k],
      all_E_z_inverse = E_z_inverse[, k],
      all_E_z = E_z[, k]
    )
  }

  data_list = lapply(1:K, combine_data)


  all_X = do.call(rbind, lapply(data_list, `[[`, "all_X"))
  all_Z = do.call(rbind, lapply(data_list, `[[`, "all_Z"))
  all_y = do.call(c, lapply(data_list, `[[`, "all_y"))
  all_Zm = do.call(c, lapply(data_list, `[[`, "all_Zm"))
  all_E_z_inverse = do.call(c, lapply(data_list, `[[`, "all_E_z_inverse"))
  all_E_z = do.call(c, lapply(data_list, `[[`, "all_E_z"))

  lambda <- (cv.glmnet(all_X,all_y)$lambda.min)*1
  fit <- glmnet(all_X,all_y,lambda=lambda)
  mu <- coef(fit)[-1]
  elbo_new = cal_elbo(mu,sigma,pi,sigma2_qk,beta_qk,r,s,a_pi,b_pi,chi,psi,E_z_inverse,E_z,pi_k)


  while (dif>=epsilon) {
    mu_old = mu
    XEzinv = all_X*as.vector(all_E_z_inverse)

    hess2 = -r/(k2^2*s)*2*t(all_X)%*%(XEzinv)-diag(pi*1/tau1+(1-pi)*1/tau0)
    D_inv = diag(pi*tau1+(1-pi)*tau0)
    grad2 = -r/(k2^2*s)*(t(all_X)%*%(XEzinv)%*%(mu*pi)-t(XEzinv)%*%(all_y-all_Zm)-t(rep(k1,n*K)%*%all_X))-((mu*pi)*(1/tau1*pi+1/tau0*(1-pi)))
    sigma = eigen_inv(-hess2)
    mu = sigma%*%(grad2-hess2%*%(mu*pi))
    pi_1 = (digamma(a_pi)-digamma(b_pi)+0.5*(-log(tau1)+log(tau0))
            +0.5*(-1/tau1+1/tau0)*(mu^2+diag(sigma)))
    pi = 1/(1+exp(-pi_1))
    pi = as.vector(pi)

    a = 0
    all_Zm = c()
    all_E_z_inverse = c()
    all_E_z = c()
    for (k in 1:K) {
      ZEzinv = Z[,,k]*as.vector(E_z_inverse[,k])
      hess2k = -r/(k2^2*s)*2*t(Z[,,k])%*%(ZEzinv)-diag(pi_k[,k]*1/tau1k+(1-pi_k[,k])*1/tau0k)
      D_invk = diag(pi_k[,k]*tau1k+(1-pi_k[,k])*tau0k)
      grad2k = -r/(k2^2*s)*(t(Z[,,k])%*%(ZEzinv)%*%(beta_qk[,k]*pi_k[,k])-t(ZEzinv)%*%(y[,k]-X[,,k]%*%(mu*pi))-t(rep(k1,n)%*%Z[,,k]))-((beta_qk[,k]*pi_k[,k])*(1/tau1k*pi_k[,k]+1/tau0k*(1-pi_k[,k])))
      sigma2_qk[,,k] = eigen_inv(-hess2k)
      beta_qk[,k] = sigma2_qk[,,k]%*%(grad2k-hess2k%*%(beta_qk[,k]*pi_k[,k]))
      pi_1k = (digamma(a_k[k])-digamma(b_k[k])+0.5*(-log(tau1k)+log(tau0k))
               +0.5*(-1/tau1k+1/tau0k)*(beta_qk[,k]^2+diag(sigma2_qk[,,k])))
      pi_k[,k] = 1/(1+exp(-pi_1k))
      chi[,k] = r/(k2^2*s)*((y[,k]-X[,,k]%*%(mu*pi)-Z[,,k]%*%(beta_qk[,k]*pi_k[,k]))^2+diag(X[,,k]%*%sigma%*%t(X[,,k]))+diag(Z[,,k]%*%sigma2_qk[,,k]%*%t(Z[,,k])))
      a = a+lava::tr(t(X[,,k])%*%X[,,k]%*%sigma)+lava::tr(t(Z[,,k])%*%Z[,,k]%*%sigma2_qk[,,k])
      psi = (r/s)*(2+k1^2/k2^2)
      E_z_inverse[,k] = sqrt(psi/chi[,k])*bessel(sqrt(chi[,k]*psi))-1/chi[,k]
      E_z[,k] = sqrt(chi[,k]/psi)*bessel(sqrt(chi[,k]*psi))
      all_Zm = rbind(all_Zm,Z[,,k] %*% (beta_qk[,k]*pi_k[,k]))
      all_E_z_inverse = c(all_E_z_inverse,E_z_inverse[, k])
      all_E_z = c(all_E_z,E_z[, k])
      a_k[k] = sum(pi_k[,k])+a_0
      b_k[k] = q-sum(pi_k[,k])+b_0
    }
    r = r_0+3*n*K
    s =(1/k2^2*(sum(k1^2*all_E_z-2*k1*(all_y-all_X%*%(mu*pi)-all_Zm)
                    +all_E_z_inverse*((all_y-all_X%*%(mu*pi)-all_Zm)^2))
                +a)+2*sum(all_E_z)+s_0)

    a_pi = sum(pi)+a_0
    b_pi = p-sum(pi)+b_0
    elbo = elbo_new
    elbo_new = cal_elbo(mu,sigma,pi,sigma2_qk,beta_qk,r,s,a_pi,b_pi,chi,psi,E_z_inverse,E_z,pi_k)
    dif = abs(elbo-elbo_new)
  }

  all_Zm = c()
  for (k in 1:K) {
    all_Zm = rbind(all_Zm,Z[,,k] %*% (beta_qk[,k]*pi_k[,k]))
  }

  minf <- function(delta){
    mu1 = mu+sigma[,pi>0.5]%*%delta
    ll = r/(k2^2*s)*(t(all_y-all_X%*%(mu1*pi)-all_Zm)%*%(all_E_z_inverse*(all_y-all_X%*%(mu1*pi)-all_Zm))-2*k1*sum(all_y-all_X%*%(mu1*pi)-all_Zm))
    kl = sum((0.5/tau1*(mu1^2+diag(sigma)))*pi)+sum((0.5/tau0*(mu1^2+diag(sigma)))*(1-pi))
    return(ll+kl)
  }
  d_est = optim(rep(0,sum(pi>0.5)), minf)
  lambda = d_est$par
  mu1 = mu + sigma[,pi>0.5]%*%lambda

  beta1 = matrix(0,nrow=q,ncol=K)
  minf_k <- function(delta){
    beta1[,k] = beta_qk[,k]+sigma2_qk[,pi_k[,k]>0.5,k]%*%delta
    ll = r/(k2^2*s)*(t(y[,k]-X[,,k]%*%(mu1*pi)-Z[,,k]%*%(beta1[,k]*pi_k[,k]))%*%(E_z_inverse[,k]*(y[,k]-X[,,k]%*%(mu1*pi)-Z[,,k]%*%(beta1[,k]*pi_k[,k])))-2*k1*sum(y[,k]-X[,,k]%*%(mu1*pi)-Z[,,k]%*%(beta1[,k]*pi_k[,k])))
    kl = sum((0.5/tau1*(beta1[,k]^2+diag(sigma2_qk[,,k])))*pi_k[,k])+sum((0.5/tau0*(beta1[,k]^2+diag(sigma2_qk[,,k])))*(1-pi_k[,k]))
    return(ll+kl)
  }

  for (k in 1:K) {
    d_est = optim(rep(0,sum(pi_k[,k]>0.5)), minf_k)
    lambda = d_est$par
    beta1[,k] = beta_qk[,k] + sigma2_qk[,pi_k[,k]>0.5,k]%*%lambda
  }
  return(list(mu=mu,rho=pi,lr_mu=mu1,sigma=sigma,beta=beta_qk,lr_beta=beta1,sigma_beta=sigma2_qk,rho_beta=pi_k))
}


