rstanで打ち切りデータがあるときのパラメータ推定をする

これに、打ち切りデータがあるときの平均値の推定問題がある。

StanとRでベイズ統計モデリング (Wonderful R)

StanとRでベイズ統計モデリング (Wonderful R)

データが正規分布に従うのだろうが、L=25を下回るデータは<25となっているので、このデータを無視して平均値を推定すると、本当にデータがあったときより過大評価するだろう、ということ。
打ち切りの場合に<25を無視するのではなく、<25となったときに(-\infty, L]区間までの(累積)確率分布を考慮すれば、少しはマシな推定になる、という話。
オリジナルのコードは

data {
  int N_obs;
  int N_cens;
  real Y_obs[N_obs];
  real L;
}
parameters{
  real<lower=0> m;
  real<lower=0> s_Y;
}
model{
  for(n in 1:N_obs){
    Y_obs[n] ~ normal(m, s_Y);
  }
  target += N_cens * normal_lcdf(L | m, s_Y);
}

となっているが、打ち切りのデータにインデックスをつけてforを回すパターンにした。また、サンプリングはY ~ normal() でもよいが、target 記法で合わせてみようと思って
サンプリングの部分はtarget += _lpdf(data | parameters)(そのデータをサンプリングする確率密度lpdf
打ち切りの部分はtarget += _lcdf(censor | parameters)L までの累積確率分布lcdf
となる。

rstan で推定した、正規分布のパラメータ\mu の事後分布である。この事後分布の最頻値は、シミュレーションで設定した真の値 50 より(たまたま)小さい。L は40 に設定したが、L を下回った値は<40 と記録されてしまうとして、このデータを除外して平均を推定した場合はOmit の縦線となり、Censor の縦線より大きくなっている。
f:id:MikuHatsune:20200405230844p:plain

さて、rstanで打ち切りデータを考慮した場合と、除外して平均を推定した場合とで、どれくらい推定値が変わるかをシミュレーションした。
除外した場合は、真の値よりほとんどの場合で大きい値を推定してしまう。
打ち切りを考慮した場合は、だいたい真の値を推定しているようだった。
f:id:MikuHatsune:20200405230827p:plain

library(rstan)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())


# 偏差値っぽい感じで
m_true <- 50
s_true <- 10
n <- 30
y <- rnorm(n, m_true, s_true)

L <- 40
idx <- which(y < L)
y0 <- replace(y, idx, L)


code <- "
data {
  int N;
  real Y[N];
  int I[N];
  real L;
}
parameters{
  real<lower=0> m;
  real<lower=0> s;
}
model{
  for(i in 1:N){
    if(I[i] == 0){ // データがある場合
      target += normal_lpdf(Y[i] | m, s);
    } else {       // 打ち切りの場合
      target += normal_lcdf(L | m, s);
    }
  }
}
"

m0 <- stan_model(model_code=code)
standata <- list(N=length(y), Y=y, I=replace(rep(0, length(y)), y < L, 1), L=L)
fit <- sampling(m0, standata, iter=1000, warmup=300, chain=4)
ex <- extract(fit, pars=head(fit@model_pars, -1))


library(vioplot)
par(mar=c(3, 2, 3, 2))
plot(y, rep(0.99, n), pch=16, col=2, xlab="", ylab="", ylim=c(0.95, 1.5), yaxt="n")
vioplot(ex$m, horizontal=TRUE, side="right", ylim=range(y), add=TRUE)
pa <- par()$usr
abline(v=m_true); text(m_true, pa[4], "True", xpd=TRUE, srt=0, pos=3)
abline(v=median(ex$m)); text(median(ex$m), pa[4], "Censor", xpd=TRUE, srt=0, pos=3, offset=1.5)
abline(v=mean(y[y>L])); text(mean(y[y>L]), pa[4], "Omit", xpd=TRUE, srt=0, pos=3)
abline(v=L); text(L, pa[4], "L", xpd=TRUE, srt=0, pos=3)


# 除外した平均と、rstanで推定した平均の違いをシミュレーション
iter <- 1000
M <- matrix(0, iter, 2)
for(i in 1:iter){
  y <- rnorm(n, m_true, s_true)
  standata <- list(N=length(y), Y=y, I=replace(rep(0, length(y)), y < L, 1), L=L)
  fit <- sampling(m0, standata, iter=1000, warmup=300, chain=4)
  ex <- extract(fit, pars=head(fit@model_pars, -1))
  M[i, ] <- c(mean(y[y > L]), median(ex$m))
}


xylim <- range(M)
par(mar=c(5, 5, 3, 3), cex.lab=1.5, cex.axis=1.2)
plot(M, xlim=xylim, ylim=xylim, xlab="L 以下を除外した平均", ylab="累積確率を考慮した平均", pch=16)
abline(v=m_true, h=m_true, lty=3, col=4)
abline(0, 1, lty=3, col=4)
vioplot(M[,1], at=par()$usr[4], add=TRUE, side="right", xpd=TRUE, horizontal=TRUE, wex=3)
vioplot(M[,2], at=par()$usr[2], add=TRUE, side="right", xpd=TRUE, wex=3)

新型肺炎COVID-19の重症化率を推定する

読んだ。
Eurosurveillance | Estimating the infection and case fatality ratio for coronavirus disease (COVID-19) using age-adjusted data from the outbreak on the Diamond Princess cruise ship, February 2020
ダイヤモンド・プリンセス号のデータを使って、感染した場合の重症化率(というか死亡率)を推定する。結論から言うと全年齢層で1.3%、70歳以上だと6.4%になる。
しかしながら、各日に報告された感染者数と死亡者数から単純に計算した重症化率 naive case fatality ratio (nCFR) は、すべての感染者の転機がわからず、死亡が観測されるまでの時差によるので、この時差 delay を考慮した推定を行う。
この推定は、時刻t の(報告された)感染者数c_t、感染もしくは入院から死亡までt の時差がある割合f_t(訳に自信なし)(これは対数正規分布になっている)として、実際に得ている観測数に対しての過小評価分u_t
u_t=\frac{\displaystyle\sum_{i=0}^t\displaystyle\sum_{j=0}^\infty c_{i-j}f_t}{\displaystyle\sum_{i=0}^t c_j}
を計算することで出来る、とあるが、githubにあるスクリプトを見ても
u_t=\frac{\displaystyle\sum_{i=0}^t\displaystyle\sum_{j=0}^{\require{color}\textcolor{red}{i}} c_{i-j}f_t}{\displaystyle\sum_{i=0}^t c_{\require{color}\textcolor{red}{i}}}
な気がするのだが、よくわからない。
GitHub - thimotei/cCFRDiamondPrincess


