Skip to contents

This function implements a Particle Marginal Metropolis-Hastings (PMMH) algorithm to perform Bayesian inference in state-space models. It first runs a pilot chain to tune the proposal distribution and the number of particles for the particle filter, and then runs the main PMMH chain.

Usage

pmmh(
  y,
  m,
  init_fn,
  transition_fn,
  log_likelihood_fn,
  log_priors,
  pilot_init_params,
  burn_in,
  num_chains = 4,
  obs_times = NULL,
  algorithm = c("SISAR", "SISR", "SIS"),
  resample_fn = c("stratified", "systematic", "multinomial"),
  param_transform = NULL,
  tune_control = default_tune_control(),
  verbose = FALSE,
  return_latent_state_est = FALSE,
  seed = NULL,
  num_cores = 1
)

Arguments

y

A numeric vector or matrix of observations. Each row represents an observation at a time step.

m

An integer specifying the total number of MCMC iterations.

init_fn

A function to initialize the state-space model.

transition_fn

A function that defines the state transition of the state-space model.

log_likelihood_fn

A function that calculates the log-likelihood for the state-space model given latent states.

log_priors

A list of functions for computing the log-prior of each parameter.

pilot_init_params

A list of initial parameter values. Should be a list of length num_chains where each element is a named vector of initial parameter values.

burn_in

An integer indicating the number of initial MCMC iterations to discard as burn-in.

num_chains

An integer specifying the number of PMMH chains to run.

obs_times

A numeric vector indicating the time points at which observations in y are available. Must be of the same length as the number of rows in y. If not specified, it is assumed that observations are available at consecutive time steps, i.e., obs_times = 1:nrow(y).

algorithm

A character string specifying the particle filtering algorithm to use. Must be one of "SISAR", "SISR", or "SIS". Defaults to "SISAR".

resample_fn

A character string specifying the resampling method. Must be one of "stratified", "systematic", or "multinomial". Defaults to "stratified".

param_transform

An optional character vector that specifies the transformation applied to each parameter before proposing. The proposal is made using a multivariate normal distribution on the transformed scale. Parameters are then mapped back to their original scale before evaluation. Currently supports "log", "invlogit", and "identity". If NULL, the "identity" transformation is used for all parameters.

tune_control

A list generated by default_tune_control containing tuning parameters for the pilot chain, such as pilot_m, pilot_n, pilot_reps, pilot_proposal_sd, pilot_algorithm, and pilot_resample_fn.

verbose

A logical value indicating whether to print information about pilot_run tuning. Defaults to FALSE.

return_latent_state_est

A logical value indicating whether to return the latent state estimates for each time step. Defaults to FALSE.

seed

An optional integer to set the seed for reproducibility.

num_cores

An integer specifying the number of cores to use for parallel processing. Defaults to 1. Each chain is assigned to its own core, so the number of cores cannot exceed the number of chains (num_chains). The progress information given to user is limited if using more than one core.

Value

A list containing:

theta_chain

A dataframe of post burn-in parameter samples.

latent_state_chain

If return_latent_state_est is TRUE, a list of matrices containing the latent state estimates for each time step.

diagnostics

Diagnostics containing ESS and Rhat for each parameter (see ess and rhat for documentation).

Details

The PMMH algorithm is essentially a Metropolis Hastings algorithm where instead of using the exact likelihood it instead uses an estimated using likelihood using a particle filter (see also particle_filter). Values are proposed using a multivariate normal distribution in the transformed space. The proposal covariance is estimated using the pilot chain.

References

Andrieu et al. (2010). Particle Markov chain Monte Carlo methods. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 72(3):269–342. doi: 10.1111/j.1467-9868.2009.00736.x

Examples

