Fit models for use in examples
(string) The name of the example. The currently available examples are
"logistic"
: logistic regression with intercept and 3 predictors.
"schools"
: the so-called "eight schools" model, a hierarchical
meta-analysis. Fitting this model will result in warnings about
divergences.
"schools_ncp"
: non-centered parameterization of the "eight schools"
model that fixes the problem with divergences.
To print the Stan code for a given example
use
print_example_program(example)
.
(string) Which fitting method should be used? The default is
the "sample"
method (MCMC).
Arguments passed to the chosen method
. See the help pages for
the individual methods for details.
(logical) If TRUE
(the default) then fitting the model is
wrapped in utils::capture.output()
.
The fitted model object returned by the selected method
.
# \dontrun{
print_example_program("logistic")
#> data {
#> int<lower=0> N;
#> int<lower=0> K;
#> array[N] int<lower=0, upper=1> y;
#> matrix[N, K] X;
#> }
#> parameters {
#> real alpha;
#> vector[K] beta;
#> }
#> model {
#> target += normal_lpdf(alpha | 0, 1);
#> target += normal_lpdf(beta | 0, 1);
#> target += bernoulli_logit_glm_lpmf(y | X, alpha, beta);
#> }
#> generated quantities {
#> vector[N] log_lik;
#> for (n in 1 : N) {
#> log_lik[n] = bernoulli_logit_lpmf(y[n] | alpha + X[n] * beta);
#> }
#> }
fit_logistic_mcmc <- cmdstanr_example("logistic", chains = 2)
fit_logistic_mcmc$summary()
#> # A tibble: 105 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <num> <num> <num> <num> <num> <num> <num> <num>
#> 1 lp__ -65.9 -65.6 1.40 1.25 -68.6 -64.3 1.00 1066.
#> 2 alpha 0.381 0.385 0.216 0.223 0.0256 0.732 1.00 1952.
#> 3 beta[1] -0.667 -0.657 0.250 0.249 -1.09 -0.266 1.00 1902.
#> 4 beta[2] -0.275 -0.271 0.223 0.222 -0.655 0.0809 1.00 2026.
#> 5 beta[3] 0.691 0.680 0.272 0.261 0.257 1.14 1.00 2036.
#> 6 log_lik[1] -0.516 -0.508 0.100 0.0986 -0.696 -0.364 1.00 2121.
#> 7 log_lik[2] -0.399 -0.376 0.146 0.142 -0.665 -0.195 1.00 2027.
#> 8 log_lik[3] -0.496 -0.463 0.218 0.204 -0.895 -0.205 0.999 2110.
#> 9 log_lik[4] -0.448 -0.433 0.154 0.149 -0.731 -0.227 1.00 1978.
#> 10 log_lik[5] -1.19 -1.17 0.276 0.277 -1.65 -0.776 1.00 2176.
#> # ℹ 95 more rows
#> # ℹ 1 more variable: ess_tail <num>
fit_logistic_optim <- cmdstanr_example("logistic", method = "optimize")
fit_logistic_optim$summary()
#> # A tibble: 105 × 2
#> variable estimate
#> <chr> <num>
#> 1 lp__ -63.9
#> 2 alpha 0.364
#> 3 beta[1] -0.632
#> 4 beta[2] -0.259
#> 5 beta[3] 0.649
#> 6 log_lik[1] -0.515
#> 7 log_lik[2] -0.394
#> 8 log_lik[3] -0.469
#> 9 log_lik[4] -0.442
#> 10 log_lik[5] -1.14
#> # ℹ 95 more rows
fit_logistic_vb <- cmdstanr_example("logistic", method = "variational")
fit_logistic_vb$summary()
#> # A tibble: 106 × 7
#> variable mean median sd mad q5 q95
#> <chr> <num> <num> <num> <num> <num> <num>
#> 1 lp__ -66.7 -66.3 1.91 1.73 -70.5 -64.4
#> 2 lp_approx__ -2.03 -1.75 1.37 1.23 -4.66 -0.332
#> 3 alpha 0.519 0.518 0.223 0.224 0.153 0.895
#> 4 beta[1] -0.690 -0.687 0.234 0.244 -1.06 -0.298
#> 5 beta[2] -0.296 -0.297 0.262 0.251 -0.720 0.144
#> 6 beta[3] 0.546 0.548 0.309 0.310 0.0501 1.06
#> 7 log_lik[1] -0.454 -0.448 0.0931 0.0926 -0.622 -0.311
#> 8 log_lik[2] -0.530 -0.498 0.215 0.215 -0.951 -0.244
#> 9 log_lik[3] -0.488 -0.446 0.238 0.228 -0.906 -0.178
#> 10 log_lik[4] -0.535 -0.516 0.191 0.180 -0.875 -0.267
#> # ℹ 96 more rows
print_example_program("schools")
#> data {
#> int<lower=1> J;
#> vector<lower=0>[J] sigma;
#> vector[J] y;
#> }
#> parameters {
#> real mu;
#> real<lower=0> tau;
#> vector[J] theta;
#> }
#> model {
#> target += normal_lpdf(tau | 0, 10);
#> target += normal_lpdf(mu | 0, 10);
#> target += normal_lpdf(theta | mu, tau);
#> target += normal_lpdf(y | theta, sigma);
#> }
fit_schools_mcmc <- cmdstanr_example("schools")
#> Warning: 391 of 4000 (10.0%) transitions ended with a divergence.
#> See https://mc-stan.org/misc/warnings for details.
fit_schools_mcmc$summary()
#> # A tibble: 11 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
#> <chr> <num> <num> <num> <num> <num> <num> <num> <num> <num>
#> 1 lp__ -55.6 -56.0 6.40 7.14 -65.7 -45.9 1.14 19.4 333.
#> 2 mu 6.95 7.68 4.20 3.72 -0.0865 14.5 1.07 246. 40.2
#> 3 tau 4.24 3.19 3.55 3.11 0.592 11.2 1.18 15.7 9.70
#> 4 theta[1] 9.05 7.63 6.17 5.01 -0.247 20.1 1.11 575. 1129.
#> 5 theta[2] 7.29 7.96 5.17 4.40 -1.33 15.6 1.06 383. 660.
#> 6 theta[3] 6.04 7.16 6.37 5.11 -4.96 15.8 1.05 345. 591.
#> 7 theta[4] 7.16 7.85 5.55 4.64 -2.09 16.3 1.04 287. 330.
#> 8 theta[5] 5.44 6.31 5.68 4.65 -4.51 14.3 1.02 185. 57.9
#> 9 theta[6] 6.24 7.03 5.68 4.70 -3.84 15.5 1.02 162. 83.5
#> 10 theta[7] 8.94 7.77 5.69 4.69 0.261 18.8 1.15 489. 1084.
#> 11 theta[8] 7.45 8.18 6.27 5.20 -2.53 16.9 1.06 527. 1455.
print_example_program("schools_ncp")
#> data {
#> int<lower=1> J;
#> vector<lower=0>[J] sigma;
#> vector[J] y;
#> }
#> parameters {
#> real mu;
#> real<lower=0> tau;
#> vector[J] theta_raw;
#> }
#> transformed parameters {
#> vector[J] theta = mu + tau * theta_raw;
#> }
#> model {
#> target += normal_lpdf(tau | 0, 10);
#> target += normal_lpdf(mu | 0, 10);
#> target += normal_lpdf(theta_raw | 0, 1);
#> target += normal_lpdf(y | theta, sigma);
#> }
fit_schools_ncp_mcmc <- cmdstanr_example("schools_ncp")
fit_schools_ncp_mcmc$summary()
#> # A tibble: 19 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <num> <num> <num> <num> <num> <num> <num> <num>
#> 1 lp__ -46.9 -46.7 2.45 2.37 -51.5 -43.4 1.00 1390.
#> 2 mu 6.45 6.53 4.16 4.02 -0.486 13.3 1.00 3299.
#> 3 tau 4.63 3.83 3.59 3.39 0.324 11.6 1.00 1997.
#> 4 theta_raw[1] 0.343 0.373 0.968 0.963 -1.26 1.90 0.999 3956.
#> 5 theta_raw[2] 0.0288 0.0326 0.919 0.900 -1.49 1.54 1.00 4784.
#> 6 theta_raw[3] -0.150 -0.150 0.930 0.871 -1.70 1.39 1.00 4594.
#> 7 theta_raw[4] 0.0241 0.0138 0.929 0.917 -1.51 1.53 1.00 3987.
#> 8 theta_raw[5] -0.254 -0.266 0.923 0.936 -1.77 1.29 1.00 4286.
#> 9 theta_raw[6] -0.138 -0.151 0.936 0.911 -1.69 1.44 1.00 4732.
#> 10 theta_raw[7] 0.354 0.386 0.935 0.917 -1.22 1.88 1.00 4376.
#> 11 theta_raw[8] 0.0685 0.0563 0.960 0.955 -1.47 1.65 1.00 4469.
#> 12 theta[1] 8.82 8.19 6.70 5.73 -0.854 21.4 1.00 3878.
#> 13 theta[2] 6.71 6.68 5.47 4.95 -2.08 15.8 1.00 4371.
#> 14 theta[3] 5.58 5.85 6.23 5.28 -4.90 15.2 1.00 4089.
#> 15 theta[4] 6.54 6.57 5.71 5.16 -2.62 16.0 1.00 4307.
#> 16 theta[5] 4.90 5.24 5.62 5.24 -5.14 13.5 1.00 3910.
#> 17 theta[6] 5.65 5.84 5.75 5.24 -4.27 14.7 1.00 4570.
#> 18 theta[7] 8.69 8.23 5.89 5.49 0.0330 19.2 1.00 4332.
#> 19 theta[8] 7.00 6.80 6.39 5.51 -2.89 17.5 1.00 3804.
#> # ℹ 1 more variable: ess_tail <num>
# optimization fails for hierarchical model
cmdstanr_example("schools", "optimize", quiet = FALSE)
#> Initial log joint probability = -52.1838
#> Iter log prob ||dx|| ||grad|| alpha alpha0 # evals Notes
#> 99 122.653 0.275645 9.62182e+09 0.14 0.3154 173
#> Iter log prob ||dx|| ||grad|| alpha alpha0 # evals Notes
#> 187 258.482 0.227211 1.52288e+17 1e-12 0.001 402 LS failed, Hessian reset
#> Optimization terminated with error:
#> Line search failed to achieve a sufficient decrease, no more progress can be made
#> Finished in 0.1 seconds.
#> variable estimate
#> lp__ 258.48
#> mu 0.28
#> tau 0.00
#> theta[1] 0.28
#> theta[2] 0.28
#> theta[3] 0.28
#> theta[4] 0.28
#> theta[5] 0.28
#> theta[6] 0.28
#> theta[7] 0.28
#>
#> # showing 10 of 11 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)
# }