Shared latent states in mvgam
Nicholas J Clark
2024-09-03
Source:vignettes/shared_states.Rmd
shared_states.Rmd
This vignette gives an example of how mvgam
can be used
to estimate models where multiple observed time series share the same
latent process model. For full details on the basic mvgam
functionality, please see the
introductory vignette.
The trend_map
argument
The trend_map
argument in the mvgam()
function is an optional data.frame
that can be used to
specify which series should depend on which latent process models
(called “trends” in mvgam
). This can be particularly useful
if we wish to force multiple observed time series to depend on the same
latent trend process, but with different observation processes. If this
argument is supplied, a latent factor model is set up by setting
use_lv = TRUE
and using the supplied trend_map
to set up the shared trends. Users familiar with the MARSS
family of packages will recognize this as a way of specifying the \(Z\) matrix. This data.frame
needs to have column names series
and trend
,
with integer values in the trend
column to state which
trend each series should depend on. The series
column
should have a single unique entry for each time series in the data, with
names that perfectly match the factor levels of the series
variable in data
). For example, if we were to simulate a
collection of three integer-valued time series (using
sim_mvgam
), the following trend_map
would
force the first two series to share the same latent trend process:
set.seed(122)
simdat <- sim_mvgam(trend_model = AR(),
prop_trend = 0.6,
mu = c(0, 1, 2),
family = poisson())
trend_map <- data.frame(series = unique(simdat$data_train$series),
trend = c(1, 1, 2))
trend_map
#> series trend
#> 1 series_1 1
#> 2 series_2 1
#> 3 series_3 2
We can see that the factor levels in trend_map
match
those in the data:
Checking trend_map
with
run_model = FALSE
Supplying this trend_map
to the mvgam
function for a simple model, but setting run_model = FALSE
,
allows us to inspect the constructed Stan
code and the data
objects that would be used to condition the model. Here we will set up a
model in which each series has a different observation process (with
only a different intercept per series in this case), and the two latent
dynamic process models evolve as independent AR1 processes that also
contain a shared nonlinear smooth function to capture repeated
seasonality. This model is not too complicated but it does show how we
can learn shared and independent effects for collections of time series
in the mvgam
framework:
fake_mod <- mvgam(y ~
# observation model formula, which has a
# different intercept per series
series - 1,
# process model formula, which has a shared seasonal smooth
# (each latent process model shares the SAME smooth)
trend_formula = ~ s(season, bs = 'cc', k = 6),
# AR1 dynamics (each latent process model has DIFFERENT)
# dynamics; processes are estimated using the noncentred
# parameterisation for improved efficiency
trend_model = AR(),
noncentred = TRUE,
# supplied trend_map
trend_map = trend_map,
# data and observation family
family = poisson(),
data = simdat$data_train,
run_model = FALSE)
Inspecting the Stan
code shows how this model is a
dynamic factor model in which the loadings are constructed to reflect
the supplied trend_map
:
code(fake_mod)
#> // Stan model code generated by package mvgam
#> data {
#> int<lower=0> total_obs; // total number of observations
#> int<lower=0> n; // number of timepoints per series
#> int<lower=0> n_sp_trend; // number of trend smoothing parameters
#> int<lower=0> n_lv; // number of dynamic factors
#> int<lower=0> n_series; // number of series
#> matrix[n_series, n_lv] Z; // matrix mapping series to latent states
#> int<lower=0> num_basis; // total number of basis coefficients
#> int<lower=0> num_basis_trend; // number of trend basis coefficients
#> vector[num_basis_trend] zero_trend; // prior locations for trend basis coefficients
#> matrix[total_obs, num_basis] X; // mgcv GAM design matrix
#> matrix[n * n_lv, num_basis_trend] X_trend; // trend model design matrix
#> array[n, n_series] int<lower=0> ytimes; // time-ordered matrix (which col in X belongs to each [time, series] observation?)
#> array[n, n_lv] int ytimes_trend;
#> int<lower=0> n_nonmissing; // number of nonmissing observations
#> matrix[4, 4] S_trend1; // mgcv smooth penalty matrix S_trend1
#> array[n_nonmissing] int<lower=0> flat_ys; // flattened nonmissing observations
#> matrix[n_nonmissing, num_basis] flat_xs; // X values for nonmissing observations
#> array[n_nonmissing] int<lower=0> obs_ind; // indices of nonmissing observations
#> }
#> transformed data {
#>
#> }
#> parameters {
#> // raw basis coefficients
#> vector[num_basis] b_raw;
#> vector[num_basis_trend] b_raw_trend;
#>
#> // latent state SD terms
#> vector<lower=0>[n_lv] sigma;
#>
#> // latent state AR1 terms
#> vector<lower=-1, upper=1>[n_lv] ar1;
#>
#> // raw latent states
#> matrix[n, n_lv] LV_raw;
#>
#> // smoothing parameters
#> vector<lower=0>[n_sp_trend] lambda_trend;
#> }
#> transformed parameters {
#> // raw latent states
#> vector[n * n_lv] trend_mus;
#> matrix[n, n_series] trend;
#>
#> // basis coefficients
#> vector[num_basis] b;
#>
#> // latent states
#> matrix[n, n_lv] LV;
#> vector[num_basis_trend] b_trend;
#>
#> // observation model basis coefficients
#> b[1 : num_basis] = b_raw[1 : num_basis];
#>
#> // process model basis coefficients
#> b_trend[1 : num_basis_trend] = b_raw_trend[1 : num_basis_trend];
#>
#> // latent process linear predictors
#> trend_mus = X_trend * b_trend;
#> LV = LV_raw .* rep_matrix(sigma', rows(LV_raw));
#> for (j in 1 : n_lv) {
#> LV[1, j] += trend_mus[ytimes_trend[1, j]];
#> for (i in 2 : n) {
#> LV[i, j] += trend_mus[ytimes_trend[i, j]]
#> + ar1[j] * (LV[i - 1, j] - trend_mus[ytimes_trend[i - 1, j]]);
#> }
#> }
#>
#> // derived latent states
#> for (i in 1 : n) {
#> for (s in 1 : n_series) {
#> trend[i, s] = dot_product(Z[s, : ], LV[i, : ]);
#> }
#> }
#> }
#> model {
#> // prior for seriesseries_1...
#> b_raw[1] ~ student_t(3, 0, 2);
#>
#> // prior for seriesseries_2...
#> b_raw[2] ~ student_t(3, 0, 2);
#>
#> // prior for seriesseries_3...
#> b_raw[3] ~ student_t(3, 0, 2);
#>
#> // priors for AR parameters
#> ar1 ~ std_normal();
#>
#> // priors for latent state SD parameters
#> sigma ~ student_t(3, 0, 2.5);
#> to_vector(LV_raw) ~ std_normal();
#>
#> // dynamic process models
#>
#> // prior for (Intercept)_trend...
#> b_raw_trend[1] ~ student_t(3, 0, 2);
#>
#> // prior for s(season)_trend...
#> b_raw_trend[2 : 5] ~ multi_normal_prec(zero_trend[2 : 5],
#> S_trend1[1 : 4, 1 : 4]
#> * lambda_trend[1]);
#> lambda_trend ~ normal(5, 30);
#> {
#> // likelihood functions
#> vector[n_nonmissing] flat_trends;
#> flat_trends = to_vector(trend)[obs_ind];
#> flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends), 0.0,
#> append_row(b, 1.0));
#> }
#> }
#> generated quantities {
#> vector[total_obs] eta;
#> matrix[n, n_series] mus;
#> vector[n_sp_trend] rho_trend;
#> vector[n_lv] penalty;
#> array[n, n_series] int ypred;
#> penalty = 1.0 / (sigma .* sigma);
#> rho_trend = log(lambda_trend);
#>
#> matrix[n_series, n_lv] lv_coefs = Z;
#> // posterior predictions
#> eta = X * b;
#> for (s in 1 : n_series) {
#> mus[1 : n, s] = eta[ytimes[1 : n, s]] + trend[1 : n, s];
#> ypred[1 : n, s] = poisson_log_rng(mus[1 : n, s]);
#> }
#> }
Notice the line that states “lv_coefs = Z;”. This uses the supplied
\(Z\) matrix to construct the loading
coefficients. The supplied matrix now looks exactly like what you’d use
if you were to create a similar model in the MARSS
package:
fake_mod$model_data$Z
#> [,1] [,2]
#> [1,] 1 0
#> [2,] 1 0
#> [3,] 0 1
Fitting and inspecting the model
Though this model doesn’t perfectly match the data-generating process (which allowed each series to have different underlying dynamics), we can still fit it to show what the resulting inferences look like:
full_mod <- mvgam(y ~ series - 1,
trend_formula = ~ s(season, bs = 'cc', k = 6),
trend_model = AR(),
noncentred = TRUE,
trend_map = trend_map,
family = poisson(),
data = simdat$data_train,
silent = 2)
The summary of this model is informative as it shows that only two latent process models have been estimated, even though we have three observed time series. The model converges well
summary(full_mod)
#> GAM observation formula:
#> y ~ series - 1
#>
#> GAM process formula:
#> ~s(season, bs = "cc", k = 6)
#>
#> Family:
#> poisson
#>
#> Link function:
#> log
#>
#> Trend model:
#> AR()
#>
#> N process models:
#> 2
#>
#> N series:
#> 3
#>
#> N timepoints:
#> 75
#>
#> Status:
#> Fitted using Stan
#> 4 chains, each with iter = 1000; warmup = 500; thin = 1
#> Total post-warmup draws = 2000
#>
#>
#> GAM observation model coefficient (beta) estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> seriesseries_1 -2.80 -0.67 1.5 1 931
#> seriesseries_2 -1.80 0.31 2.5 1 924
#> seriesseries_3 -0.84 1.30 3.4 1 920
#>
#> Process model AR parameter estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> ar1[1] -0.73 -0.430 -0.056 1.00 666
#> ar1[2] -0.30 -0.019 0.250 1.01 499
#>
#> Process error parameter estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> sigma[1] 0.33 0.49 0.67 1 854
#> sigma[2] 0.59 0.73 0.91 1 755
#>
#> GAM process model coefficient (beta) estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> (Intercept)_trend -1.40 0.7800 2.90 1 921
#> s(season).1_trend -0.21 -0.0072 0.21 1 1822
#> s(season).2_trend -0.30 -0.0480 0.18 1 1414
#> s(season).3_trend -0.16 0.0680 0.30 1 1664
#> s(season).4_trend -0.14 0.0660 0.29 1 1505
#>
#> Approximate significance of GAM process smooths:
#> edf Ref.df Chi.sq p-value
#> s(season) 1.48 4 0.67 0.93
#>
#> Stan MCMC diagnostics:
#> n_eff / iter looks reasonable for all parameters
#> Rhat looks reasonable for all parameters
#> 0 of 2000 iterations ended with a divergence (0%)
#> 0 of 2000 iterations saturated the maximum tree depth of 12 (0%)
#> E-FMI indicated no pathological behavior
#>
#> Samples were drawn using NUTS(diag_e) at Tue Sep 03 2:51:42 PM 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split MCMC chains
#> (at convergence, Rhat = 1)
Both series 1 and 2 share the exact same latent process estimates, while the estimates for series 3 are different:
plot(full_mod, type = 'trend', series = 1)
plot(full_mod, type = 'trend', series = 2)
plot(full_mod, type = 'trend', series = 3)
However, forecasts for series’ 1 and 2 will differ because they have different intercepts in the observation model
Example: signal detection
Now we will explore a more complicated example. Here we simulate a
true hidden signal that we are trying to track. This signal depends
nonlinearly on some covariate (called productivity
, which
represents a measure of how productive the landscape is). The signal
also demonstrates a fairly large amount of temporal autocorrelation:
set.seed(0)
# simulate a nonlinear relationship using the mgcv function gamSim
signal_dat <- mgcv::gamSim(n = 100, eg = 1, scale = 1)
#> Gu & Wahba 4 term additive model
# productivity is one of the variables in the simulated data
productivity <- signal_dat$x2
# simulate the true signal, which already has a nonlinear relationship
# with productivity; we will add in a fairly strong AR1 process to
# contribute to the signal
true_signal <- as.vector(scale(signal_dat$y) +
arima.sim(100, model = list(ar = 0.8, sd = 0.1)))
Plot the signal to inspect it’s evolution over time
plot(true_signal, type = 'l',
bty = 'l', lwd = 2,
ylab = 'True signal',
xlab = 'Time')
Next we simulate three sensors that are trying to track the same
hidden signal. All of these sensors have different observation errors
that can depend nonlinearly on a second external covariate, called
temperature
in this example. Again this makes use of
gamSim
sim_series = function(n_series = 3, true_signal){
temp_effects <- mgcv::gamSim(n = 100, eg = 7, scale = 0.1)
temperature <- temp_effects$y
alphas <- rnorm(n_series, sd = 2)
do.call(rbind, lapply(seq_len(n_series), function(series){
data.frame(observed = rnorm(length(true_signal),
mean = alphas[series] +
1.5*as.vector(scale(temp_effects[, series + 1])) +
true_signal,
sd = runif(1, 1, 2)),
series = paste0('sensor_', series),
time = 1:length(true_signal),
temperature = temperature,
productivity = productivity,
true_signal = true_signal)
}))
}
model_dat <- sim_series(true_signal = true_signal) %>%
dplyr::mutate(series = factor(series))
#> Gu & Wahba 4 term additive model, correlated predictors
Plot the sensor observations
plot_mvgam_series(data = model_dat, y = 'observed',
series = 'all')
And now plot the observed relationships between the three sensors and
the temperature
covariate
plot(observed ~ temperature, data = model_dat %>%
dplyr::filter(series == 'sensor_1'),
pch = 16, bty = 'l',
ylab = 'Sensor 1',
xlab = 'Temperature')
plot(observed ~ temperature, data = model_dat %>%
dplyr::filter(series == 'sensor_2'),
pch = 16, bty = 'l',
ylab = 'Sensor 2',
xlab = 'Temperature')
plot(observed ~ temperature, data = model_dat %>%
dplyr::filter(series == 'sensor_3'),
pch = 16, bty = 'l',
ylab = 'Sensor 3',
xlab = 'Temperature')
The shared signal model
Now we can formulate and fit a model that allows each sensor’s
observation error to depend nonlinearly on temperature
while allowing the true signal to depend nonlinearly on
productivity
. By fixing all of the values in the
trend
column to 1
in the
trend_map
, we are assuming that all observation sensors are
tracking the same latent signal. We use informative priors on the two
variance components (process error and observation error), which reflect
our prior belief that the observation error is smaller overall than the
true process error
mod <- mvgam(formula =
# formula for observations, allowing for different
# intercepts and hierarchical smooth effects of temperature
observed ~ series +
s(temperature, k = 10) +
s(series, temperature, bs = 'sz', k = 8),
trend_formula =
# formula for the latent signal, which can depend
# nonlinearly on productivity
~ s(productivity, k = 8) - 1,
trend_model =
# in addition to productivity effects, the signal is
# assumed to exhibit temporal autocorrelation
AR(),
noncentred = TRUE,
trend_map =
# trend_map forces all sensors to track the same
# latent signal
data.frame(series = unique(model_dat$series),
trend = c(1, 1, 1)),
# informative priors on process error
# and observation error will help with convergence
priors = c(prior(normal(2, 0.5), class = sigma),
prior(normal(1, 0.5), class = sigma_obs)),
# Gaussian observations
family = gaussian(),
data = model_dat,
silent = 2)
View a reduced version of the model summary because there will be many spline coefficients in this model
summary(mod, include_betas = FALSE)
#> GAM observation formula:
#> observed ~ series + s(temperature, k = 10) + s(series, temperature,
#> bs = "sz", k = 8)
#>
#> GAM process formula:
#> ~s(productivity, k = 8) - 1
#>
#> Family:
#> gaussian
#>
#> Link function:
#> identity
#>
#> Trend model:
#> AR()
#>
#> N process models:
#> 1
#>
#> N series:
#> 3
#>
#> N timepoints:
#> 100
#>
#> Status:
#> Fitted using Stan
#> 4 chains, each with iter = 1100; warmup = 600; thin = 1
#> Total post-warmup draws = 2000
#>
#>
#> Observation error parameter estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> sigma_obs[1] 1.4 1.7 2.1 1 1080
#> sigma_obs[2] 1.7 2.0 2.3 1 2112
#> sigma_obs[3] 2.0 2.3 2.7 1 2799
#>
#> GAM observation model coefficient (beta) estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> (Intercept) -3.40 -2.1 -0.790 1.00 946
#> seriessensor_2 -2.80 -1.4 -0.015 1.01 1263
#> seriessensor_3 0.53 3.1 4.700 1.00 897
#>
#> Approximate significance of GAM observation smooths:
#> edf Ref.df Chi.sq p-value
#> s(temperature) 1.74 9 0.09 1
#> s(series,temperature) 2.47 16 106.38 7.6e-05 ***
#> ---
#> Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#>
#> Process model AR parameter estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> ar1[1] 0.39 0.59 0.78 1.01 541
#>
#> Process error parameter estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> sigma[1] 1.5 1.8 2.2 1 768
#>
#> Approximate significance of GAM process smooths:
#> edf Ref.df Chi.sq p-value
#> s(productivity) 1.04 7 5.12 1
#>
#> Stan MCMC diagnostics:
#> n_eff / iter looks reasonable for all parameters
#> Rhat looks reasonable for all parameters
#> 0 of 2000 iterations ended with a divergence (0%)
#> 0 of 2000 iterations saturated the maximum tree depth of 12 (0%)
#> E-FMI indicated no pathological behavior
#>
#> Samples were drawn using NUTS(diag_e) at Tue Sep 03 2:53:07 PM 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split MCMC chains
#> (at convergence, Rhat = 1)
Inspecting effects on both process and observation models
Don’t pay much attention to the approximate p-values of the
smooth terms. The calculation for these values is incredibly sensitive
to the estimates for the smoothing parameters so I don’t tend to find
them to be very meaningful. What are meaningful, however, are
prediction-based plots of the smooth functions. All main effects can be
quickly plotted with conditional_effects
:
conditional_effects(mod, type = 'link')
conditional_effects
is simply a wrapper to the more
flexible plot_predictions
function from the
marginaleffects
package. We can get more useful plots of
these effects using this function for further customisation:
require(marginaleffects)
#> Loading required package: marginaleffects
plot_predictions(mod,
condition = c('temperature', 'series', 'series'),
points = 0.5) +
theme(legend.position = 'none')
We have successfully estimated effects, some of them nonlinear, that impact the hidden process AND the observations. All in a single joint model. But there can always be challenges with these models, particularly when estimating both process and observation error at the same time.
Recovering the hidden signal
A final but very key question is whether we can successfully recover
the true hidden signal. The trend
slot in the returned
model parameters has the estimates for this signal, which we can easily
plot using the mvgam
S3 method for plot
. We
can also overlay the true values for the hidden signal, which shows that
our model has done a good job of recovering it:
Further reading
The following papers and resources offer a lot of useful material about other types of State-Space models and how they can be applied in practice:
Holmes, Elizabeth E., Eric J. Ward, and Wills Kellie. “MARSS: multivariate autoregressive state-space models for analyzing time-series data.” R Journal. 4.1 (2012): 11.
Ward, Eric J., et al. “Inferring spatial structure from time‐series data: using multivariate state‐space models to detect metapopulation structure of California sea lions in the Gulf of California, Mexico.” Journal of Applied Ecology 47.1 (2010): 47-56.
Auger‐Méthé, Marie, et al. “A guide to state–space modeling of ecological time series.” Ecological Monographs 91.4 (2021): e01470.