#' fcsim
#' 
#' This function similate differential expression fold change level
#' @param n.gene total number of genes
#' @param de.id index of differentially expressed genes
#' @param fc.loc location parameter for fold change (log-normal distribution)
#' @param fc.scale scale parameter for fold change (log-normal distribution)
#' @references Zappia, L., Phipson, B., & Oshlack, A. (2017). Splatter: Simulation of single-cell RNA sequencing data. Genome Biology, 18(1). https://doi.org/10.1186/s13059-017-1305-0 
#' @import stats
#' 
fcsim <- function (n.gene, de.id, fc.loc, fc.scale){
  n.de <- length(de.id)
  fc <- rlnorm(n.de, fc.loc, fc.scale)
  fc.all <- rep(1, n.gene)
  fc.all[de.id] <- fc
  return(fc.all)
}

#' GeneScape
#' 
#' This function simulate single cell RNAseq data with complicated differential expression and correlation structure.
#' @param nCells number of cells
#' @param nGroups number of cell groups
#' @param groups group information for cells
#' @param lib.size.loc location parameter for library size (log-normal distribution)
#' @param lib.size.scale scale parameter for library size (log-normal distribution)
#' @param de.fc.mat differential expression fold change matrix, could be generated by this function
#' @param nGenes number of genes
#' @param gene.mean.shape shape parameter for mean expression level (Gamma distribution)
#' @param gene.mean.rate rate parameter for mean expression level (Gamma distribution)
#' @param gene.means mean gene expression levels
#' @param de.n number of differentially expressed genes in each cell type. Should be a integer or a vector of length nGroups
#' @param de.share number of shared DE genes between neighbor cell types. Should be a vector of length (nGroups - 1)
#' @param de.id the index of genes that are DE across cell types. Should be a list of vectors. 
#'        Each vector corresponds to a cell type. With non-null value of de.id, de.n and de.share would be ignored.
#' @param de.fc.loc the location parameter for the fold change of DE genes. Should be a number, a vector of length nGroups 
#            or a list of vectors with exactly same structure as de.id (or the structure defined by de.n and de.share)
#' @param de.fc.scale the scale parameter for fold change (log-normal distribution). Should be a number or a vector of length nGroups
#' @param add.sub whether to add sub-cell-types
#' @param sub.major the major cell types correspond to the sub-cell-types
#' @param sub.prop proportion of sub-cell-types in the corresponding major cell type
#' @param sub.group cell index for sub-cell-types. With non-null sub.group specified, sub.prop would be ignored.
#' @param sub.de.n number of differentially expressed genes in each sub-cell-type compared to the corresponding major cell type. Should be a integer or a vector of length sub.major
#' @param sub.de.id the index of additional differentially expressed genes between sub-cell-types and the corresponding major cell types
#' @param sub.de.common whether the additional differential expression structure should be same for all sub-cell-types
#' @param sub.de.fc.loc similar to de.fc.loc, but for addtional differentially expressed genes in sub-cell-types
#' @param sub.de.fc.scale similar to de.fc.scale, but for addtional differentially expressed genes in sub-cell-types
#' @param add.cor whether to add pathways (correlated genes)
#' @param cor.n number of pathways included. Should be a integer.
#' @param cor.size number of correlated genes (length of pathway). Should be a number or a vector of length cor.n
#' @param cor.cor correlation parameters
#' @param cor.id gene index of correlated (pathway) genes. Should be a list of vectors, with each vector represents a pathway. With non-null value of cor.id, cor.n would be ignored.
#' @param band.width No correlation exists if distance of 2 genes are further than band_width in a pathway
#' @param add.hub whether to add hub genes
#' @param hub.n number of hub genes included. Should be a integer.
#' @param hub.size number of genes correlated to the hub gene. Should be a number or a vector of length hub.n
#' @param hub.cor correlation parameters between hub genes and their correlated genes
#' @param hub.id gene index of hub genes. Should be a list of vectors. With non-null value of hub.id, hub.n would be ignored.
#' @param hub.fix user defined genes correlated to hub genes (others are randomly selected). Should be a list of vectors of length hub.n or same as hub.id.
#' @param drop whether to add dropout
#' @param dropout.location dropout mid point (the mean expression level at which the probability is equal to 0.5, same as splat. Could be negative)
#' @param dropout.slope how dropout proportion changes with increasing expression
#' @return A list of observed data, true data (without dropout), differential expression rate and hub gene indices.
#' @references Zappia, L., Phipson, B., & Oshlack, A. (2017). Splatter: Simulation of single-cell RNA sequencing data. Genome Biology, 18(1). https://doi.org/10.1186/s13059-017-1305-0 
#' @details Compared to splat method in Splatter R package, this function can fix the number and position of differentially expressed genes, 
#' have more complicated differential expression structure, add sub-cell-types, correlated genes (AR(1) correlation structure with bound, mimicking pathways) and hub genes.
#' @import corpcor
#' @import MASS
#' @import stats
#' @examples
#' set.seed(1)
#' data <- GeneScape()
#' @export
#' 