ダイヤモンド・プリンセス号の死亡者は70歳以上のカテゴリにしかいないので、70歳以上のみにわけたときの重症化率も推定している。

やってみた結果、

Age group cIFR (95% CI) cCFR (95% CI)
All ages 1.299 (0.525, 2.657) 2.601 (1.049, 5.323)
70歳以上 6.445 (2.621, 12.785) 12.910 (5.259, 25.611)

となり、70歳以上と、cCFR (corrected CFR) はよいが、全年齢のcIFR (corrected infection fatality ratio) は信頼区間がずれた。

ちなみに、対数正規分布のパラメータは平均mと中央値Mが与えられていて、これを対数正規分布用のパラメータに変換するには\mu=\log(M)\sigma=\sqrt{2(\log(M)-\mu)} としているが、なんか違うっぽい。
RPubs - 対数正規分布から、平均値(中央値)と標準偏差を指定して乱数を生成させる
図の見た目はあっている。
f:id:MikuHatsune:20200331231832p:plain

正規分布のパラメータ
zMean <- 13
zMedian <- 9.1
propSymptomatic <- 619/309
propOver70Cases <- 124/619
propOver70Deaths <- 7/7
mu <- log(zMedian)
sigma <- sqrt(2*(log(zMean) - mu))

x <- seq(0, 40, length=300)
plot(x, dlnorm(x, mu, sigma), type="l", lwd=5,
     xlab="Days after hospitalisation", ylab="P(death on a given day|death)")


# 一部改変した
scale_cfr <- function(data_1_in, mu, sigma){
  case_incidence <- data_1_in$new_cases
  death_incidence <- data_1_in$new_deaths
  cumulative_known_t <- rep(0, nrow(data_1_in)) # cumulative cases with known outcome at time tt
  # calculating CDF between each of the days to determine the probability of death on each day

  for(ii in 1:nrow(data_1_in)){
    known_i <- 0 # number of cases with known outcome at time ii
    for(jj in 0:(ii - 1)){
      known_jj <- case_incidence[ii - jj]*(plnorm(jj, mu, sigma) - plnorm(jj - 1, mu, sigma))
      known_i <- known_i + known_jj
    }
    cumulative_known_t[ii] <- known_i # Tally cumulative known
  }
  # naive CFR value
  b_tt <- cumsum(death_incidence)/cumsum(case_incidence) 
  # corrected CFR estimator
  p_tt <- cumsum(death_incidence)/cumsum(cumulative_known_t)
  data.frame(nCFR = b_tt, cIFR = p_tt, total_deaths = cumsum(death_incidence), 
             cum_known_t = cumulative_known_t, total_cases = cumsum(case_incidence))
}

