species <- c("Drer", "Olat", "Eluc", "Ssal", "Salp", "Omyk", "Okis")
salmonids <- c("Ssal","Salp","Omyk","Okis")

spcTree <- read.tree("input-data/from_ortho_pipeline/SpeciesTree_rooted_node_labels.txt")
tree <- keep.tip(spcTree, species)

#### load salmonid data
combExprMat <- readRDS("data/BSNormalize/combExprMat.RDS")
mat <- as.matrix(combExprMat)
hasNA <- apply(is.na(mat),1,any) # remove genes with NA values
noVar <- apply(mat,1,var)==0  # remove genes with no variance
gene.data <- mat[!(hasNA | noVar), ]

dSpc <- ape::cophenetic.phylo(tree)
# generate a table of mean divergence for each sample pair given an expressien matrix m
distTbl <- function(m){
  d <- dist(t(m))^2/nrow(m)
  sampleIDs <- colnames(m)
  tibble(d = as.vector(d),
         sampleA = unlist(lapply(1:length(sampleIDs), function(i){rep(sampleIDs[i],length(sampleIDs)-i)})),
         sampleB = unlist(lapply(1:(length(sampleIDs)-1), function(i){sampleIDs[(i+1):length(sampleIDs)]})))
}

# # Fit an exponential function by minimizing sum of squares error
fitExp <- function(df, dSpcNew){
  dExpFun <- function(par,dSpc){
    par[1]*(1-exp(-(par[2]*dSpc))) + par[3]
  }

  ssFun <- function(par,df){
    df %>%
      mutate( dExp = dExpFun(par,dSpc)) %>%
      with( sum((d-dExp)^2) )
  }

  optRes <- optim(par=c(a=6.3,b=25,intersect=1.2),fn = ssFun, df= df)

  tibble( dSpc=dSpcNew) %>%
    mutate( d = dExpFun(optRes$par,dSpc))
}

# Calculate mean expression per species
spcMean <- function(m){
  sapply(split(1:ncol(m),colnames(m)),function(i) rowMeans(m[,i]))
}
df <- 
  gene.data %>%
  spcMean() %>%
  distTbl() %>% 
  mutate( dSpc = dSpc[cbind(sampleA,sampleB)])

df %>% 
  mutate( noSalmonids = !(sampleA %in% salmonids) & !(sampleB %in% salmonids),
          onlySalmonids = sampleA %in% salmonids & sampleB %in% salmonids,
          label = paste(sampleA, sampleB, sep = ":")) %>% 
  ggplot(aes(y=d,x=dSpc)) +
  geom_line( data=fitExp(df,dSpcNew = seq(0,0.6,length.out = 100))) +
  geom_point(alpha=0.5, size = 3, mapping = aes(color=onlySalmonids), shape = 16) + 
  scale_colour_manual(name = "Only salmonids", values = c("#FF8C00", "#708090"), limits = c(TRUE, FALSE)) +
  labs(y="Mean expression divergence",x="Evolutionary distance (substitutions per site)") +
  guides(colour = "none") +
  theme_bw() +
  theme(text = element_text(size = 8, colour = "black"))

# ggsave("Figures/suppFigDivergenceCurve.pdf", width = 3, height = 3, units = "in")
# Params
runs <- 10
time <- 100  # time
sig <- 1

OUsim <- function(n,sigma,alpha,theta,x0){
  bm <- rnorm(n,sd = sigma)
  x <- c(x0,bm)
  for(i in 1:n){
    x[i+1] = x[i] + x[i+1] + alpha*(theta-x[i])
  }
  return(x)
}

# BM a=0
BMmatrix <- t(replicate(runs, OUsim(n = time, sigma = sig,alpha = 0,theta = 0, x0=0)))
# BMmatrix = matrix(nrow = runs, ncol = time+1)
# for(k in 1:runs){
#   X <- as.vector(ornstein_uhlenbeck(T=time, n=time, nu=0, lambda=0, sig, x0=0))
#   BMmatrix[k,] = X
# }
colnames(BMmatrix) <- c(0:time)

# OU a>0
OUmatrix <- t(replicate(runs, OUsim(n = time, sigma = sig,alpha = 0.1,theta = 0, x0=0)))
# OUmatrix = matrix(nrow = runs, ncol = time+1)
# for(k in 1:runs){
#   X <- as.vector(ornstein_uhlenbeck(T=time, n=time, nu=0, lambda=0.1, sig, x0=0))
#   OUmatrix[k,] = X
# }
colnames(OUmatrix) <- c(0:time)

# Combine tables for plotting
BMtable <- tbl_df(BMmatrix) %>%
  mutate(model = "BM", run = paste("BM run", 1:runs)) %>%
  gather(x, y, 0:time+1) %>%
  mutate(x = as.integer(x)) %>%
  arrange(run)

OUtable <- tbl_df(OUmatrix) %>%
  mutate(model = "OU", run = paste("OU run", 1:runs)) %>%
  gather(x, y, 0:time+1) %>%
  mutate(x = as.integer(x)) %>%
  arrange(run)

modelTable <- bind_rows(BMtable, OUtable)
  
modelTable %>%
  ggplot(aes(x = x, y = y, group = run, colour = model)) +
  geom_line(alpha=0.6, size = 0.5) +
  scale_colour_manual(values = c("#018571", "#a6611a"), limits = c("OU", "BM")) +
  guides(colour = "none") +
  labs(x = "Time (t)", y = "Expression (x)") +
  theme_bw() + 
  theme(panel.grid = element_blank(),
        axis.text = element_blank(),
        axis.ticks = element_blank(),
        text = element_text(size = 8, colour = "black"))

#ggsave("Figures/suppFigureModelSimulations.pdf", width = 3, height = 3, units = "in")