Skip to contents

This function sets up a Joint Species Distribution Model whereby the residual associations among species can be modelled in a reduced-rank format using a set of latent factors. The factor specification is extremely flexible, allowing users to include spatial, temporal or any other type of predictor effects to more efficiently capture unmodelled residual associations, while the observation model can also be highly flexible (including all smooth, GP and other effects that mvgam can handle)

Usage

jsdgam(
  formula,
  factor_formula = ~-1,
  knots,
  factor_knots,
  data,
  newdata,
  family = poisson(),
  unit = time,
  species = series,
  share_obs_params = FALSE,
  priors,
  n_lv = 2,
  backend = getOption("brms.backend", "cmdstanr"),
  algorithm = getOption("brms.algorithm", "sampling"),
  control = list(max_treedepth = 10, adapt_delta = 0.8),
  chains = 4,
  burnin = 500,
  samples = 500,
  thin = 1,
  parallel = TRUE,
  threads = 1,
  silent = 1,
  run_model = TRUE,
  return_model_data = FALSE,
  ...
)

Arguments

formula

A formula object specifying the GAM observation model formula. These are exactly like the formula for a GLM except that smooth terms, s(), te(), ti(), t2(), as well as time-varying dynamic() terms, nonparametric gp() terms and offsets using offset(), can be added to the right hand side to specify that the linear predictor depends on smooth functions of predictors (or linear functionals of these). Details of the formula syntax used by mvgam can be found in mvgam_formulae

factor_formula

A formula object specifying the linear predictor effects for the latent factors. Use by = trend within calls to functional terms (i.e. s(), te(), ti(), t2(), dynamic(), or gp()) to ensure that each factor captures a different axis of variation. See the example below as an illustration

knots

An optional list containing user specified knot values to be used for basis construction. For most bases the user simply supplies the knots to be used, which must match up with the k value supplied (note that the number of knots is not always just k). Different terms can use different numbers of knots, unless they share a covariate

factor_knots

An optional list containing user specified knot values to be used for basis construction of any smooth terms in factor_formula. For most bases the user simply supplies the knots to be used, which must match up with the k value supplied (note that the number of knots is not always just k). Different terms can use different numbers of knots, unless they share a covariate

data

A dataframe or list containing the model response variable and covariates required by the GAM formula and factor_formula objects

newdata

Optional dataframe or list of test data containing the same variables as in data. If included, the observations in variable y will be set to NA when fitting the model so that posterior simulations can be obtained

family

family specifying the observation family for the outcomes. Currently supported families are:

Default is poisson(). See mvgam_families for more details

unit

The unquoted name of the variable that represents the unit of analysis in data over which latent residuals should be correlated. This variable should be either a numeric or integer variable in the supplied data. Defaults to time to be consistent with other functionalities in mvgam, though note that the data need not be time series in this case. See examples below for further details and explanations

species

The unquoted name of the factor variable that indexes the different response units in data (usually 'species' in a JSDM). Defaults to series to be consistent with other mvgam models

share_obs_params

logical. If TRUE and the family has additional family-specific observation parameters (e.g. variance components in student_t() or gaussian(), or dispersion parameters in nb() or betar()), these parameters will be shared across all outcome variables. This is handy if you have multiple outcomes (time series in most mvgam models) that you believe share some properties, such as being from the same species over different spatial units. Default is FALSE.

priors

An optional data.frame with prior definitions (in Stan syntax) or, preferentially, a vector containing objects of class brmsprior (see. prior for details). See get_mvgam_priors and for more information on changing default prior distributions

n_lv

integer the number of latent factors to use for modelling residual associations. Cannot be > n_species. Defaults arbitrarily to 2

backend

Character string naming the package to use as the backend for fitting the Stan model. Options are "cmdstanr" (the default) or "rstan". Can be set globally for the current R session via the "brms.backend" option (see options). Details on the rstan and cmdstanr packages are available at https://mc-stan.org/rstan/ and https://mc-stan.org/cmdstanr/, respectively

algorithm

Character string naming the estimation approach to use. Options are "sampling" for MCMC (the default), "meanfield" for variational inference with factorized normal distributions, "fullrank" for variational inference with a multivariate normal distribution, "laplace" for a Laplace approximation (only available when using cmdstanr as the backend) or "pathfinder" for the pathfinder algorithm (only currently available when using cmdstanr as the backend). Can be set globally for the current R session via the "brms.algorithm" option (see options). Limited testing suggests that "meanfield" performs best out of the non-MCMC approximations for dynamic GAMs, possibly because of the difficulties estimating covariances among the many spline parameters and latent trend parameters. But rigorous testing has not been carried out

