totalVI NaN loss with few proteins

Thank you for sharing these awesome tools with the bioinformatics community.

After using totalVI without any issues, I recently started experiencing problems in training my totalVI models and am writing to solicit some advice on how to resolve this. The totalVI example in the tutorial continues to run fine, but when I try something analogous for my data, I see the following:

 import numpy as np

x = concatenated_adata.copy()

sc.pp.highly_variable_genes(
        x,
        batch_key="batch",
        flavor="seurat_v3",
        layer = 'counts',
        n_top_genes=4000,
        subset=True
)

scvi.data.setup_anndata(x,
                        batch_key = "batch",
                        layer = 'counts',
                        protein_expression_obsm_key = "protein_expression")

x_model = scvi.model.TOTALVI(x, latent_distribution = "normal", n_layers_decoder = 2)

x_model.train(max_epochs = 50,
              use_gpu = True)

which results in the following error(s):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
  and should_run_async(code)
If you pass `n_top_genes`, all cutoffs are ignored.
extracting highly variable genes
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scanpy/preprocessing/_highly_variable_genes.py:144: FutureWarning: Slicing a positional slice with .loc is not supported, and will raise TypeError in a future version.  Use .loc with labels or .iloc with positions instead.
  df.loc[: int(n_top_genes), 'highly_variable'] = True
--> added
    'highly_variable', boolean vector (adata.var)
    'highly_variable_rank', float vector (adata.var)
    'means', float vector (adata.var)
    'variances', float vector (adata.var)
    'variances_norm', float vector (adata.var)
INFO     Using batches from adata.obs["batch"]                                               
INFO     No label_key inputted, assuming all cells have same label                           
INFO     Using data from adata.layers["counts"]                                              
INFO     Computing library size prior per batch                                              
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
/lib/python3.8/site-packages/anndata-0.7.5-py3.8.egg/anndata/_core/anndata.py:1116: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_sub[k] = df_sub[k].cat.remove_unused_categories()
INFO     Using protein expression from adata.obsm['protein_expression']                      
INFO     Using protein names from columns of adata.obsm['protein_expression']                
INFO     Found batches with missing protein expression                                       
INFO     Successfully registered anndata object containing 35523 cells, 4000 vars, 6 batches,
         1 labels, and 3 proteins. Also registered 0 extra categorical covariates and 0 extra
         continuous covariates.                                                              
INFO     Please do not further modify adata until model is trained.                          
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Set SLURM handle signals.
Epoch 1/50:   0%|          | 0/50 [00:03<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-13-50c61792beb1> in <module>
     21 x_model = scvi.model.TOTALVI(x, latent_distribution = "normal", n_layers_decoder = 2)
     22 
---> 23 x_model.train(max_epochs = 50,
     24               use_gpu = True)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/model/_totalvi.py in train(self, max_epochs, lr, use_gpu, train_size, validation_size, batch_size, early_stopping, check_val_every_n_epoch, reduce_lr_on_plateau, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_classifier, plan_kwargs, **kwargs)
    257             **kwargs,
    258         )
--> 259         return runner()
    260 
    261     @torch.no_grad()

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/train/_trainrunner.py in __call__(self)
     73             self.trainer.fit(self.training_plan, train_dl)
     74         else:
---> 75             self.trainer.fit(self.training_plan, train_dl, val_dl)
     76         try:
     77             self.model.history_ = self.trainer.logger.history

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
    150                 message="you defined a validation_step but have no val_dataloader",
    151             )
--> 152             super().fit(*args, **kwargs)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    512 
    513         # dispath `start_training` or `start_testing` or `start_predicting`
--> 514         self.dispatch()
    515 
    516         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    552 
    553         else:
--> 554             self.accelerator.start_training(self)
    555 
    556     def train_or_test_or_predict(self):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     72 
     73     def start_training(self, trainer):
---> 74         self.training_type_plugin.start_training(trainer)
     75 
     76     def start_testing(self, trainer):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    109     def start_training(self, trainer: 'Trainer') -> None:
    110         # double dispatch to initiate the training loop
--> 111         self._results = trainer.run_train()
    112 
    113     def start_testing(self, trainer: 'Trainer') -> None:

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    643                 with self.profiler.profile("run_training_epoch"):
    644                     # run train epoch
