Model complexity selection

Hi everybody,

I am currently running the scvi-tools to perform batch correction and dimensionality reduction of the dataset consisting of multiple patients. I am not sure about the complexity of the model that has been chosen for dimensionality reduction.

Here is an example of the code I am currently running:

sc.pp.highly_variable_genes(adata_da_HVG, n_top_genes=5000, subset=True)

model_da = scvi.model.SCVI(adata_da_HVG)


I wonder if there is a way to visualize the model complexity to look at bias-variance tradeoff for the particular dataset. I would like to check if I would need more layers to perform the dimensionality reduction more efficiently.


I believe you can investigate the bias variance trade off by looking at the training curves. Your model_da should have an field .history that has a dictionary of training metrics per epoch of model training. It definitely has the reconstruction error for the training set. I can’t remember how to make the training also record the reconstruction error for the test set. @adamgayoso is this in the documentation somewhere?

When you have these curves for alternative models you can compare them. Curves with higher reconstruction error in both train and test will have more bias. Curves where the test error goes up compared to the training error will have more variance. (Caveat, I always get the concepts of bias and variance for model mixed, I’m hoping I’m getting it right now).

Now, how to quantify the complexity of the model? That is hard I think… You could try to count all the parameters, but there are also various forms of regularization both for the neural networks and the latent variables (and other parameters). Not sure how that factors in to the definition of model complexity.


1 Like

Two options here. Note that by default the “test” set is really called the “validation” set. This convention is due to the fact that the validation set can be used for early stopping. I can go into more detail on this if it’s confusing.

  1. Add check_val_every_n_epoch=1 as an argument to train to record all the losses on the validation set. Note, training will be slower, you can set this value to 10, or 20, if you want.
  2. call model.get_reconstruction_error(indices=model.validation_indices). I believe here higher is better, where the trainer history is logging the negative of this.
1 Like