newData <- read.table(text="
         date new_cases new_deaths
1  2020-02-05        10          0
2  2020-02-06        10          0
3  2020-02-07        41          0
4  2020-02-08         3          0
5  2020-02-09         6          0
6  2020-02-10        65          0
7  2020-02-11         0          0
8  2020-02-12        39          0
9  2020-02-13        44          0
10 2020-02-14         0          0
11 2020-02-15        67          0
12 2020-02-16        70          0
13 2020-02-17        99          0
14 2020-02-18        88          0
15 2020-02-19        79          0
16 2020-02-20        13          2
17 2020-02-21         0          0
18 2020-02-22         0          0
19 2020-02-23         0          1
20 2020-02-24         0          0
21 2020-02-25         0          1
22 2020-02-26        71          0
23 2020-02-27         0          0
24 2020-02-28         0          2
25 2020-02-29         0          0
26 2020-03-01         0          1
27 2020-03-02         0          0
28 2020-03-03         0          0
29 2020-03-04         0          0
30 2020-03-05        -9          0
")

ageCorrectedDatNew <- cbind.data.frame(date=newData$date,
                                       new_cases=round(newData$new_cases * propOver70Cases),
                                       new_deaths=round(newData$new_deaths * propOver70Deaths))

res0 <- scale_cfr(newData, mu, sigma)
res1 <- scale_cfr(ageCorrectedDatNew, mu, sigma)
res <- list("all_age"=res0, "over_70"=res1)

ci <- mapply(function(z) binom.test(tail(z$total_deaths, 1), round(sum(z$cum_known_t)))$conf.int[1:2], res)
ci <- rbind(mapply(function(z) tail(z$cIFR, 1), res), ci)

# 無症候性陽性患者を考慮した CRF
ci * propSymptomatic

新型肺炎COVID-19 の感染陽性患者数の過小報告分をrstanで推定する

読んだ。
Ascertainment rate of novel coronavirus disease (COVID-19) in Japan | medRxiv
ascertainment rate という、感染者数(PCR陽性ベース)がどれくらいか、つまり、1だと実際の報告数が潜在的な患者数と同一で、>1だと過剰に報告されている、<1だと過小報告されている、と考えるならば、0.44(95%CI 0.37-0.50)で、軽症患者については実際の患者数は2倍くらい(実際の0.44分しか報告されていない)だろう、ということらしい。

2020年2月28日までの疫学データから、RT-PCRで陽性確定となった患者がほぼ毎日厚生労働省のHPで見れるので、それを都道府県、10歳きざみの年齢、そして重症度でカウントする。重症度の定義は、酸素療法を要して肺炎もしくは挿管されている患者か、ICUに入室した患者、となっている。

厚生労働省のページをいちいち見に行ってもよいが、広島県のHPが時系列で厚生労働省の発表をまとめているので、そこから毎日データを見に行って確認した。2020年2月28日までは症例198まで番号がついている。論文ではNは言及がないが、図を見ると足すと200人前後っぽいのでたぶんいいのかもしれない。
患者の発生状況等 - 広島市公式ホームページ

都道府県ベースで10歳きざみのデータが報告されている。そして各自治体レベルで、重症かそうかの報告がされている。ただし、「胸部X線およびCTで肺炎像」というのがはたしてすべて重症肺炎かというとそうでもない。抗生剤さえ開始すれば酸素療法はいらない肺炎は多数存在する。そもそもコロナウイルスは抗生剤が効くようなものではないので、細菌性肺炎が合併するならまだしも、抗生剤はいらない。
重症度については各自治体で報告様式が異なっている。例えば東京都は、「症例○は重篤である」と書いてあるし、北海道では「症例○は非侵襲性呼吸補助療法を要した」「挿管された」などと書いてある。
図からはsevere な症例について21症例がみてとれるが、今回自分でデータを漁ったところ、12症例しかわからなかった。ただし、ただの肺炎(画像診断で肺炎、となっている症例)は除外したのでこれが過小カウントになってしまったかもしれない。
都道府県ごとの人口は平成26年度版の5歳階級データを拝借した。
https://www.e-stat.go.jp/dbview?sid=0003104197

モデルとしては、各都道府県xの各年齢階級aの人口N_{x,a}について、非重症患者数D_{n}と重症患者数D_s (これらは各都道府県の各年齢階級のインデックスがつく)について、確率p_{x,a}ポアソン過程で患者が発生する、としている。
ここで、非重症患者は重症患者よりf_a の割合で患者報告数が多い(これは年齢階級のみのインデックス)、というパラメータをいれている。ここのパラメータは中国の初期段階での患者報告データから流用している。おそらくtable 1である。
Clinical Characteristics of Coronavirus Disease 2019 in China. - PubMed - NCBI

さて、肝心のascertainment rate はk として、対数尤度関数を
ll=\displaystyle\sum_x\displaystyle\sum_a\ln[\frac{(N_{x,a}kf_ap_{x,a})^{D_{n,x,a}}exp(-N_{x,a}kf_a p_{x,a})}{D_{n,x,a}!}\frac{(N_{x,a}p_{x,a})^{D_{s,x,a}}exp(-N_{x,a}p_{x,a})}{D_{s,x,a}!}]
と定義するが、これはポアソン分布の確率密度関数
P(X=k|\lambda)=\frac{\lambda^k exp(-\lambda)}{k!}
であり、非重症患者では\lambda \gets N_{x,a}kf_ap_{x,a}、重症患者では\lambda \gets N_{x,a}p_{x,a} とすればよい。

ここまで出来たのでrstanでやってみる。
結論から言うとascertainment rate k の推定はそれなりによかった。

     2.5%       50%     97.5% 
0.3640505 0.4345467 0.5170327 

しかし、肝心の推定患者数はグダグダだった。非重症患者はもとより、重症患者が異常に多く推定されてしまった。
f:id:MikuHatsune:20200329220208p:plain

グダグダだった理由として、2020年2月28日までの都道府県別年齢階級別のデータがほんとうに論文で解析したデータとあっているのか不明だし、f_a の値も不明だし、重症患者の定義が

severe dyspnea that required oxygen support plus pneumonia or intubation

と書いてあるが、肺炎であることが必ずしも重症ではないし、重症患者のカウントの仕方が不明だった。

こんな短い論文ですら再現出来ないのだから自分の技量は(ここで記事が途絶えている

# 都道府県別の年齢階級別の人口
pop <- read.table(text="
         10歳未満 10代 20代 30代 40代 50代 60代 70代 80代
北海道        397  462  511  643  738  693  853  641  461
青森県         95  123  109  147  172  182  210  163  120
岩手県         98  119  108  144  160  172  196  160  128
宮城県        191  217  256  305  316  296  329  239  180
秋田県         70   88   73  110  123  144  172  139  118
山形県         88  105   92  127  136  151  174  135  123
福島県        149  190  171  221  243  265  293  218  185
茨城県        239  279  281  363  411  366  444  324  213
栃木県        165  186  189  253  280  254  301  207  145
群馬県        163  194  179  239  283  242  296  221  159
埼玉県        603  673  778  953 1137  858 1031  808  398
千葉県        505  560  636  798  953  738  907  715  385
東京都       1020 1041 1662 2072 2220 1583 1611 1339  841
神奈川県      760  824  999 1246 1497 1080 1186  957  547
新潟県        179  213  201  268  300  295  356  272  228
富山県         84   99   88  125  148  128  167  130  101
石川県         97  112  115  139  161  138  171  127   97
福井県         68   78   70   92  104  100  116   89   75
山梨県         66   83   79   94  118  108  122   95   76
長野県        176  204  171  245  289  258  306  250  211
岐阜県        174  203  190  240  283  246  303  238  165
静岡県        316  349  331  452  528  460  548  431  292
愛知県        682  727  835 1012 1147  858  976  773  446
三重県        154  177  170  219  257  224  263  211  149
滋賀県        135  145  154  185  202  167  191  139   97
京都府        208  238  301  321  371  294  374  303  200
大阪府        723  824  950 1131 1366  999 1231 1047  566
兵庫県        472  532  542  674  815  664  798  634  411
奈良県        110  134  133  156  193  166  210  168  105
和歌山県       75   91   81  102  130  123  150  125   94
鳥取県         48   55   49   67   71   74   88   65   60
島根県         57   63   54   76   82   87  110   85   82
岡山県        165  184  194  230  255  223  279  220  173
広島県        247  265  276  346  396  332  417  319  234
山口県        111  128  118  155  178  167  229  182  143
徳島県         58   68   65   87   96   97  122   92   78
香川県         82   92   84  116  130  118  154  114   93
愛媛県        112  128  117  159  180  175  219  167  138
高知県         54   66   57   81   92   92  119   92   84
福岡県        455  472  557  655  688  609  736  536  381
佐賀県         77   86   76   97  102  107  124   90   77
長崎県        117  132  117  150  171  185  216  163  136
熊本県        160  171  166  207  217  230  265  202  177
大分県         98  106  105  135  144  145  183  140  116
宮崎県        100  108   91  127  133  146  173  129  107
鹿児島県      148  157  143  188  195  224  250  190  172
沖縄県        166  162  157  188  195  182  168  116   85
", check.name=FALSE
) * 1000

# 非重症患者
infec <- read.table(text="
         10歳未満 10代 20代 30代 40代 50代 60代 70代 80代
北海道          4    2    5    5    8    9   13   10    7
青森県          0    0    0    0    0    0    0    0    0
岩手県          0    0    0    0    0    0    0    0    0
宮城県          0    0    0    0    0    0    0    0    0
秋田県          0    0    0    0    0    0    0    0    0
山形県          0    0    0    0    0    0    0    0    0
福島県          0    0    0    0    0    0    0    0    0
茨城県          0    0    0    0    0    0    0    0    0
栃木県          0    0    0    0    0    0    1    0    0
群馬県          0    0    0    0    0    0    0    0    0
埼玉県          0    0    0    1    0    0    0    0    0
千葉県          0    0    2    0    1    2    4    2    0
東京都          0    0    1    2    4    7    2    8    2
神奈川県        0    0    3    2    2    6    2    1    4
新潟県          0    0    0    0    0    0    1    0    0
富山県          0    0    0    0    0    0    0    0    0
石川県          0    0    0    0    0    2    1    0    0
福井県          0    0    0    0    0    0    0    0    0
山梨県          0    0    0    0    0    0    0    0    0
長野県          0    0    0    0    0    0    1    0    0
岐阜県          0    0    0    0    0    1    0    0    0
静岡県          0    0    0    0    0    0    1    0    0
愛知県          0    0    1    0    2    2   14    8    1
三重県          0    0    0    0    0    1    0    0    0
滋賀県          0    0    0    0    0    1    0    0    0
京都府          0    0    2    0    0    0    0    0    0
大阪府          0    0    0    0    5    1    0    0    0
兵庫県          0    0    0    0    0    0    0    0    0
奈良県          0    0    0    0    0    0    1    0    0
和歌山県        0    0    0    1    1    5    2    1    1
鳥取県          0    0    0    0    0    0    0    0    0
島根県          0    0    0    0    0    0    0    0    0
岡山県          0    0    0    0    0    0    0    0    0
広島県          0    0    0    0    0    0    0    0    0
山口県          0    0    0    0    0    0    0    0    0
徳島県          0    0    0    0    0    0    0    0    0
香川県          0    0    0    0    0    0    0    0    0
愛媛県          0    0    0    0    0    0    0    0    0
高知県          0    0    0    1    0    0    0    0    0
福岡県          0    0    0    0    0    0    2    0    0
佐賀県          0    0    0    0    0    0    0    0    0
長崎県          0    0    0    0    0    0    0    0    0
熊本県          0    0    1    0    0    2    2    0    0
大分県          0    0    0    0    0    0    0    0    0
宮崎県          0    0    0    0    0    0    0    0    0
鹿児島県        0    0    0    0    0    0    0    0    0
沖縄県          0    0    0    0    0    0    2    0    1
", check.name=FALSE
)

# 重症患者
infec_severe <- read.table(text="
         10歳未満 10代 20代 30代 40代 50代 60代 70代 80代
北海道          0    0    1    0    0    1    0    0    2
青森県          0    0    0    0    0    0    0    0    0
岩手県          0    0    0    0    0    0    0    0    0
宮城県          0    0    0    0    0    0    0    0    0
秋田県          0    0    0    0    0    0    0    0    0
山形県          0    0    0    0    0    0    0    0    0
福島県          0    0    0    0    0    0    0    0    0
茨城県          0    0    0    0    0    0    0    0    0
栃木県          0    0    0    0    0    0    0    0    0
群馬県          0    0    0    0    0    0    0    0    0
埼玉県          0    0    0    0    0    0    0    0    0
千葉県          0    0    0    0    0    0    0    0    0
東京都          0    0    0    0    0    2    1    1    1
神奈川県        0    0    0    0    0    0    0    0    1
新潟県          0    0    0    0    0    0    0    0    0
富山県          0    0    0    0    0    0    0    0    0
石川県          0    0    0    0    0    0    0    0    0
福井県          0    0    0    0    0    0    0    0    0
山梨県          0    0    0    0    0    0    0    0    0
長野県          0    0    0    0    0    0    0    0    0
岐阜県          0    0    0    0    0    0    0    0    0
静岡県          0    0    0    0    0    0    0    0    0
愛知県          0    0    0    0    0    0    0    0    0
三重県          0    0    0    0    0    0    0    0    0
滋賀県          0    0    0    0    0    0    0    0    0
京都府          0    0    0    0    0    0    0    0    0
大阪府          0    0    0    0    0    0    0    0    0
兵庫県          0    0    0    0    0    0    0    0    0
奈良県          0    0    0    0    0    0    0    0    0
和歌山県        0    0    0    0    0    0    0    1    0
鳥取県          0    0    0    0    0    0    0    0    0
島根県          0    0    0    0    0    0    0    0    0
岡山県          0    0    0    0    0    0    0    0    0
広島県          0    0    0    0    0    0    0    0    0
山口県          0    0    0    0    0    0    0    0    0
徳島県          0    0    0    0    0    0    0    0    0
香川県          0    0    0    0    0    0    0    0    0
愛媛県          0    0    0    0    0    0    0    0    0
高知県          0    0    0    0    0    0    0    0    0
福岡県          0    0    0    0    0    0    0    0    0
佐賀県          0    0    0    0    0    0    0    0    0
長崎県          0    0    0    0    0    0    0    0    0
熊本県          0    0    0    0    0    0    0    0    0
大分県          0    0    0    0    0    0    0    0    0
宮崎県          0    0    0    0    0    0    0    0    0
鹿児島県        0    0    0    0    0    0    0    0    0
沖縄県          0    0    0    0    0    0    1    0    0
", check.name=FALSE
)

# nejm のfa
f1 <- rep(c(8, 490, 241, 109), c(2, 3, 2, 2))/848
f2 <- rep(c(1, 67, 51, 44), c(2, 3, 2, 2))/163
fa <- f1/f2

# http://www.ourphn.org.au/wp-content/uploads/20200225-Article-COVID-19.pdf
# の論文にある fa
# これを使ってもいい再現にはならない
f1 <- c(0.01,1,7,18,38,130,309,213,208) # death
f2 <- c(416,549,3619,7600,8571,10008,8583,3918,1408) # confirmed case
fa <- f2/f1

library(vioplot)
library(rstan)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

code <- "
  data{
    int<lower=0> A;          // age group
    int<lower=0> X;          // prefecture
    matrix<lower=0>[X, A] N; // population
    int<lower=0> Dn[X, A];   // non-severe
    int<lower=0> Ds[X, A];   // severe
    matrix<lower=0>[X, A] f; // ratio of non-severe to severe
  }
  parameters{
    matrix<lower=0, upper=0.3>[X, A] p;
    real<lower=0> k;
  }
  transformed parameters{
    matrix<lower=0>[X, A] lambda1;
    matrix<lower=0>[X, A] lambda2;
    lambda1 = (k * N .* p) .* f;
    lambda2 = (N .* p);
  }
  model{
    for(a in 1:A){
      for(x in 1:X){
        Dn[x, a] ~ poisson(lambda1[x, a]);
        Ds[x, a] ~ poisson(lambda2[x, a]);
      }
    }
  }
"

m0 <- stan_model(model_code=code)
standata <- list(X=nrow(pop), A=ncol(pop), N=pop, Dn=infec, Ds=infec_severe, f=t(replicate(nrow(pop), fa)))
fit <- sampling(m0, standata, iter=10000, warmup=3000, chain=4)
ex <- extract(fit, pars=head(fit@model_pars, -1))
alpha <- 0.05
cia <- c(alpha/2, 0.5, 1-alpha/2)
quantile(ex$k, cia) # k の推定区間

est <- list("non-severe"=t(mapply(function(z) colSums(ex$lambda1[z,,]), 1:dim(ex$lambda1)[1])),
            severe=t(mapply(function(z) colSums(ex$lambda2[z,,]), 1:dim(ex$lambda2)[1])))
Ninfec <- lapply(list(infec, infec_severe), colSums)
xl <- c(0.5, ncol(pop)+0.5)
yl <- c(0, max(unlist(est), unlist(Ninfec)))
par(mfrow=c(2, 1), las=1)
for(i in seq(Ninfec)){
  plot(Ninfec[[i]], type="n", col="red", xaxt="n", xlim=xl, ylim=yl, xlab="Age group", ylab="count")
  vioplot(as.data.frame(est[[i]]), ylim=yl, add=TRUE)
  points(Ninfec[[i]], pch=15, col="red")
  axis(1, at=seq(ncol(pop)), labels=colnames(pop))
  legend("topleft", legend=c("Estimate", "Data"), pch=15, col=c(grey(0.3), "red"))
  title(names(est)[i])
}

新型肺炎COVID-19 の潜伏期間をrstanで推定する

読んだ。
Incubation period of 2019 novel coronavirus (2019-nCoV) infections among travellers from Wuhan, China, 20-28 January 2020. - PubMed - NCBI

最初に武漢で肺炎が発生したときに、88症例について感染履歴を聴取して、ワイブル分布で潜伏期間を推定すると平均6.4日(95% credible interval (CI): 5.6–7.7)、潜伏期間の幅は2.1から11.1日(2.5th to 97.5th percentile)だった、という。
論文ではワイブル分布のほかに、ガンマ分布、対数正規分布で推定して、looicでもっともよかったのがワイブル分布だった、と言っている。
supplemental にスクリプトがあったのでぱくってやってみる。

結果としてはだいぶずれたような気がする。

weibull gamma lognormal
2.5% 4.629226 5.531133 5.174367
50% 6.828260 6.353485 6.070993
97.5% 9.616947 7.583359 7.347166
looIC 480.119496 524.649064 586.254024

前処理にtidyverseがうざすぎるのでこちらで前処理したデータを置いておく。

library(rstan)
library(loo)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

# tidyverse がうざいので加工済みデータを置いておく
input <- list(N=88, 
tStartExposure=c(-2,-18,-18,10,3,8,12,13,-18,-18,8,10,-11,-18,-18,-18,12,-18,-18,-18,-18,-18,-18,11,11,-18,-18,-18,-18,-18,12,12,15,15,-18,-18,-18,-18,19,-18,-18,-18,-18,-18,-18,-18,18,-18,-18,6,-18,-18,-18,9,-18,11,-18,-18,-18,-18,-18,-18,-18,-18,-18,-18,19,-18,-18,-18,-18,-18,-18,-18,13,13,-18,-18,-18,-18,-18,-18,13,-18,-18,-18,-18,-18),
tEndExposure=c(3,12,3,11,4,16,16,16,15,15,16,11,9,2,12,17,15,17,18,13,16,11,18,14,18,20,12,10,17,15,15,14,17,20,18,13,23,23,22,20,21,21,21,15,21,17,20,18,20,7,20,20,20,20,18,22,22,18,18,18,9,20,18,22,23,19,19,22,22,13,22,22,23,23,17,17,22,18,18,22,18,20,15,20,19,23,22,22),
tSymptomOnset=c(3,15,4,14,9,16,16,16,16,16,16,14,10,10,14,20,19,19,20,20,17,13,19,19,20,21,18,18,18,16,20,16,20,20,19,17,24,24,23,21,23,23,23,21,22,24,24,19,21,14,23,23,21,20,21,22,23,19,23,21,13,22,24,25,25,25,25,24,24,21,23,23,24,25,18,18,23,22,22,24,24,25,22,25,25,24,25,25)
)

# 3つの確率分布を一気に作る
dists <- c("weibull", "gamma", "lognormal")
code <- sprintf("
  data{
    int<lower=1> N;
    vector[N] tStartExposure;
    vector[N] tEndExposure;
    vector[N] tSymptomOnset;
  }
  parameters{
    real<lower=0> par[2];
    vector<lower=0, upper=1>[N] uE;	// Uniform value for sampling between start and end exposure
  }
  transformed parameters{
    vector[N] tE; 	// infection moment
    tE = tStartExposure + uE .* (tEndExposure - tStartExposure);
  }
  model{
    // Contribution to likelihood of incubation period
    target += %s_lpdf(tSymptomOnset -  tE  | par[1], par[2]);
  }
  generated quantities {
    // likelihood for calculation of looIC
    vector[N] log_lik;
    for (i in 1:N) {
      log_lik[i] = %s_lpdf(tSymptomOnset[i] -  tE[i]  | par[1], par[2]);
    }
  }
", dists, dists)
names(code) <- dists

m0 <- mapply(stan_model, model_code=code)
fit <- mapply(sampling, m0, list(input), iter=10000, warmup=3000, chain=4)
ps <- mapply(function(z) extract(z)$par, fit, SIMPLIFY=FALSE)

# 確率分布のパラメータから得られる解析的な平均値
means <- cbind(ps$weibull[,1]*factorial(1+1/ps$weibull[,2]+1),
               ps$gamma[,1] / ps$gamma[,2],
               exp(ps$lognormal[,1]))
alpha <- c(0.025, 0.5, 0.975)
res <- apply(means, 2, quantile, alpha)
ll <- mapply(function(z) loo(extract_log_lik(z))$looic, fit)
rbind(res, looIC=ll)

新型肺炎COVID-19の感染力R0を推定する

読んだ。
www.ncbi.nlm.nih.gov
巷を賑わせているCOVID-19だが、厚生労働省がダイヤモンド・プリンセス号のPCR陽性者数を逐一ネットに挙げていたので、この論文にもあるようにそこからデータを取ってきて、COVID-19の感染力を推定しようと思った。
Basic reproduction number - Wikipedia

ここで、感染症の感染力とは、いろいろゴニョゴニョやった結果、R_0 と呼ばれる、感染者ひとりが何人に感染させるか、という推定値で感染力の強さがわかる。ここで、この論文での推定値はR_0=2.28 (95% CI 2.06-2.52) ということになっている。
COVID-19は飛沫感染ということになっているので、だいたい1-3くらいのR_0 である。例えば、最強の空気感染の感染形態をもつ麻疹は、R_0=18 くらいあるようなので、めっちゃ強い(小並感
ちなみに、感染力R_0 がわかると、集団のどれくらいの人が免疫をつければ感染を押さえ込めるかがモデル的にわかって、1-\frac{1}{R_0} となる。R_0<1 なら勝手に感染が収束するし、R_0>1なら少しは集団免疫がないと感染が大爆発する。季節性インフルエンザはR_0=3 くらいなので、60-70% くらいがワクチン接種をして集団免疫をつけると感染の大流行が防げるし、感染力の強い麻疹は、95% くらいが集団免疫をつけないと大流行する。
昔、こんなことをやった。
mikuhatsune.hatenadiary.com

さて、論文のとおりにやってみよう。論文ではまず、日本に停泊したダイヤモンド・プリンセス号のPCR陽性者数を、厚生労働省の発表から取ってきたようである。しかし、第8報(なぜか全角である)だけ引用されていて、停泊して隔離(?)され始めた2月3日から2月16日までのデータを利用したようだが、正直データは集まらないと思う。
横浜港に寄港したクルーズ船内で確認された新型コロナウイルス感染症について
新型コロナウイルス感染症の現在の状況と厚生労働省の対応について(令和2年2月6日版)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第3報)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第4報)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第5報)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第6報)
新型コロナウイルス感染症の現在の状況と厚生労働省の対応について(令和2年2月12日版)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第8報)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第9報)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第10報)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第11報)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第12報)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第13報)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第14報)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(2月23日公表分)
横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(2月26日公表分)