--> 645                     self.train_loop.run_training_epoch()
    646 
    647                 if self.max_steps and self.max_steps <= self.global_step:

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
    491             # ------------------------------------
    492             with self.trainer.profiler.profile("run_training_batch"):
--> 493                 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
    494 
    495             # when returning -1 from train_step, we end epoch early

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
    653 
    654                         # optimizer step
--> 655                         self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    656 
    657                     else:

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    424 
    425         # model hook
--> 426         model_ref.optimizer_step(
    427             self.trainer.current_epoch,
    428             batch_idx,

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1382             # wraps into LightingOptimizer only for running step
   1383             optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, optimizer_idx)
-> 1384         optimizer.step(closure=optimizer_closure)
   1385 
   1386     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, *args, **kwargs)
    212             profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
    213 
--> 214         self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
    215         self._total_optimizer_step_calls += 1
    216 

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py in __optimizer_step(self, closure, profiler_name, **kwargs)
    132 
    133         with trainer.profiler.profile(profiler_name):
--> 134             trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
    135 
    136     def step(self, *args, closure: Optional[Callable] = None, **kwargs):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
    276         )
    277         if make_optimizer_step:
--> 278             self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
    279         self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
    280         self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
    281 
    282     def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
--> 283         self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
    284 
    285     def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in optimizer_step(self, optimizer, lambda_closure, **kwargs)
    158 
    159     def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 160         optimizer.step(closure=lambda_closure, **kwargs)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
     87                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     88                 with torch.autograd.profiler.record_function(profile_name):
---> 89                     return func(*args, **kwargs)
     90             return wrapper
     91 

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.__class__():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/optim/adam.py in step(self, closure)
     64         if closure is not None:
     65             with torch.enable_grad():
---> 66                 loss = closure()
     67 
     68         for group in self.param_groups:

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in train_step_and_backward_closure()
    647 
    648                         def train_step_and_backward_closure():
--> 649                             result = self.training_step_and_backward(
    650                                 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
    651                             )

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
    741         with self.trainer.profiler.profile("training_step_and_backward"):
    742             # lightning module hook
--> 743             result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
    744             self._curr_step_result = result
    745 

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
    291             model_ref._results = Result()
    292             with self.trainer.profiler.profile("training_step"):
--> 293                 training_step_output = self.trainer.accelerator.training_step(args)
    294                 self.trainer.accelerator.post_training_step()
    295 

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in training_step(self, args)
    155 
    156         with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 157             return self.training_type_plugin.training_step(*args)
    158 
    159     def post_training_step(self):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in training_step(self, *args, **kwargs)
    120 
    121     def training_step(self, *args, **kwargs):
--> 122         return self.lightning_module.training_step(*args, **kwargs)
    123 
    124     def post_training_step(self):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/train/_trainingplans.py in training_step(self, batch, batch_idx, optimizer_idx)
    346         if optimizer_idx == 1:
    347             inference_inputs = self.module._get_inference_input(batch)
--> 348             outputs = self.module.inference(**inference_inputs)
    349             z = outputs["z"]
    350             loss = self.loss_adversarial_classifier(z.detach(), batch_tensor, True)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
     30         # decorator only necessary after training
     31         if self.training:
---> 32             return fn(self, *args, **kwargs)
     33 
     34         device = list(set(p.device for p in self.parameters()))

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/module/_totalvae.py in inference(self, x, y, batch_index, label, n_samples, transform_batch, cont_covs, cat_covs)
    436         else:
    437             categorical_input = tuple()
--> 438         qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder(
    439             encoder_input, batch_index, *categorical_input
    440         )

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/nn/_base_components.py in forward(self, data, *cat_list)
    984         qz_m = self.z_mean_encoder(q)
    985         qz_v = torch.exp(self.z_var_encoder(q)) + 1e-4
--> 986         z, untran_z = self.reparameterize_transformation(qz_m, qz_v)
    987 
    988         ql_gene = self.l_gene_encoder(data, *cat_list)

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/scvi/nn/_base_components.py in reparameterize_transformation(self, mu, var)
    950 
    951     def reparameterize_transformation(self, mu, var):
--> 952         untran_z = Normal(mu, var.sqrt()).rsample()
    953         z = self.z_transformation(untran_z)
    954         return z, untran_z

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
     48         else:
     49             batch_shape = self.loc.size()