init_fn <- function(particles) {
  rnorm(particles, mean = 0, sd = 1)
}
transition_fn <- function(particles, phi, sigma_x) {
  phi * particles + sin(particles) +
    rnorm(length(particles), mean = 0, sd = sigma_x)
}
log_likelihood_fn <- function(y, particles, sigma_y) {
  dnorm(y, mean = particles, sd = sigma_y, log = TRUE)
}
log_prior_phi <- function(phi) {
  dnorm(phi, mean = 0, sd = 1, log = TRUE)
}
log_prior_sigma_x <- function(sigma) {
  dexp(sigma, rate = 1, log = TRUE)
}
log_prior_sigma_y <- function(sigma) {
  dexp(sigma, rate = 1, log = TRUE)
}
log_priors <- list(
  phi = log_prior_phi,
  sigma_x = log_prior_sigma_x,
  sigma_y = log_prior_sigma_y
)
# Generate data
t_val <- 10
x <- numeric(t_val)
y <- numeric(t_val)
x[1] <- rnorm(1, mean = 0, sd = 1)
y[1] <- rnorm(1, mean = x[1], sd = 0.5)
for (t in 2:t_val) {
  x[t] <- 0.8 * x[t - 1] + sin(x[t - 1]) + rnorm(1, mean = 0, sd = 1)
  y[t] <- x[t] + rnorm(1, mean = 0, sd = 0.5)
}
# Should use much higher MCMC iterations in practice (m)
pmmh_result <- pmmh(
  y = y,
  m = 1000,
  init_fn = init_fn,
  transition_fn = transition_fn,
  log_likelihood_fn = log_likelihood_fn,
  log_priors = log_priors,
  pilot_init_params = list(
    c(phi = 0.8, sigma_x = 1, sigma_y = 0.5),
    c(phi = 1, sigma_x = 0.5, sigma_y = 1)
  ),
  burn_in = 100,
  num_chains = 2,
  param_transform = list(
    phi = "identity",
    sigma_x = "log",
    sigma_y = "log"
  ),
  tune_control = default_tune_control(pilot_m = 500, pilot_burn_in = 100)
)
#> Running chain 1...
#> Running pilot chain for tuning...
#> Using 50 particles for PMMH:
#> Running particle MCMC chain with tuned settings...
#> Running chain 2...
#> Running pilot chain for tuning...
#> Using 50 particles for PMMH:
#> Running particle MCMC chain with tuned settings...
#> PMMH Results Summary:
#>  Parameter Mean   SD Median CI Lower.2.5% CI Upper.97.5% ESS  Rhat
#>        phi 0.64 0.19   0.64          0.28           1.03  78 1.031
#>    sigma_x 0.52 0.40   0.43          0.02           1.49  38 1.003
#>    sigma_y 0.75 0.31   0.75          0.18           1.38  65 1.002
#> Warning: Some ESS values are below 400, indicating poor mixing. Consider running the chains for more iterations.
#> Warning: 
#> Some Rhat values are above 1.01, indicating that the chains have not converged. 
#> Consider running the chains for more iterations and/or increase burn_in.
# Convergence warning is expected with such low MCMC iterations.

# Suppose we have data for t=1,2,3,5,6,7,8,9,10 (i.e., missing at t=4)

obs_times <- c(1, 2, 3, 5, 6, 7, 8, 9, 10)
y <- y[obs_times]

# Specify observation times in the pmmh using obs_times
pmmh_result <- pmmh(
  y = y,
  m = 1000,
  init_fn = init_fn,
  transition_fn = transition_fn,
  log_likelihood_fn = log_likelihood_fn,
  log_priors = log_priors,
  pilot_init_params = list(
    c(phi = 0.8, sigma_x = 1, sigma_y = 0.5),
    c(phi = 1, sigma_x = 0.5, sigma_y = 1)
  ),
  burn_in = 100,
  num_chains = 2,
  obs_times = obs_times,
  param_transform = list(
    phi = "identity",
    sigma_x = "log",
    sigma_y = "log"
  ),
  tune_control = default_tune_control(pilot_m = 500, pilot_burn_in = 100)
)
#> Running chain 1...
#> Running pilot chain for tuning...
#> Using 50 particles for PMMH:
#> Running particle MCMC chain with tuned settings...
#> Running chain 2...
#> Running pilot chain for tuning...
#> Using 50 particles for PMMH:
#> Running particle MCMC chain with tuned settings...
#> PMMH Results Summary:
#>  Parameter Mean   SD Median CI Lower.2.5% CI Upper.97.5% ESS  Rhat
#>        phi 0.66 0.18   0.68          0.30           0.99  77 1.006
#>    sigma_x 0.61 0.44   0.56          0.05           1.60   9 1.077
#>    sigma_y 0.70 0.32   0.70          0.14           1.42  40 1.061
#> Warning: Some ESS values are below 400, indicating poor mixing. Consider running the chains for more iterations.
#> Warning: 
#> Some Rhat values are above 1.01, indicating that the chains have not converged. 
#> Consider running the chains for more iterations and/or increase burn_in.