実はwiki にまとまっていた。
クルーズ客船における2019年コロナウイルス感染症の流行状況 - Wikipedia

論文ではearlyR というパッケージのget_Rという関数でR_0 の推定を行なっている。これには感染から発症までの期間 serial interval というものが必要らしく、ガンマ分布で近似しているようなので、こちらの論文から平均 7.5、標準偏差 3.4 となるようにすると、shape=4.865917、scale=1.541333 となる。
Early Transmission Dynamics in Wuhan, China, of Novel Coronavirus-Infected Pneumonia. - PubMed - NCBI

f:id:MikuHatsune:20200317233338p:plain

2月16日までのデータで論文の通り(? データがあっているのかわからないしコードも本当にあっているのかわからないのでわからない)に実行すると、R_0=2.322 だった。
f:id:MikuHatsune:20200317233653p:plain

さて、ダイヤモンド・プリンセス号のデータは2月26日まであって、プロットをみる限りでは2月20日くらいから感染者数の急激な増加は収まっているようにみえるので、2月26日までのデータでやってみる。ここで、PCR陽性が判明した2月5日から、1日ずつ増加させてみながらR_0 を推定すると、感染拡大の超初期は莫大に感染が増えている(R_0 が大きい)が、なんらかの値に収束していくようである。こういうようにR_0 を推定するのが果たして正しいのかどうかは知らない。
ここで、R_0 の収束にはdrc パッケージのEXD.3 関数を用い、漸近モデルとした。このとき、R_0=1.62 くらいだった。
(飛んでいる日は前後の感染者総数および検査総数から適当に補完してシミュレーションしたので、若干のばらつきがある)
f:id:MikuHatsune:20200317234030p:plain

get_R の仕様をまだよく読み込んでないのだが、\lambda というのが感染の勢い(?) を表している。ここでは観察から14日目(2月19日)にピークで、その後感染拡大の勢いが小さくなっている。
f:id:MikuHatsune:20200317234357p:plain

本当にこれであっているのかよくわからないので、サンプルでエボラをやってみた。2014年にエボラの大流行があって、outbreaks というパッケージにデータが入っている。エボラ自体は血液感染で、R_0=2 くらいらしい。
これもCOVID-19 と同じようにやってみると、結局R_0=1.5 くらいになった。いいのかよくわからない。
f:id:MikuHatsune:20200317234517p:plain

積んでるリスト
Time-varying transmission dynamics of Novel Coronavirus Pneumonia in China | bioRxiv
www.ncbi.nlm.nih.gov
https://web.stanford.edu/~jhj1/teachingdocs/Jones-on-R0.pdf

dat <- read.delim("corona_infection.txt", stringsAsFactors=FALSE)

library(earlyR)
library(incidence)
library(drc)
library(stringr)
library(outbreaks)

# 論文のデータを再現する
xx <- seq(0, 23, length=1000)
x <- 0:21

m0 <- 7.5
sd0 <- 3.4

# ガンマ分布のパラメータに戻す
a <- (m0/sd0)^2
s <- (sd0^2/m0)