control

A named list for controlling the sampler's behaviour. Currently only accepts settings for

chains

integer specifying the number of parallel chains for the model. Ignored if algorithm %in% c('meanfield', 'fullrank', 'pathfinder', 'laplace')

burnin

integer specifying the number of warmup iterations of the Markov chain to run to tune sampling algorithms. Ignored if algorithm %in% c('meanfield', 'fullrank', 'pathfinder', 'laplace')

samples

integer specifying the number of post-warmup iterations of the Markov chain to run for sampling the posterior distribution

thin

Thinning interval for monitors. Ignored if algorithm %in% c('meanfield', 'fullrank', 'pathfinder', 'laplace')

parallel

logical specifying whether multiple cores should be used for generating MCMC simulations in parallel. If TRUE, the number of cores to use will be min(c(chains, parallel::detectCores() - 1))

threads

integer Experimental option to use multithreading for within-chain parallelisation in Stan. We recommend its use only if you are experienced with Stan's reduce_sum function and have a slow running model that cannot be sped up by any other means. Currently works for all families when using Cmdstan as the backend

silent

Verbosity level between 0 and 2. If 1 (the default), most of the informational messages of compiler and sampler are suppressed. If 2, even more messages are suppressed. The actual sampling progress is still printed. Set refresh = 0 to turn this off as well. If using backend = "rstan" you can also set open_progress = FALSE to prevent opening additional progress bars.

run_model

logical. If FALSE, the model is not fitted but instead the function will return the model file and the data / initial values that are needed to fit the model outside of mvgam

return_model_data

logical. If TRUE, the list of data that is needed to fit the model is returned, along with the initial values for smooth and AR parameters, once the model is fitted. This will be helpful if users wish to modify the model file to add other stochastic elements that are not currently available in mvgam. Default is FALSE to reduce the size of the returned object, unless run_model == FALSE

...

Other arguments to pass to mvgam

Value

A list object of class mvgam containing model output, the text representation of the model file, the mgcv model output (for easily generating simulations at unsampled covariate values), Dunn-Smyth residuals for each species and key information needed for other functions in the package. See mvgam-class for details. Use methods(class = "mvgam") for an overview on available methods

Details

Joint Species Distribution Models allow for responses of multiple species to be learned hierarchically, whereby responses to environmental variables in formula can be partially pooled and any latent, unmodelled residual associations can also be learned. In mvgam, both of these effects can be modelled with the full power of latent factor Hierarchical GAMs, providing unmatched flexibility to model full communities of species. When calling jsdgam, an initial State-Space model using trend = 'None' is set up and then modified to include the latent factors and their linear predictors. Consequently, you can inspect priors for these models using get_mvgam_priors by supplying the relevant formula, factor_formula, data and family arguments and keeping the default trend = 'None'.

In a JSDGAM, the expectation of response \(Y_{ij}\) is modelled with

$$g(\mu_{ij}) = X_i\beta + u_i\theta_j,$$

