新型肺炎COVID-19が今後2022年まで流行が続くというのをrstanで再現しようとしたが断念した

読んだ。
Projecting the transmission dynamics of SARS-CoV-2 through the postpandemic period. - PubMed - NCBI
COVID-19がこのままだと2022年も続いているのではないか、といろいろ話題になった論文。
githubがおいてあるが、ここにあるデータは
GitHub - c2-d2/CoV-seasonality: Code for regression model from Kissler, Tedijanto et al. 2020.

コロナウイルスの流行具合をCDCのNREVSSというデータベースから2015-2019年の間までデータを使った、というが、肝心のデータは2018年4月以降しかもう置いてないっぽい。
NREVSS | Coronavirus National Trends | CDC
https://www.cdc.gov/surveillance/nrevss/images/coronavirus/Corona4PP_Nat.htm
こちらの過去の報告の図からデータを引っこ抜ければ、2014-2017年はデータが取れるが、めんどうなのでやめた。そもそも再現できないじゃん…
Human coronavirus circulation in the United States 2014-2017. - PubMed - NCBI

コロナウイルスの検査数と、インフルエンザ様症状(ILI)の受診割合でかけることで、週ごとの(\beta型)コロナウイルスの頻度、としている。日ごとのR_u

R_u=\displaystyle\sum_{t=u}^{u+i_{max}}\frac{b(t)g(t-u)}{\displaystyle\sum_{a=0}^{i_{max}}b(t-a)g(a)}
b(t):時刻tの株(strain)ごとの発生率
g(a):時刻a のgeneration interval。論文では平均8.4日、標準偏差3.8日のweibull 分布を使っているが、githubでは他の報告の対数正規分布やガンマ分布も使っている。
i_{max}:generation interval が累積確率分布で99%になる日。論文では19日としている。
github ではfunc.WTという関数で定義されているが、各株のデータのはいっているdata.frameを入力するようなので、vectorを入力するような感じに書き換えた。