plot(xx, dgamma(xx, shape=a, scale=s), type="l", xlab="Days after onset", ylab="", lwd=3)
points(x, dgamma(x, shape=a, scale=s), pch=15)

# 厚生労働省のデータは飛んでいる日があるので
# 飛んでいる前後の日のデータから適当にサンプリングして補完する
n <- dat$total_positive
positive_N_simulator <- function(n){
  r <- rle(is.na(n))
  r0 <- tapply(n, rep(seq(r$lengths), r$lengths), c)
  r1 <- tapply(seq(n), rep(seq(r0), sapply(r0, length)), c)
  r2 <- sapply(sapply(r0, is.na), any)
  r3 <- sapply(r1[r2], length)
  r4 <- list(n[sapply(r1[r2], head, 1) - 1], n[sapply(r1[r2], tail, 1) + 1])
  r5 <- sapply(mapply(function(z) sample(r4[[1]][z]:r4[[2]][z], size=r3[z], replace=FALSE), seq(r3)), sort)
  n1 <- replace(n, is.na(n), unlist(r5))
  n2 <- mapply(rbinom, n=1, size=unlist(r5), prob=sample(na.omit(dat$total_positive/dat$total_n), size=1))
  N <- replace(dat$test_positive, is.na(dat$test_positive), diff(n1)[which(is.na(dat$test_positive))-1])
  return(N)
}

# wiki からとったデータ
total <- 3711
wiki <- c(10, 20, 61, 64, 70, 135, 135, 174, 218, 218, 285, 355, 454, 542, 621, 634, 691, 705, 706, 706, 712)
N <- c(wiki[1], diff(wiki))

# 飛んでいる日を適当に補完してシミュレーション
iter <- 50
R0 <- matrix(NA, nrow(dat), iter)
for(j in seq(iter)){
  N <- positive_N_simulator(dat$total_positive)
  for(i in 2:length(N)){
    inc <- incidence(rep(seq(N[1:i]), N[1:i]))
    res <- get_R(inc, si_mean=m0, si_sd=sd0, max_R=20)
    R0[i,j] <- res$R_ml
  }
}