where \(g(.)\) is a known link function, \(X\) is a design matrix of linear predictors (with associated \(\beta\) coefficients), \(u\) are \(n_{lv}\)-variate latent factors (\(n_{lv}\)<<\(n_{species}\)) and \(\theta_j\) are species-specific loadings on the latent factors, respectively. The design matrix \(X\) and \(\beta\) coefficients are constructed and modelled using formula and can contain any of mvgam's predictor effects, including random intercepts and slopes, multidimensional penalized smooths, GP effects etc... The factor loadings \(\theta_j\) are constrained for identifiability but can be used to reconstruct an estimate of the species' residual variance-covariance matrix using \(\Theta \Theta'\) (see the example below and residual_cor() for details). The latent factors are further modelled using: $$ u_i \sim \text{Normal}(Q_i\beta_{factor}, 1) \quad $$ where the second design matrix \(Q\) and associated \(\beta_{factor}\) coefficients are constructed and modelled using factor_formula. Again, the effects that make up this linear predictor can contain any of mvgam's allowed predictor effects, providing enormous flexibility for modelling species' communities.

References

Nicholas J Clark & Konstans Wells (2020). Dynamic generalised additive models (DGAMs) for forecasting discrete ecological time series. Methods in Ecology and Evolution. 14:3, 771-784.

David I Warton, F Guillaume Blanchet, Robert B O’Hara, Otso Ovaskainen, Sara Taskinen, Steven C Walker & Francis KC Hui (2015). So many variables: joint modeling in community ecology. Trends in Ecology & Evolution 30:12, 766-779.

Author

Nicholas J Clark

Examples

# \donttest{
# Simulate latent count data for 500 spatial locations and 10 species
set.seed(0)
N_points <- 500
N_species <- 10

# Species-level intercepts (on the log scale)
alphas <- runif(N_species, 2, 2.25)

# Simulate a covariate and species-level responses to it
temperature <- rnorm(N_points)
betas <- runif(N_species, -0.5, 0.5)

# Simulate points uniformly over a space
lon <- runif(N_points, min = 150, max = 155)
lat <- runif(N_points, min = -20, max = -19)

# Set up spatial basis functions as a tensor product of lat and lon
sm <- mgcv::smoothCon(mgcv::te(lon, lat, k = 5),
                      data = data.frame(lon, lat),
                      knots = NULL)[[1]]

# The design matrix for this smooth is in the 'X' slot
des_mat <- sm$X
dim(des_mat)
#> [1] 500  25

# Function to generate a random covariance matrix where all variables
# have unit variance (i.e. diagonals are all 1)
random_Sigma = function(N){
  L_Omega <- matrix(0, N, N);
  L_Omega[1, 1] <- 1;
  for (i in 2 : N) {
    bound <- 1;
    for (j in 1 : (i - 1)) {
      L_Omega[i, j] <- runif(1, -sqrt(bound), sqrt(bound));
      bound <- bound - L_Omega[i, j] ^ 2;
    }
    L_Omega[i, i] <- sqrt(bound);
  }
  Sigma <- L_Omega %*% t(L_Omega);
  return(Sigma)
}

# Simulate a variance-covariance matrix for the correlations among
# basis coefficients
Sigma <- random_Sigma(N = NCOL(des_mat))

# Now simulate the species-level basis coefficients hierarchically, where
# spatial basis function correlations are a convex sum of a base correlation
# matrix and a species-level correlation matrix
basis_coefs <- matrix(NA, nrow = N_species, ncol = NCOL(Sigma))
base_field <- mgcv::rmvn(1, mu = rep(0, NCOL(Sigma)), V = Sigma)
for(t in 1:N_species){
  corOmega <- (cov2cor(Sigma) * 0.7) +
                 (0.3 * cov2cor(random_Sigma(N = NCOL(des_mat))))
  basis_coefs[t, ] <- mgcv::rmvn(1, mu = rep(0, NCOL(Sigma)), V = corOmega)
}

# Simulate the latent spatial processes
st_process <- do.call(rbind, lapply(seq_len(N_species), function(t){
  data.frame(lat = lat,
             lon = lon,
             species = paste0('species_', t),
             temperature = temperature,
             process = alphas[t] +
               betas[t] * temperature +
               des_mat %*% basis_coefs[t,])
}))

# Now take noisy observations at some of the points (60)
obs_points <- sample(1:N_points, size = 60, replace = FALSE)
obs_points <- data.frame(lat = lat[obs_points],
                         lon = lon[obs_points],
                         site = 1:60)

# Keep only the process data at these points
st_process %>%
  dplyr::inner_join(obs_points, by = c('lat', 'lon')) %>%
  # now take noisy Poisson observations of the process
  dplyr::mutate(count = rpois(NROW(.), lambda = exp(process))) %>%
  dplyr::mutate(species = factor(species,
                                 levels = paste0('species_', 1:N_species))) %>%
  dplyr::group_by(lat, lon) -> dat

# View the count distributions for each species
library(ggplot2)
ggplot(dat, aes(x = count)) +
  geom_histogram() +
  facet_wrap(~ species, scales = 'free')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.


ggplot(dat, aes(x = lat, y = lon, col = log(count + 1))) +
  geom_point(size = 2.25) +
  facet_wrap(~ species, scales = 'free') +
  scale_color_viridis_c() +
  theme_classic()


# Inspect default priors for a joint species model with three spatial factors
priors <- get_mvgam_priors(formula = count ~
                            # Environmental model includes random slopes for
                            # a linear effect of temperature
                            s(species, bs = 're', by = temperature),

                          # Each factor estimates a different nonlinear spatial process, using
                          # 'by = trend' as in other mvgam State-Space models
                          factor_formula = ~ gp(lat, lon, k = 6, by = trend) - 1,
                          n_lv = 3,

                          # The data and grouping variables
                          data = dat,
                          unit = site,
                          species = species,

                          # Poisson observations
                          family = poisson())
#> Warning: gp effects in mvgam cannot yet handle autogrouping
#> resetting all instances of 'gr = TRUE' to 'gr = FALSE'
#> This warning is displayed once per session.
head(priors)
#>                                            param_name param_length
#> 1                                         (Intercept)            1
#> 2                                   vector[1] mu_raw;            1
#> 3                       vector<lower=0>[1] sigma_raw;            1
#> 4 real<lower=0> alpha_gp_trend(lat, lon):trendtrend1;            1
#> 5 real<lower=0> alpha_gp_trend(lat, lon):trendtrend2;            1
#> 6 real<lower=0> alpha_gp_trend(lat, lon):trendtrend3;            1
#>                                    param_info
#> 1                                 (Intercept)
#> 2             s(species):temperature pop mean
#> 3               s(species):temperature pop sd
#> 4 gp(lat, lon):trendtrend1 marginal deviation
#> 5 gp(lat, lon):trendtrend2 marginal deviation
#> 6 gp(lat, lon):trendtrend3 marginal deviation
#>                                                          prior
#> 1                        (Intercept) ~ student_t(3, 2.1, 2.5);
#> 2                                       mu_raw ~ std_normal();
#> 3                            sigma_raw ~ student_t(3, 0, 2.5);
#> 4 alpha_gp_trend(lat, lon):trendtrend1 ~ student_t(3, 0, 2.5);
#> 5 alpha_gp_trend(lat, lon):trendtrend2 ~ student_t(3, 0, 2.5);
#> 6 alpha_gp_trend(lat, lon):trendtrend3 ~ student_t(3, 0, 2.5);
#>                                            example_change new_lowerbound
#> 1                             (Intercept) ~ normal(0, 1);           <NA>
#> 2                            mu_raw ~ normal(0.65, 0.15);           <NA>
#> 3                          sigma_raw ~ exponential(0.27);           <NA>
#> 4 alpha_gp_trend(lat, lon):trendtrend1 ~ normal(0, 0.86);           <NA>
#> 5 alpha_gp_trend(lat, lon):trendtrend2 ~ normal(0, 0.84);           <NA>
#> 6 alpha_gp_trend(lat, lon):trendtrend3 ~ normal(0, 0.78);           <NA>
#>   new_upperbound
#> 1           <NA>
#> 2           <NA>
#> 3           <NA>
#> 4           <NA>
#> 5           <NA>
#> 6           <NA>

# Fit a JSDM that estimates hierarchical temperature responses
# and that uses three latent spatial factors
mod <- jsdgam(formula = count ~
                # Environmental model includes random slopes for a
                # linear effect of temperature
                s(species, bs = 're', by = temperature),

              # Each factor estimates a different nonlinear spatial process, using
              # 'by = trend' as in other mvgam State-Space models
              factor_formula = ~ gp(lat, lon, k = 6, by = trend) - 1,
              n_lv = 3,

              # Change default priors for fixed random effect variances and
              # factor P marginal deviations to standard normal
              priors = c(prior(std_normal(),
                               class = sigma_raw),
                         prior(std_normal(),
                               class = `alpha_gp_trend(lat, lon):trendtrend1`),
                         prior(std_normal(),
                               class = `alpha_gp_trend(lat, lon):trendtrend2`),
                         prior(std_normal(),
                               class = `alpha_gp_trend(lat, lon):trendtrend3`)),

              # The data and the grouping variables
              data = dat,
              unit = site,
              species = species,

              # Poisson observations
              family = poisson(),
              chains = 2,
              silent = 2)

# Plot species-level intercept estimates
plot_predictions(mod, condition = 'species',
                 type = 'link')


# Plot species' hierarchical responses to temperature
plot_predictions(mod, condition = c('temperature', 'species', 'species'),
                 type = 'link')


# Plot posterior median estimates of the latent spatial factors
plot(mod, type = 'smooths', trend_effects = TRUE)


# Or using gratia, if you have it installed
if(requireNamespace('gratia', quietly = TRUE)){
  gratia::draw(mod, trend_effects = TRUE)
}


# Calculate residual spatial correlations
post_cors <- residual_cor(mod)
names(post_cors)
#>  [1] "cor"        "cor_lower"  "cor_upper"  "sig_cor"    "cov"       
#>  [6] "prec"       "prec_lower" "prec_upper" "sig_prec"   "trace"     
# Look at lower and upper credible interval estimates for
# some of the estimated correlations
post_cors$cor[1:5, 1:5]
#>            species_1  species_2   species_3  species_4   species_5
#> species_1  1.0000000  0.7336280  0.72991602  0.7861474 -0.24696228
#> species_2  0.7336280  1.0000000  0.15420441  0.3813268 -0.47995732
#> species_3  0.7299160  0.1542044  1.00000000  0.8908844 -0.03924676
#> species_4  0.7861474  0.3813268  0.89088435  1.0000000 -0.44731327
#> species_5 -0.2469623 -0.4799573 -0.03924676 -0.4473133  1.00000000
post_cors$cor_upper[1:5, 1:5]
#>           species_1  species_2 species_3  species_4  species_5
#> species_1 1.0000000  0.9242392 0.9566867  0.9575828  0.1509028
#> species_2 0.9242392  1.0000000 0.5199946  0.7191710 -0.1730502
#> species_3 0.9566867  0.5199946 1.0000000  0.9774616  0.3136909
#> species_4 0.9575828  0.7191710 0.9774616  1.0000000 -0.1076519
#> species_5 0.1509028 -0.1730502 0.3136909 -0.1076519  1.0000000
post_cors$cor_lower[1:5, 1:5]
#>            species_1  species_2  species_3  species_4  species_5
#> species_1  1.0000000  0.4226772  0.4168132  0.4720316 -0.6018739
#> species_2  0.4226772  1.0000000 -0.2345903 -0.0158343 -0.7385235
#> species_3  0.4168132 -0.2345903  1.0000000  0.7304959 -0.3965220
#> species_4  0.4720316 -0.0158343  0.7304959  1.0000000 -0.7197583
#> species_5 -0.6018739 -0.7385235 -0.3965220 -0.7197583  1.0000000
# A quick and dirty plot of the posterior median correlations
image(post_cors$cor)


# Posterior predictive checks and ELPD-LOO can ascertain model fit
pp_check(mod, type = "pit_ecdf_grouped",
         group = "species", ndraws = 100)

loo(mod)
#> Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
#> 
#> Computed from 1000 by 600 log-likelihood matrix.
#> 
#>          Estimate    SE
#> elpd_loo  -2190.8  51.4
#> p_loo       421.9  27.1
#> looic      4381.6 102.8
#> ------
#> MCSE of elpd_loo is NA.
#> MCSE and ESS estimates assume MCMC draws (r_eff in [0.0, 1.6]).
#> 
#> Pareto k diagnostic values:
#>                           Count Pct.    Min. ESS
#> (-Inf, 0.67]   (good)     438   73.0%   1       
#>    (0.67, 1]   (bad)      132   22.0%   <NA>    
#>     (1, Inf)   (very bad)  30    5.0%   <NA>    
#> See help('pareto-k-diagnostic') for details.

# Forecast log(counts) for entire region (site value doesn't matter as long
# as each spatial location has a different and unique site identifier);
# note this calculation takes a few minutes because of the need to calculate
# draws from the stochastic latent factors
newdata <- st_process %>%
                   dplyr::mutate(species = factor(species,
                                                  levels = paste0('species_',
                                                                  1:N_species))) %>%
                   dplyr::group_by(lat, lon) %>%
                   dplyr::mutate(site = dplyr::cur_group_id()) %>%
                   dplyr::ungroup()
preds <- predict(mod, newdata = newdata)

# Plot the median log(count) predictions on a grid
newdata$log_count <- preds[,1]
ggplot(newdata, aes(x = lat, y = lon, col = log_count)) +
  geom_point(size = 1.5) +
  facet_wrap(~ species, scales = 'free') +
  scale_color_viridis_c() +
  theme_classic()

# }