GeneScape <- function(nCells = 6000, nGroups = NULL, groups = NULL, lib.size.loc = 9.3, lib.size.scale = 0.25,
                      de.fc.mat = NULL, 
                      nGenes = 5000, gene.mean.shape = 0.3, gene.mean.rate = 0.15, gene.means = NULL,
                      de.n = 50, de.share = NULL, de.id = NULL, de.fc.loc = 0.7, de.fc.scale = 0.2, 
                      add.sub = FALSE, sub.major = NULL, sub.prop = 0.1, sub.group = NULL, sub.de.n = 20, 
                      sub.de.id = NULL, sub.de.common = FALSE, sub.de.fc.loc = 0.7, sub.de.fc.scale = 0.2, 
                      add.cor = FALSE, cor.n = 4, cor.size = 20, cor.cor = 0.7, cor.id = NULL, band.width = 10,
                      add.hub = FALSE, hub.n = 10, hub.size = 20, hub.cor = 0.4, hub.id = NULL, hub.fix = NULL,
                      drop = FALSE, dropout.location = -2, dropout.slope = -1){
  if (is.null(de.fc.mat)){
    if (is.null(nGroups)){
      nGroups <- 4
    }
    
    if (is.null(groups)){
      groups <- rep_len(1:nGroups, nCells)
    } else if ((length(unique(groups)) != nGroups) & (is.null(de.fc.mat))){
      stop("Number of unique groups does not match nGroup.")
    } else if ((max(groups) != nGroups) & (is.null(de.fc.mat))){
      stop("groups is not encoded by integers starting from 1.")
    }
    
    if (length(lib.size.loc) == 1){
      lib.size.loc <- rep(lib.size.loc, nGroups)
    } else if (length(lib.size.loc) != nGroups){
      stop("Length of lib.size.loc does not match nGroup.")
    }
    
    if (length(lib.size.scale) == 1){
      lib.size.scale <- rep(lib.size.scale, nGroups)
    } else if (length(lib.size.scale) != nGroups){
      stop("Length of lib.size.scale does not match nGroup.")
    }
    
    lib.size <- rlnorm(nCells, lib.size.loc[groups], lib.size.scale[groups])   # library size
    if (is.null(gene.means)){
      gene.means <- rgamma(nGenes, shape = gene.mean.shape, rate = gene.mean.rate)   # gene expression level
    } else if (length(gene.means) != nGenes){
      stop("Length of gene.means does not match nGenes.")
    }
    
    if (is.null(de.id)){
      if (length(de.n) == 1){
        de.n <- rep(de.n, nGroups)
      } else if (length(de.n) != nGroups){
        stop("Length of de.n does not match nGroup.")
      }
      
      if (is.null(de.share)){
        de.share <- rep(0, nGroups - 1)
      } else if (length(de.share) != (nGroups - 1)){
        stop("Length of de.share does not match nGroup - 1.")
      }
      
      de.id.set <- sample(1:nGenes, (sum(de.n) - sum(de.share)))
      de.id <- list()
      pos <- 0
      for (i in 1:nGroups){
        if (i == 1){
          de.id[[i]] <- de.id.set[1:de.n[1]]
          pos <- de.n[1]
        } else {
          de.id[[i]] <- de.id.set[(pos - de.share[i-1] + 1):(pos - de.share[i-1] + de.n[i])]
          pos <- pos - de.share[i-1] + de.n[i]
        }
      }
    } else {
      if (length(de.id) != nGroups){
        stop("Wrong size of de.id. It should be a list of length nGroups.")
      } 
    }
    
    if (length(de.fc.scale) == 1){
      de.fc.scale <- rep(de.fc.scale, nGroups)
    } else if (length(de.fc.scale) != nGroups){
      stop("Length of de.fc.scale does not match nGroup.")
    }
    
    if (!is.list(de.fc.loc)){
      if (length(de.fc.loc) == 1){
        de.fc.loc <- rep(de.fc.loc, nGroups)
      } else if (length(de.fc.loc) != nGroups){
        stop("Length of de.fc.loc does not match nGroup.")
      } 
      de.fc.loc.new <- list()
      for (i in 1:nGroups){
        de.fc.loc.new[[i]] <- rep(de.fc.loc[i], length(de.id[[i]]))
      }
      de.fc.loc <- de.fc.loc.new
    }
    
    
    # simulate basic counts
    group.fc <- matrix(1, ncol = nGroups, nrow = nGenes)
    for (idx in seq_len(nGroups)) {
      de.fc <- fcsim(nGenes, de.id[[idx]], de.fc.loc[[idx]], de.fc.scale[idx])
      group.fc[, idx] <- de.fc
    }   # differential expression fold change parameter
    
    group.fc.total <- group.fc
    
    # simulate the sub-cell-types
    if (add.sub){
      if (is.null(sub.major)){
        sub.major <- 1:nGroups
      } else if ((!is.integer(sub.major)) | (max(sub.major) > nGroups)) {
        stop("Index of sub groups' corresponding major cell types exceed nGroup.")
      }
      
      if (is.null(sub.group)){
        for (i in 1:length(sub.major)){
          temp <- which(groups == sub.major[i])
          sub.group[[i]] <- sample(temp, size = round(sub.prop * length(temp)))
        }
      }
      
      if (length(sub.group) != length(sub.major)){
        stop("Index of sub groups' corresponding major cell types does not match sub type cell ids.")
      }
      
      if (is.null(sub.de.id)){
        if (length(sub.de.n) == 1){
          sub.de.n <- rep(sub.de.n, length(sub.major))
        } else if (length(sub.de.n) != length(sub.major)){
          stop("Length of sub.de.n does not match length of sub.major.")
        }
        
        sub.de.id.set <- sample(setdiff(1:nGenes, unique(unlist(de.id))), sum(sub.de.n))
        sub.de.id <- list()
        pos <- 0
        for (i in 1:length(sub.major)){
          if (i == 1){
            sub.de.id[[i]] <- sub.de.id.set[1:sub.de.n[1]]
            pos <- sub.de.n[1]
          } else {
            sub.de.id[[i]] <- sub.de.id.set[(pos + 1):(pos + sub.de.n[i])]
            pos <- pos + sub.de.n[i]
          }
        }
      } else {
        if (length(sub.de.id) != length(sub.major)){
          stop("Wrong size of sub.de.id. It should be a list of length nGroups.")
        } 
      }
      
      
      if (length(sub.de.fc.scale) == 1){
        sub.de.fc.scale <- rep(sub.de.fc.scale, length(sub.major))
      } else if (length(sub.de.fc.scale) != length(sub.major)){
        stop("Length of sub.de.fc.scale does not match length of sub.major.")
      }
      
      if (!is.list(sub.de.fc.loc)){
        if (length(sub.de.fc.loc) == 1){
          sub.de.fc.loc <- rep(sub.de.fc.loc, length(sub.major))
        } else if (length(sub.de.fc.loc) != length(sub.major)){
          stop("Length of sub.de.fc.loc does not match length of sub.major.")
        } 
        sub.de.fc.loc.new <- list()
        for (i in 1:length(sub.major)){
          sub.de.fc.loc.new[[i]] <- rep(sub.de.fc.loc[i], length(sub.de.id[[i]]))
        }
        sub.de.fc.loc <- sub.de.fc.loc.new
      }
      
      sub.group.fc <- matrix(1, ncol = length(sub.major), nrow = nGenes)
      for (i in 1:length(sub.major)){
        sub.group.fc[,i] <- group.fc[, sub.major[i]]
      }
      if (sub.de.common){
        sub.de.fc <- fcsim(nGenes, sub.de.id[[1]], sub.de.fc.loc[[1]], sub.de.fc.scale[1])
        sub.group.fc[sub.de.id[[1]], ] <- matrix(sub.de.fc[sub.de.id[[1]]], nrow = length(sub.de.id[[1]]), ncol = length(sub.major), byrow = F)
      } else {
        for (idx in seq_len(sub.major)) {
          sub.de.fc <- fcsim(nGenes, sub.de.id[[idx]], sub.de.fc.loc[[idx]], sub.de.fc.scale[idx])
          sub.group.fc[sub.de.id[[idx]], idx] <- sub.de.fc[sub.de.id[[idx]]]
        }
      }
      group.fc.total <- cbind(group.fc, sub.group.fc)
    }
  } else {
    if (nrow(de.fc.mat) != nGenes){
      stop("Size of de.fc.mat does not match nGenes.")
    } else if ((is.null(groups)) | (is.null(nGroups)) | 
               (length(sub.major) != (ncol(de.fc.mat) - nGroups))){
      stop("de.fc.mat needs to be provided together with groups, sub.major and nGroups.")
    } else if ((ncol(de.fc.mat) != max(groups)) | (ncol(de.fc.mat) < nGroups)){
      stop("Number of unique groups (including subgroups) does not match de.fc.mat.")
    }
    
    if (length(lib.size.loc) == 1){
      lib.size.loc <- rep(lib.size.loc, nGroups)
    } else if (length(lib.size.loc) != nGroups){
      stop("Length of lib.size.loc does not match nGroup.")
    }
    
    if (length(lib.size.scale) == 1){
      lib.size.scale <- rep(lib.size.scale, nGroups)
    } else if (length(lib.size.scale) != nGroups){
      stop("Length of lib.size.scale does not match nGroup.")
    }
    
    groups.old <- groups
    if (add.sub){
      for (i in 1:length(sub.major)){
        groups.old[groups.old == (i + nGroups)] <- sub.major[i]
      }
    }
    lib.size <- rlnorm(nCells, lib.size.loc[groups.old], lib.size.scale[groups.old])   # library size
    if (is.null(gene.means)){
      gene.means <- rgamma(nGenes, shape = gene.mean.shape, rate = gene.mean.rate)   # gene expression level
    } else if (length(gene.means) != nGenes){
      stop("Length of gene.means does not match nGenes.")
    }
    
    group.fc.total <- de.fc.mat
    add.sub <- FALSE
    
    de.id <- list()
    for (i in 1:nGroups){
      de.id[[i]] <- which(de.fc.mat[,i] != 1)
    }
    
    if (add.sub){
      nsub <- length(sub.major)
      if (nsub > 0){
        sub.de.id <- list()
        for (j in 1:nsub){
          sub.de.id[[j]] <- setdiff(which(de.fc.mat[,j + nGroups] != 1), which(de.fc.mat[,sub.major[j]] != 1))
        }
      } else {
        sub.de.id <- NULL
      }
    }
  }
  
  
  groups.new <- groups
  if (add.sub){
    for (i in 1:length(sub.major)){
      groups.new[sub.group[[i]]] <- i + nGroups
    }
  }
  
  gene.means.fc <- matrix(gene.means, nrow = nGenes, ncol = nCells, byrow = FALSE) * as.matrix(group.fc.total[, groups.new])
  gene.means.scale <- gene.means.fc * matrix(lib.size / colSums(gene.means.fc), nrow = nGenes, ncol = nCells, byrow = TRUE)
  rm(gene.means.fc)
  true.counts <- matrix(rpois(as.numeric(nGenes) * as.numeric(nCells), 
                              lambda = gene.means.scale), nrow = nGenes, ncol = nCells)
  
  
  # simulate the correlated genes
  if (add.cor){
    if (length(cor.size) == 1){
      cor.size <- rep(cor.size, cor.n)
    } else if (length(cor.size) != cor.n){
      stop("Length of cor.size does not match cor.n.")
    }
    
    if (length(cor.cor) == 1){
      cor.cor <- rep(cor.cor, cor.n)
    } else if (length(cor.cor) != cor.n){
      stop("Length of cor.cor does not match cor.n.")
    }
    if (is.null(cor.id)){
      temp <- setdiff(1:nGenes, unique(length(de.id)))
      if (add.sub){
        temp <- setdiff(temp, sub.de.id)
      }
      cor.id.set <- sample(temp, sum(cor.size))
      cor.id <- list()
      pos <- 0
      for (i in 1:cor.n){
        cor.id[[i]] <- cor.id.set[(pos + 1):(pos + cor.size[i])]
        pos <- pos + cor.size[i]
      }
    }
    
    for (j in 1:length(cor.cor)){
      l <- length(cor.id[[j]])
      S <- matrix(0, l, l)
      diag(S) <- 1
      for (i in 1:(l-1)){
        for (k in (i+1):l){
          temp <- cor.cor[j] ^ (k-i)
          if ((k - i) <= band.width){
            S[i, k] <- temp
            S[k, i] <- temp
          }
        }
      } # build covariance matrix
      e1 <- min(eigen(S)$values)
      if (e1 <= 0){
        S <- S + diag(0.1 - e1, nrow(S))
      } # deal with covariance matrix that are not positive definite
      sim <- mvrnorm(nCells, rep(0, l), S)
      psim <- pnorm(sim)
      temp <- qpois(t(psim), gene.means.scale[cor.id[[j]],])
      true.counts[cor.id[[j]],] <- temp
    }
  }
  
  # simulate the hub genes
  hub.list <- list()
  if (add.hub){
    if (length(hub.size) == 1){
      hub.size <- rep(hub.size, hub.n)
    } else if (length(hub.size) != hub.n){
      stop("Length of hub.size does not match hub.n.")
    }
    
    if (length(hub.cor) == 1){
      hub.cor <- rep(hub.cor, hub.n)
    } else if (length(hub.cor) != hub.n){
      stop("Length of hub.cor does not match hub.n.")
    }
    
    if (is.null(hub.fix)){
      hub.fix <- replicate(hub.n, numeric(0))
    } else if ((length(hub.fix) != hub.n) & (length(hub.fix) != length(hub.id))){
      stop("Length of hub.fix does not match hub.n.")
    }
    
    if (is.null(hub.id)){
      temp <- setdiff(1:nGenes, unique(unlist(de.id)))
      if (add.sub){
        temp <- setdiff(temp, unique(unlist(sub.de.id)))
      }
      if (add.cor){
        temp <- setdiff(temp, unique(unlist(cor.id)))
      }
      hub.id <- sample(temp, hub.n)
    }
    
    all.fixed <- unique(c(unlist(de.id), hub.id, unlist(hub.fix)))
    if (add.sub){
      all.fixed <- unique(c(unlist(sub.de.id), all.fixed))
    }
    if (add.cor){
      all.fixed <- unique(c(unlist(cor.id), all.fixed))
    }
    gene.set.left <- setdiff(1:nGenes, all.fixed)
    for (j in 1:length(hub.id)){
      l <- hub.size[j]
      S <- matrix(0, l, l)
      S[1,] <- hub.cor[j]
      S[,1] <- hub.cor[j]
      diag(S) <- 1  # build covariance matrix
      S1 <- make.positive.definite(S) # deal with covariance matrix that are not positive definite
      sim <- mvrnorm(nCells, rep(0, l), S1)
      psim <- pnorm(sim)
      hub.pos <- sample(gene.set.left, l-length(hub.fix[[j]]), replace = FALSE)
      hub.pos <- c(hub.id[j], hub.fix[[j]], hub.pos)
      gene.set.left <- setdiff(gene.set.left, hub.pos)
      hub.list[[j]] <- hub.pos
      temp <- qpois(t(psim), gene.means.scale[hub.pos,])
      true.counts[hub.pos,] <- temp
    }
  }
  
  # add dropout
  if (drop){
    drop.prob <- 1/(1 + exp(- dropout.slope * (gene.means.scale - dropout.location)))
    keep <- matrix(rbinom(nCells * nGenes, 1, 1 - drop.prob), nrow = nGenes, ncol = nCells)
    counts <- true.counts * keep
  } else {
    counts <- true.counts
  }
  
  if (add.cor){
    cor.list <- cor.id
  } else {
    cor.list <- list()
  }
  
  return(list(counts = counts, true.counts = true.counts, groups = groups.new,
  			  library.size = lib.size, gene.mean.exp = gene.means,
              de.fc = group.fc.total, 
              cor.list = cor.list, hub.list = hub.list))
}