# フィッテング
Y <- c(R0)
X <- c(row(R0))
fit <- drm(Y ~ X, fct=EXD.3())
cur <- predict(fit, as.data.frame(xx))

yl <- c(0, max(R0, na.rm=TRUE))
txt <- format(as.Date(dat$date), "%m/%d")
rownames(R0) <- txt
par(mar=c(4, 5, 2, 2))
matplot(R0, pch=16, xaxt="n", ylim=yl, las=2, xlab="", ylab=expression(R[0]), col=1, cex.lab=2)
axis(1, at=seq(nrow(R0)), labels=txt, las=2)
abline(h=1, lty=3)
lines(xx, cur, lwd=3, col="red")


# エボラでやってみる
onset <- outbreaks::ebola_sim$linelist$date_of_onset
N <- table(onset)
R0 <- rep(0, length(N))
res <- vector("list", length(N))
for(i in 2:length(R0)){
  inc <- incidence(rep(seq(N[1:i]), N[1:i]))
  res[[i]] <- get_R(inc, disease="ebola", max_R=20)
  R0[i] <- res[[i]]$R_ml
}
xx <- seq(0, 200, length=1000)
Y <- R0[R0 > 0]
X <- seq(R0)[R0 > 0]
fit <- drm(Y ~ X, fct=EXD.3())
cur <- predict(fit, as.data.frame(xx))

xl <- c(0, length(N))
par(mfrow=c(3, 1), mar=c(4, 5, 2, 2), cex.lab=1.5, las=1)
plot(R0, xlab="Date", ylab=expression(R[0]), xlim=xl, pch=15)
lines(xx, cur, lwd=3, col="red")
abline(h=1, lty=3)
plot(res[[40]]) # 適当に40日目
plot(tail(res, 1)[[1]], "lambda", xlim=xl)
date	test_positive	test_n	asymptomatic	total_positive	total_n	total_asymptomatic	title	url
2020-02-05	10	31		10	31		横浜港に寄港したクルーズ船内で確認された新型コロナウイルス感染症について	https://www.mhlw.go.jp/stf/newpage_09276.html
2020-02-06	10	71		20	102		新型コロナウイルス感染症の現在の状況と厚生労働省の対応について(令和2年2月6日版)	https://www.mhlw.go.jp/stf/newpage_09360.html
2020-02-07	41	171		61	273		横浜港に寄港したクルーズ船内で確認された新型コロナウイルス感染症について	https://www.mhlw.go.jp/stf/newpage_09340.html
2020-02-08	3	6		64	279		横浜港に寄港したクルーズ船内で確認された新型コロナウイルス感染症について(第4報)	https://www.mhlw.go.jp/stf/newpage_09398.html
2020-02-09	6	57		70	336		横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第5報)	https://www.mhlw.go.jp/stf/newpage_09405.html
2020-02-10	65	103		135	439		横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第6報)	https://www.mhlw.go.jp/stf/newpage_09419.html
2020-02-11
2020-02-12	39	67		174	492		新型コロナウイルス感染症の現在の状況と厚生労働省の対応について(令和2年2月12日版)	https://www.mhlw.go.jp/stf/newpage_09450.html
2020-02-13	44	221		218	713		横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第8報)	https://www.mhlw.go.jp/stf/newpage_09425.html
2020-02-14								
2020-02-15	67	217	38	285	930	73	横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第9報)	https://www.mhlw.go.jp/stf/newpage_09542.html
2020-02-16	70	289	38	355	1219	111	横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第10報)	https://www.mhlw.go.jp/stf/newpage_09547.html
2020-02-17	99	504	70	454	1723	189	横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第11報)	https://www.mhlw.go.jp/stf/newpage_09568.html
2020-02-18	88	681	65	542	2404	254	横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第12報)	https://www.mhlw.go.jp/stf/newpage_09606.html
2020-02-19	79	607	68	621	3011	322	横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第13報)	https://www.mhlw.go.jp/stf/newpage_09640.html
2020-02-20	13	52	6	634	3063	328	横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(第14報)	https://www.mhlw.go.jp/stf/newpage_09668.html
2020-02-21								
2020-02-22								
2020-02-23	57	831	52	691	3894	380	横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(223日公表分)	https://www.mhlw.go.jp/stf/newpage_09708.html
2020-02-24								
2020-02-25								
2020-02-26	14	167	12	705	4061	392	横浜港で検疫中のクルーズ船内で確認された新型コロナウイルス感染症について(226日公表分)	https://www.mhlw.go.jp/stf/newpage_09783.html

rstanで自分で定義した確率分布からサンプリングする:Johnson's SU 分布

本当はこの通りにしたかったが、自作関数のサンプリングが遅すぎたので先に正規分布normalからのサンプリングがvectorかひとつひとつかtargetかで変わるのかを検証していた。
結論から言うと組み込みのvector型サンプリングは速いが、自作関数はひとつひとつサンプリングすることになるので、圧倒的遅さ。
mikuhatsune.hatenadiary.com

これをやっていたが、実際のデータはやはり単純にガンマ分布ではあてはまりが悪かった。
mikuhatsune.hatenadiary.com
というわけで原点から最頻値が遠くて、かつ、左右非対称な分布をゴリ押ししようと思っていたら、Gumbel 分布gumbel)でできるようだが、もっと歪ませたいと思っていたら、Johnson's SU 分布とうものがあるらしい。これは、適当な変換を挟んで正規分布に従うようにして、確率密度関数
\frac{\delta}{\sqrt{2 \pi \lambda^2}}\frac{1}{\sqrt{1+(\frac{x-\xi}{\lambda})^2}}e^{-\frac{1}{2}*(\gamma+\delta\sinh^{-1} (\frac{x-\xi}{\lambda}))^2}
となる。これをrstanで定義すると
delta/(2*pi()*lambda^2)^0.5 * 1/(1+((x-xi)/lambda)^2)^0.5 * exp(-0.5*(gamma+delta*asinh((x-xi)/lambda))^2)
となる。

Johnson 関数群自体はSuppDistsパッケージで使える。適当に正の範囲(といっても定義域は負もある)で最頻値が50程度で左右非対称な分布を作る。パラメータを与えると、sJohnson関数で平均や分散など得られる。

sJohnson(parms)
$title
[1] "Johnson Distribution"

$gamma
[1] -5.5

$delta
[1] 2

$xi
[1] 7.5

$lambda
[1] 5

$type
SU 
 3 