---> 50         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     51 
     52     def expand(self, batch_shape, _instance=None):

anaconda3/envs/scvi_tools_env/lib/python3.8/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     51                     continue  # skip checking lazily-constructed args
     52                 if not constraint.check(getattr(self, param)).all():
---> 53                     raise ValueError("The parameter {} has invalid values".format(param))
     54         super(Distribution, self).__init__()
     55 

ValueError: The parameter loc has invalid values

This is the AnnData setup summary:

Anndata setup with scvi-tools version 0.9.0.
              Data Summary              
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┓
┃             Data             ┃ Count ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━┩
│            Cells             │ 35523 │
│             Vars             │ 10620 │
│            Labels            │   1   │
│           Batches            │   6   │
│           Proteins           │   3   │
│ Extra Categorical Covariates │   0   │
│ Extra Continuous Covariates  │   0   │
└──────────────────────────────┴───────┘
                   SCVI Data Registry                    
┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Data        ┃       scvi-tools Location        ┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         X          │      adata.layers['counts']      │
│   batch_indices    │     adata.obs['_scvi_batch']     │
│    local_l_mean    │ adata.obs['_scvi_local_l_mean']  │
│    local_l_var     │  adata.obs['_scvi_local_l_var']  │
│       labels       │    adata.obs['_scvi_labels']     │
│ protein_expression │ adata.obsm['protein_expression'] │
└────────────────────┴──────────────────────────────────┘
                        Label Categories                        
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃      Source Location      ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['_scvi_labels'] │     0      │          0          │
└───────────────────────────┴────────────┴─────────────────────┘
                        Batch Categories                         
┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃  Source Location   ┃     Categories     ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['batch'] │    PB_baseline     │          0          │
│                    │   PB_primary_CMV   │          1          │
│                    │  PB_steady_state   │          2          │
│                    │  LN_steady_state   │          3          │
│                    │ PB_CMV_rechallenge │          4          │
│                    │ LN_CMV_rechallenge │          5          │
└────────────────────┴────────────────────┴─────────────────────┘

Any recommendations on what I might try? Thank in advance!

It’s hard to tell what might be the problem just from this traceback. But you might try two things:

  1. See if you can run your model with just scVI (ignore protein data for now) scvi.model.SCVI
  2. Turn down the learning rate of totalVI in the train method to e.g., 2e-3.

And just to double check, the protein data is count data yes?

I should have stated this: scVI and scANVI continue to work just fine as does the totalVI example in the scvi-tools tutorial. All attempts to run totalVI on my data, including turning down the learning rate, have been unsuccessful. Interestingly it completes a variable number of epochs - even without changing the learning rate - before raising this exception.

I confirm I am indeed trying to run totalVI on count data for both mRNA and protein.

I realise how incredibly annoying it is to share and develop nice code and then get asked support-type questions. I’m really just looking for your gut reaction: is it something with my data and should I focus my attention on finding some kind of problem there … or is it more worthwhile for me to try to step through your very nice code with ipdb or spyder?

Thanks again for the benefit of your advice.

It’s either the data or some hyperparam; though I’ve never seen this sort of error before on the mean of the latent space.

Does scVI run successfully on your data? If it does, I could quickly try to run totalVI if you are open to providing the data.

If scVI does not run successfully – are there any cells with all 0 counts?

Thanks for sending the data, upon some inspection you should do the following:

x_model = scvi.model.TOTALVI(adata, empirical_protein_background_prior=False, n_layers_decoder=2)

Basically I think since you have only 3 proteins, the new empirically learned prior initialization is getting thrown off and having bad values. This was something we added after the fact to try to better initialize the parameters that represent the protein background.

I can put a warning message if there are fewer than 10 proteins to alert users of this potential issue.

Thank so much!! Everything appears to work fine now. The data come from a species (and a cell type within that species) for which there aren’t many cross-reactive CITE-seq antibodies. Even those three antibodies help us quite a bit and thanks to totalVI we are able to overcome some of those limitations by mapping into references of other species with more CITE-seq antibodies.

Do you think it might make sense to have the default value of empirical_protein_background_prior to be (the boolean value of) n_proteins < 10?

Thanks again for sharing your awesome tool with this community.

Yes we can definitely do this. Would you be willing to make an issue on GitHub and link to this discussion?