func.SI_pull <- function(value, serial_int){
  if(serial_int == "SARS"){
    return(dweibull(value, shape=SARS_shape, scale=SARS_scale))
  } else if(serial_int == "nCoV_Li"){
    return(dgamma(value, shape=nCoV_Li_shape, scale=nCoV_Li_scale))
  } else if(serial_int == "nCoV_Nishiura"){
    return(dweibull(value, shape=nCoV_Nishiura_shape, scale=nCoV_Nishiura_scale))
  }
}
func.WT <- function(week_list, daily_inc, proxy_name, org_list, serial_int){
  df.Reff_results <- data.frame(Week_start=week_list) # Initialize dataframe to hold results
  stop <- get(paste0(serial_int, "_stop")) 
  # Estimate daily R effective
  for(v in org_list){ #For each strain
    RDaily <- numeric()
    for(u in 1:(length(week_list)*7)){ #Estimate for each day
      sumt <- 0
      for(t in u:(u+stop)){ #Look ahead starting at day u through (u+max SI)
        suma <- 0 
        for(a in 0:(stop)){ #Calc denominator, from day t back through (t-max SI)
          suma = daily_inc[t-a,v]*func.SI_pull(a, serial_int) + suma
        } 
        sumt = (daily_inc[t,v]*func.SI_pull(t-u, serial_int))/suma + sumt
      }
      RDaily[u] = sumt
    }   
    # Take geometric mean over daily values for current week, and one week before + after (for smoothing), to get weekly estimates
    RWeekly <- numeric()
    for(i in 1:length(week_list)) RWeekly[i] <- geoMean(as.numeric(RDaily[max(1,7*(i-2)+1):min(length(RDaily),7*(i+1))]), na.rm=TRUE)
    RWeekly[1:3] <- NA #Set results for first 3 weeks to NA (unreliable using WT)
    RWeekly[(length(RWeekly)-2):length(RWeekly)] <- NA #Set results for last 3 weeks to NA (unreliable using WT)
    df.Reff_results[,paste0(v,"_R_",proxy_name)] <- RWeekly
  } 
  return(df.Reff_results)
}
Ru <- function(vec, pdf="weibull", pars=c(2.35, 9.48), alpha=0.99){
  Imax <- eval(parse(text=sprintf("ceiling(q%s(%f, %f, %f))", pdf, alpha, pars[1], pars[2])))
  g <- function(x) eval(parse(text=sprintf("d%s(%f, %f, %f)", pdf, x, pars[1], pars[2])))
  res <- mapply(function(u)  sum(
         mapply(function(t) (vec[t]*g(t-u))/sum(mapply(function(a) vec[t-a]*g(a), 0:Imax)), u:(u+Imax))
         ),  (Imax+1):(length(RDaily)-Imax))
  NAs <- rep(NA, Imax)
  RDaily <- c(NAs, res, NAs)
  RWeekly <- mapply(function(i) geoMean(RDaily[max(1,7*(i-2)+1):min(length(RDaily),7*(i+1))], na.rm=TRUE), 1:(ceiling(length(RDaily)/7))
  RWeekly[1:3] <- RWeekly[(length(RWeekly)-2):length(RWeekly)] <- NA
  retuen(list(RDaily, RWeekly))
}

感染の流行はtwo-strain SEIRS モデルというものを使っている。要は図のSEIRS x two strain の16コンポーネント微分方程式になるのだが、解けるわけがないので前にやったとおりrstanで推定してみる。
mikuhatsune.hatenadiary.com
16コンポーネントはこんな感じになる。
f:id:MikuHatsune:20200419224151p:plain

さて、このSEIRSモデルに\mu, \beta(t), \nu, \chi_{12}, \chi_{21}, \gamma, \delta_1, \delta_2 の8つのパラメータが設定してある。ここで、\beta(t) だけが時間依存のパラメータで、再生産係数R_0 に対して
\beta(t)=\gamma R_0(t)
R_0=\max(R_0)\left[\frac{f}{2}\cos(\frac{2\pi}{52}(t-\phi))+(1-\frac{f}{2})\right]
と、R_0(t)に依存することになる。
これをrstanで頑張ってみようと思ったが、rstanで微分方程式を作ったときに、パラメータは各時刻tですべて一定でしか(自分の力量では)出来ないので、今回の再現では\betaは一様分布からサンプリングされる決め打ちのシミュレーションになってしまった。誰か時間依存\beta(t) のままやれる方法を教えてほしい。f:id:MikuHatsune:20200419224943p:plain
16コンパートメントもあるとゴチャゴチャするので、googlevisも作ってみた。

(htmlが破壊されてレイアウトが変になったので削除しました)


うまく再現できなかったがこれも引っ張り過ぎると他のことが進まなくなるのでここらへんで終了。

# SEIRS two-strain 16 compornents
lab <- c("S", "E", "I", "R")
labs <- c(outer(paste0(lab, 1), paste0(lab, 2), paste0))
library(igraph)
mat <- as.matrix(read.table(text="
0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0
0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0
0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0
0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0
0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0
0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0
0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0
0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0
0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0
0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1
1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0
0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0
0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1
0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0
", header=FALSE))
rownames(mat) <- colnames(mat) <- labs
g <- graph_from_adjacency_matrix(mat)

rot <- -pi/4
lay <- cbind(rep(c(-3,-1,1,3), each=4), rep(c(3,1,-1,-3), 4))
lay <- t(matrix(c(cos(rot), sin(rot), -sin(rot), cos(rot)), 2) %*% t(lay))


par(mar=rep(0.5, 4))
V(g)$label.cex <- 1.5
V(g)$label.font <- 2
V(g)$size <- 25
E(g)$curved <- 0 #0.05
E(g)$curved[c(7, 25, 31:32)] <- c(1, -1) #0.05
E(g)$curved[c(15, 27, 23, 29)] <- c(1, -1)*0.7 #0.05
plot(g, layout=lay)

# SEIRS two strain
library(rstan)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

code <- "
functions {
   real[] sir(
      real t,
      real[] y,
      real[] p, // mu beta nu x12 x21 gamma delta1 delta2
      real[] x_r,
      int[] x_i
   ) {
     // S1S2 E1S2 I1S2 R1S2 1-4
     // S1E2 E1E2 I1E2 R1E2 5-8
     // S1I2 E1I2 I1I2 R1I2 9-12
     // S1R2 E1R2 I1R2 R1R2 13-16
      real dydt[16];
      dydt[1] = -p[1]*y[1]*y[3]-p[2]*y[1]*(y[3]+y[7]+y[11]+y[15])-p[2]*y[1]*(y[9]+y[10]+y[11]+y[12])+p[7]*y[4]+p[8]*y[13];
      dydt[2] = p[1]*y[1]*y[3]-p[1]*y[2]-p[3]*y[2]-(1-p[4])*p[2]*y[2]*(y[9]+y[10]+y[11]+y[12])+p[8]*y[14];
      dydt[3] = p[3]*y[2]-p[6]*y[3]-p[1]*y[3]-(1-p[4])*y[3]*(y[9]+y[10]+y[11]+y[12])+p[8]*y[15];
      dydt[4] = -(p[1]+p[7]+(1-p[4])*(y[9]+y[10]+y[11]+y[12]))*y[4]+p[6]*y[3]+p[8]*y[16];
      dydt[5] = -(p[1]+p[3]+(1-p[5])*p[2]*(y[3]+y[7]+y[11]+y[15]))*y[5]+p[2]*(y[9]+y[10]+y[11]+y[12])*y[1]+p[7]*y[8];
      dydt[6] = -(p[1]+2*p[3])*y[6]+p[2]*((1-p[4])*(y[9]+y[10]+y[11]+y[12])*y[2]+(1-p[5])*(y[3]+y[7]+y[11]+y[15])*y[5]);
      dydt[7] = -(p[1]+p[3]+p[6])*y[7]+p[3]*y[6]+(1-p[4])*p[2]*(y[9]+y[10]+y[11]+y[12])*y[3];
      dydt[8] = -(p[1]+p[3]+p[7])*y[8]+p[6]*y[7]+(1-p[5])*p[2]*(y[9]+y[10]+y[11]+y[12])*y[4];
      dydt[9] = -(p[1]+p[6]+(1-p[6])*p[2]*(y[3]+y[7]+y[11]+y[15]))*y[9]+p[7]*y[12]+p[3]*y[5];
      dydt[10]= -(p[1]+p[3]+p[6])*y[7]+p[3]*y[6]+(1-p[5])*p[2]*(y[3]+y[7]+y[11]+y[15])*y[9];
      dydt[11]= -(p[1]+2*p[6])*y[11]+p[3]*(y[7]+y[10]);
      dydt[12]= -(p[1]+p[6]+p[7])*y[12]+p[3]*y[8]+p[6]*y[11];
      dydt[13]= -(p[1]+p[8]+(1-p[6])*p[2]*(y[3]+y[7]+y[11]+y[15]))*y[13]+p[6]*y[9]+p[7]*y[16];
      dydt[14]= -(p[1]+p[3]+p[8])*y[14]+p[6]*y[10]+(1-p[6])*p[2]*(y[3]+y[7]+y[11]+y[15])*y[13];
      dydt[15]= -(p[1]+p[6]+p[8])*y[15]+p[3]*y[14]+p[6]*y[11];
      dydt[16]= -(p[1]+p[7]+p[8])*y[16]+p[6]*(y[12]+y[15]);
      return dydt;
   }
}
data {
   int<lower=1> T; // 時間の数
   real T0; // time 0
   real TS[T]; // 各時刻
   real Y0[16]; // Y の初期値 SEIR
   real m01[2]; // mu   の範囲
   real b01[2]; // beta の範囲
   real n01[2]; // nu   の範囲
   real x12[2]; // x12  の範囲
   real x21[2]; // x21  の範囲
   real g01[2]; // gamma  の範囲
   real d1[2]; // delta1 の範囲
   real d2[2]; // delta2 の範囲
}
transformed data {
   real x_r[0];
   int x_i[0];
}
parameters {
   real<lower=m01[1], upper=m01[2]> mu;
   real<lower=b01[1], upper=b01[2]> beta;
   real<lower=n01[1], upper=n01[2]> nu;
   real<lower=x12[1], upper=x12[2]> X12;
   real<lower=x21[1], upper=x21[2]> X21;
   real<lower=g01[1], upper=g01[2]> gamma;
   real<lower=d1[1], upper=d1[2]> delta1;
   real<lower=d2[1], upper=d2[2]> delta2;
}
transformed parameters {
   real y_hat[T,16];
   real theta[8];
   theta[1] = mu;
   theta[2] = beta;
   theta[3] = nu;
   theta[4] = X12;
   theta[5] = X21;
   theta[6] = gamma;
   theta[7] = delta1;
   theta[8] = delta2;
   y_hat = integrate_ode(sir, Y0, T0, TS, theta, x_r, x_i);
}
model {
  // 何もしない
}
"

m0 <- stan_model(model_code=code)
y0 <-  replace(rep(0, 16), 1, 0) # 初期値
y0[c(3, 7, 11, 15, 9, 10, 12)] <- 0.0005
y0 <- replace(y0, 1, 1-sum(y0))
times <- seq(0, 100, by=1)[-1] # 時刻 0.2 刻みはどういうこと?

# 各パラメータを振る範囲
lim <- list(m01=1/c(81, 79), b01=1/c(14, 1), n01=1/c(14, 3), x12=c(0, 1), x21=c(0, 1), g01=1/c(14, 3), d1=1/c(100, 25), d2=1/c(100, 25))

standata <- c(list(T=length(times), T0=0, TS=times, Y0=y0), lim)
fit <- sampling(m0, standata, iter=200, warmup=100, chain=2)
ex <- extract(fit, pars=head(fit@model_pars, -1))

cols <- matlab::jet.colors(length(labs))
i <- 23 # たまたまこのシミュレーションが見栄えがよかったので
matplot(ex$y_hat[i,,], type="l", lty=1, col=cols, lwd=2, xlab="time", ylab="proportion")
legend("topright", legend=labs, ncol=4, col=cols, pch=15, bg="white")

# googlevis で可視化
library(googleVis)
library(stringr)
gmat <- cbind.data.frame(Time=times, ex$y_hat[i,,])
colnames(gmat) <- c("Times", labs)
url <- "https://science.sciencemag.org/content/early/2020/04/14/science.abb5793"
webpage <- "Science. 2020 Apr 14. pii: eabb5793. doi: 10.1126/science.abb5793."
headertxt <- sprintf("<span><a href='%s' target='_blank'>%s</a></span> on %s", url, webpage, Sys.Date())
# googlevis 用のプロットオプション
ops <- list(
  fontSize=12, height=600, width=600, lineWidth=1,
  vAxis="{title:'Proportion', fontSize: 20, titleTextStyle: {fontSize: 20, bold: 'true', italic: 'false'}, viewWindow: {min: 0, max: 1}, gridlines: {color: '#00F', minSpacing: 100}, logScale: 'false'}",
  sizeAxis="{minValue: -80, maxValue: 75, maxSize: 15}",
  hAxis="{title:'Time', fontSize: 20, titleTextStyle: {fontSize: 20, bold: 'true', italic: 'false'}, viewWindow: {min: 0, max: 105}}",
  explorer="{actions: ['dragToZoom', 'rightClickToReset']}",
  focusTarget="category",
  selectionMode="multiple",
  legend="{position: 'right', textStyle: {color: 'black', fontSize: 10}, pageIndex: 1}"
)
gvis <- gvisLineChart(gmat, options=ops)
gvis$html$caption <- NULL
gvis$html$footer <- str_replace(gvis$html$footer,
                                "<span> \n  .+\n  .+\n</span>",
                                headertxt)
plot(gvis)