$Mean
[1] 51.63246

$Median
[1] 46.44676

$Mode
[1] 37.23034

$Variance
[1] 561.298

$SD
[1] 23.69173

$ThirdCentralMoment
[1] -23119.93

$FourthCentralMoment
[1] 2784825

$PearsonsSkewness...mean.minus.mode.div.SD
[1] 0.6078967

$Skewness...sqrtB1
[1] -1.738587

$Kurtosis...B2.minus.3
[1] 5.83916


推定した結果はそこそこよい。しかし時間がクッソかかった。
f:id:MikuHatsune:20200314214536p:plain

library(rstan)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

library(SuppDists)
parms <- list(gamma=-5.5, delta=2, xi=7.5, lambda=5, type="SU")
n <- 3000
hoge <- rJohnson(n, parms)

code <- "
  functions{ //手作り正規分布を定義
    real johnsonSU_lpdf(real x, real gamma, real delta, real xi, real lambda){
      real tmp;
      tmp = delta/(2*pi()*lambda^2)^0.5 * 1/(1+((x-xi)/lambda)^2)^0.5 * exp(-0.5*(gamma+delta*asinh((x-xi)/lambda))^2);
      return log(tmp);
    }
  }
  data {
    int<lower=0> N;
    real<lower=0> y[N];
  }
  parameters {
  real gamma;
  real<lower=0> delta;
  real xi;
  real<lower=0> lambda;
  }
  model {
    for(i in 1:N){
    y[i] ~ johnsonSU(gamma, delta, xi, lambda);
    }
  }
"

m0 <- stan_model(model_code=code)
standata <- list(N=length(hoge), y=hoge)
fit <- sampling(m0, standata, iter=1000, warmup=500, chain=4)
ex <- extract(fit, pars=head(fit@model_pars, -1))
pa <- c(lapply(ex, median), type="SU")


x <- seq(0, 200, length=300)
d0 <- dJohnson(x, parms)
d1 <- dJohnson(x, pa)
hist(hoge, nclass=100, freq=F, main="Johnson SU", col="yellow")
lines(x, d0, col=2, lwd=4)
lines(x, d1, col=4, lwd=4)

rstanでの確率分布からのサンプリングの速さを比較する

rstanで自作関数、というかrstanに実装されていない確率分布からサンプリングをしたくてコードを書いていたが、その前にコードの書き方でサンプリングの効率というか速さが違うので速くなる書き方をしよう、という検証。
結論から言うと、実装されている関数ならば、vector 型でサンプリングするのがよく、自作関数を作るとひとつのrealもしくはintでサンプリングするので、遅い。

5つのパターンを比較する。どれも正規分布N(1, 1)からデータを作り、正規分布normal(m, s)で推定する。
code0は、vector型でy ~ normal(m, s)とサンプリングする。
code1は、ひとつひとつy[i] ~ normal(m, s)とサンプリングする。
code2は、target記法を用いてtarget += normal(y[i] | m, s)とサンプリングする。
code3は、自作関数tmp = delta/(2*pi()*lambda^2)^0.5 * 1/(1+((x-xi)/lambda)^2)^0.5 * exp(-0.5*(gamma+delta*asinh((x-xi)/lambda))^2); を定義し、code1のようにひとつひとつサンプリングする。
code4は、code3の自作関数をtarget記法でサンプリングする。
ちなみにrstanで自作関数を利用するとき、func_lpdfで関数を定義し、返り値はlog(return)とする。また、target記法で利用するときは、func_lpdfで記述する。

結果としてはcode0vector型でのサンプリングが一番早かった。rstanに定義されている関数を利用する場合は、ひとつひとつサンプリングするcode1のほうがtarget記法より速いようだった。
自作関数を利用する場合は、target記法のほうが速そうな気がする。
f:id:MikuHatsune:20200312225316p:plain

hoge <- rnorm(300, 1, 1)
code0 <- "
    data {
    int<lower=0> N;
    real y[N];
  }
  parameters {
    real m;
    real<lower=0> s;
  }
  model {
    y ~ normal(m, s);
  }
"
code1 <- "
    data {
    int<lower=0> N;
    real y[N];
  }
  parameters {
    real m;
    real<lower=0> s;
  }
  model {
    for(i in 1:N){
      y[i] ~ normal(m, s);
    }
  }
"
code2 <- "
    data {
    int<lower=0> N;
    real y[N];
  }
  parameters {
    real m;
    real<lower=0> s;
  }
  model {
    for(i in 1:N){
      target += normal_lpdf(y[i] | m, s);
    }
  }
"
code3 <- "
  functions{ //手作り正規分布を定義
    real my_lpdf(real x, real m, real s){
      real tmp;
      tmp = 1/(2*pi()*s^2)^0.5 * exp(-0.5*((x-m)/s)^2);
      return log(tmp);
    }
  }
  data {
    int<lower=0> N;
    real y[N];
  }
  parameters {
    real m;
    real<lower=0> s;
  }
  model {
    for(i in 1:N){
      y[i] ~ my(m, s);
    }
  }
"
code4 <- "
  functions{ //手作り正規分布を定義
    real my_lpdf(real x, real m, real s){
      real tmp;
      tmp = 1/(2*pi()*s^2)^0.5 * exp(-0.5*((x-m)/s)^2);
      return log(tmp);
    }
  }
  data {
    int<lower=0> N;
    real y[N];
  }
  parameters {
    real m;
    real<lower=0> s;
  }
  model {
    for(i in 1:N){
      target += my_lpdf(y[i] | m, s);
    }
  }
"
codes <- list(code0, code1, code2, code3, code4)
names(codes) <- c("vector", "elementwise", "element_target", "self_function", "self_target")

# 一括してコンパイルするが
# コンピュータのメモリが貧弱だと並列にすると死ぬかもしれない
library(rstan)
rstan_options(auto_write=TRUE)
# options(mc.cores=parallel::detectCores())


ms <- mapply(stan_model, model_code=codes)
standata <- list(N=length(hoge), y=hoge)
fits <- mapply(function(z) sampling(z, standata, iter=2000, warmup=1000, chain=100), ms)

elp <- mapply(function(w) colSums(mapply(function(z) attributes(z)$elapsed_time, w@sim$samples)), fits)


library(vioplot)
vioplot(elp, horizontal=TRUE, side="right", yaxt="n", xlab="elapsed time [sec]")
abline(h=seq(ncol(elp)), lty=3)
axis(1)
text(par()$usr[2], seq(ncol(elp))+0.3, colnames(elp), xpd=TRUE, pos=2)