Interpreting validation loss curve in query to reference mapping

Hello,
I’ve been using the scArches implementation in scvi-tools for query to reference mapping with trained SCVI models. I wanted to explore the trends of reconstruction loss on training and validation set to detect possible overfitting issues. During training of the reference scVI model, I see the expected trend where the training error underestimates the validation error. Conversely during training of the query model I find that the validation error is consistently lower than the training error and I find this puzzling.

Following the example in the reference-mapping tutorial:

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import anndata
import scvi
import scanpy as sc

sc.set_figure_params(figsize=(4, 4))
scvi.settings.seed = 94705

url = "https://figshare.com/ndownloader/files/24539828"
adata = sc.read("pancreas.h5ad", backup_url=url)
print(adata)

## Define query dataset
query = np.array([s in ["smartseq2", "celseq2"] for s in adata.obs.tech])
adata_ref = adata[~query].copy()
adata_query = adata[query].copy()
sc.pp.highly_variable_genes(
    adata_ref,
    n_top_genes=2000,
    batch_key="tech",
    subset=True
)

adata_query = adata_query[:, adata_ref.var_names].copy()

## Train reference
scvi.model.SCVI.setup_anndata(adata_ref, batch_key="tech", layer="counts")
arches_params = dict(
    use_layer_norm="both",
    use_batch_norm="none",
    encode_covariates=True,
    dropout_rate=0.2,
    n_layers=2,
)

vae_ref = scvi.model.SCVI(
    adata_ref,
    **arches_params
)
vae_ref.train(check_val_every_n_epoch=1)

Plotting validation error for reference model training:

plt.plot(vae_ref.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train');
plt.plot(vae_ref.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation');
plt.legend()

Screenshot 2022-01-04 at 15.17.57

Plotting validation error for query model training:

plt.plot(vae_q.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train');
plt.plot(vae_q.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation');
plt.legend()

Screenshot 2022-01-04 at 15.18.02

Any intuition on why this might be going on?

Many thanks in advance,
Emma

Hi Emma,

We are a bit stumped and will think about it. We reproduced it on the pancreas example. Some initial thoughts:

  1. Could be due to random sampling – e.g,. validation set is small and is biased with one cell type, perhaps you could try a query train size of 50% and see if it’s different.
  2. The network during query training has very few parameters

We will get back to you if we have any more insight.

This thread also has some relevant points: