Constructing a high-level model: Problem with .train function

Hello everybody,

First of all I would like to thank you for creating this great tool and the detailed tutorials!

Unfortunately, I always run into problems when trying to create my own models. There are different error messages occurring, when I am trying to train my models.
As an example, I am trying to create a VAE which slightly differs from the one used by scvi-tools.

First I am creating a different Decoder

from typing import Iterable

import torch
from scvi.nn import FCLayers
from torch import nn as nn


# Decoder 
class MyDecoder(nn.Module):
    """
    Decodes data from latent space to data space.
    ``n_input`` dimensions to ``n_output``
    dimensions using a fully-connected neural network of ``n_hidden`` layers.
    Output is the mean and variance of a multivariate Gaussian
    Parameters
    ----------
    n_input
        The dimensionality of the input (latent space)
    n_output
        The dimensionality of the output (data space)
    n_cat_list
        A list containing the number of categories
        for each category of interest. Each category will be
        included using a one-hot encoding
    n_layers
        The number of fully-connected hidden layers
    n_hidden
        The number of nodes per hidden layer
    dropout_rate
        Dropout rate to apply to each of the hidden layers
    kwargs
        Keyword args for :class:`~scvi.modules._base.FCLayers`
    """

    def __init__(
        self,
        n_input: int,
        n_output: int,
        n_cat_list: Iterable[int] = None,
        n_layers: int = 1,
        n_hidden: int = 128,
        dropout_rate: float = 0.2,
        **kwargs,
    ):
        super().__init__()
        self.decoder = FCLayers(
            n_in=n_input,
            n_out=n_hidden,
            n_cat_list=n_cat_list,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            **kwargs,
        )
        self.linear_out = nn.Linear(n_hidden, n_output)

    def forward(self, x: torch.Tensor, *cat_list: int):
        """
        The forward computation for a single sample.
         #. Decodes the data from the latent space using the decoder network
         #. Returns tensors for the mean and variance of a multivariate distribution
        Parameters
        ----------
        x
            tensor with shape ``(n_input,)``
        cat_list
            list of category membership(s) for this sample
        Returns
        -------
        2-tuple of :py:class:`torch.Tensor`
            Mean and variance tensors of shape ``(n_output,)``
        """
        p = self.linear_out(self.decoder(x, *cat_list))
        return p

After that, I define the VAE module

import numpy as np
import torch
from torch.distributions import Normal, NegativeBinomial
from torch.distributions import kl_divergence as kl
from scvi.nn import Decoder, Encoder

from scvi import _CONSTANTS
from scvi.module.base import (
    BaseModuleClass,
    LossRecorder,
    auto_move_data,
)

class MyModule(BaseModuleClass):
    """
    Parameters
    ----------
    n_input
        Number of input genes
    n_latent
        Dimensionality of the latent space
    """

    def __init__(
        self,
        n_input: int,
        n_hidden: int = 800,
        n_latent: int = 10,
        n_layers: int = 2,
        dropout_rate: float = 0.1,
        kl_weight: float = 0.00005,
    ):
        super().__init__()
        # in the init, we create the parameters of our elementary stochastic computation unit.
        # First, we setup the parameters of the generative model
        self.n_layers = n_layers
        self.n_latent = n_latent
        self.kl_weight = kl_weight
    


        self.decoder = MyDecoder(n_latent,
                                    n_input 
        )

        # Second, we setup the parameters of the variational distribution
        self.z_encoder = Encoder(n_input,
                                 n_latent
        )        

    def _get_inference_input(self, tensors):
        """Parse the dictionary to get appropriate args"""
        # let us fetch the raw counts, and add them to the dictionary
        x = tensors[_CONSTANTS.X_KEY]
        input_dict = dict(x=x)
        return input_dict

    @auto_move_data
    def inference(self, x):
        """
        High level inference method.
        Runs the inference (encoder) model.
        """
        qz_m, qz_v, z = self.z_encoder(x)

        outputs = dict(qz_m=qz_m, qz_v=qz_v, z=z)
        return outputs

    def _get_generative_input(self, tensors, inference_outputs):
        z = inference_outputs["z"]
        input_dict = {
            "z": z
        }
        return input_dict

    @auto_move_data
    def generative(self, z):
        """Runs the generative model."""
        px = self.decoder(z)
        return dict(px=px)

    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
    ):
        x = tensors[_CONSTANTS.X_KEY]
        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        p = generative_outputs["px"]

        kld = kl(
            Normal(qz_m, torch.sqrt(qz_v)),
            Normal(0, 1),
        ).sum(dim=1)
        rl = self.get_reconstruction_loss(p, x)
        loss = (0.5 * rl + 0.5 * (kld * self.kl_weight)).mean()
        kl_global = torch.randn(1)
        return LossRecorder(loss, rl, kld, kl_global)

    def get_reconstruction_loss(self, x, px) -> torch.Tensor:
        loss = ((x - px) ** 2).sum(dim=1)
        return loss

Then I am creating the model

import numpy
import numpy as np
import pandas as pd
import scanpy as sc
import torch
from anndata import AnnData
from scvi.module import VAE
from scvi.data import setup_anndata
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin

class Try1(UnsupervisedTrainingMixin, BaseModelClass, VAEMixin):
    """
    single-cell Variational Inference [Lopez18]_.
    """

    def __init__(
        self,
        adata: AnnData,
        n_latent: int = 10,
        n_hidden: int = 800,
        n_layers: int = 2,
        dropout_rate: float = 0.1,
        **model_kwargs,
    ):
        super(Try1, self).__init__(adata)
        self.adata = adata

        self.module = MyModule(
            n_input=self.summary_stats["n_vars"],
            #n_batch=self.summary_stats["n_batch"],
            n_latent=n_latent,
            n_hidden=n_hidden,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
            **model_kwargs,
        )
        self._model_summary_string = (
            "VAE Model with the following params: \nn_hidden: {}, \nn_latent: {}, n_layers: {}, dropout_rate: {}"
        ).format(
            n_hidden,
            n_latent,
            n_layers,
            dropout_rate,
        )
        self.init_params_ = self._get_init_params(locals())

And try to train it with the pbmc5k data, that I prepared before.

pbmc5k_ready = pbmc5k_ready.copy()
scvi.data.setup_anndata(pbmc5k_ready)
model = Try1(pbmc5k_ready)
model.train(2)

Unfortunately I am getting this error message

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

Epoch 1/2:   0%|          | 0/2 [00:00<?, ?it/s]

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-133-37babc234801> in <module>()
----> 1 model.train(2)

12 frames

/usr/local/lib/python3.7/dist-packages/scvi/model/base/_training_mixin.py in train(self, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
     75             **trainer_kwargs,
     76         )
---> 77         return runner()

/usr/local/lib/python3.7/dist-packages/scvi/train/_trainrunner.py in __call__(self)
     74         self.training_plan.n_obs_training = len(self.model.train_indices)
     75 
---> 76         self.trainer.fit(self.training_plan, self.data_splitter)
     77         self._update_history()
     78 

/usr/local/lib/python3.7/dist-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
    164                     message="`LightningModule.configure_optimizers` returned `None`",
    165                 )
--> 166             super().fit(*args, **kwargs)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    458         )
    459 
--> 460         self._run(model)
    461 
    462         assert self.state.stopped

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    756 
    757         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 758         self.dispatch()
    759 
    760         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    797             self.accelerator.start_predicting(self)
    798         else:
--> 799             self.accelerator.start_training(self)
    800 
    801     def run_stage(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     94 
     95     def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96         self.training_type_plugin.start_training(trainer)
     97 
     98     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    142     def start_training(self, trainer: 'pl.Trainer') -> None:
    143         # double dispatch to initiate the training loop
--> 144         self._results = trainer.run_stage()
    145 
    146     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    807         if self.predicting:
    808             return self.run_predict()
--> 809         return self.run_train()
    810 
    811     def _pre_training_routine(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    869                 with self.profiler.profile("run_training_epoch"):
    870                     # run train epoch
--> 871                     self.train_loop.run_training_epoch()
    872 
    873                 if self.max_steps and self.max_steps <= self.global_step:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
    564 
    565         # handle epoch_output on epoch end
--> 566         self.on_train_epoch_end(epoch_output)
    567 
    568         # log epoch metrics

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in on_train_epoch_end(self, epoch_output)
    604 
    605             # lightningmodule hook
--> 606             training_epoch_end_output = model.training_epoch_end(processed_epoch_output)
    607 
    608             if training_epoch_end_output is not None:

/usr/local/lib/python3.7/dist-packages/scvi/train/_trainingplans.py in training_epoch_end(self, outputs)
    139         # kl global same for each minibatch
    140         kl_global = outputs[0]["kl_global"]
--> 141         elbo += kl_global
    142         self.log("elbo_train", elbo / n_obs)
    143         self.log("reconstruction_loss_train", rec_loss / n_obs)

RuntimeError: output with shape [] doesn't match the broadcast shape [1]

If anybody has an idea on how to solve this error I would be very thankful!

Creating this model is just a training for me, for a project, where I will have to create many different models, so I also highly appreciate any tips on how to go about creating new models with scvi.
I worked myself through the tutorials and looked at the skeleton, but am still running into errors while training my own models, which I often can not really interpret, since the train function seems to be kind of a black box.

With kindest regards from Cologne, Germany,

Lunas

Thanks for using scvi-tools. There is no need to specify the kl_global as we have a correct default of 0 set. I believe you’d need to pass torch.tensor(0.0). Your value here doesn’t have the expected shape:

In [4]: torch.randn(1).shape
Out[4]: torch.Size([1])

In [5]: torch.tensor(0.0).shape
Out[5]: torch.Size([])

Also it’s great that you’re walking through the tutorials. At the end of the day, it’s a lightweight, single-cell specific wrapping of PyTorch and PyTorch Lightning. I advise familiarizing yourself with PyTorch Lightning.

Dear Adam,

Thanks a lot for your quick response!
I appreciate it :slight_smile:

When I remove my kl_global = torch.randn(1) variable, I get the same error. When setting the variable to torch.tensor(0.0).shape I get a new error.

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

Epoch 1/2:   0%|          | 0/2 [00:00<?, ?it/s]

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-26-37babc234801> in <module>()
----> 1 model.train(2)

29 frames

/usr/local/lib/python3.7/dist-packages/scvi/model/base/_training_mixin.py in train(self, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
     75             **trainer_kwargs,
     76         )
---> 77         return runner()

/usr/local/lib/python3.7/dist-packages/scvi/train/_trainrunner.py in __call__(self)
     74         self.training_plan.n_obs_training = len(self.model.train_indices)
     75 
---> 76         self.trainer.fit(self.training_plan, self.data_splitter)
     77         self._update_history()
     78 

/usr/local/lib/python3.7/dist-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
    164                     message="`LightningModule.configure_optimizers` returned `None`",
    165                 )
--> 166             super().fit(*args, **kwargs)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    458         )
    459 
--> 460         self._run(model)
    461 
    462         assert self.state.stopped

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    756 
    757         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 758         self.dispatch()
    759 
    760         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    797             self.accelerator.start_predicting(self)
    798         else:
--> 799             self.accelerator.start_training(self)
    800 
    801     def run_stage(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     94 
     95     def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96         self.training_type_plugin.start_training(trainer)
     97 
     98     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    142     def start_training(self, trainer: 'pl.Trainer') -> None:
    143         # double dispatch to initiate the training loop
--> 144         self._results = trainer.run_stage()
    145 
    146     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    807         if self.predicting:
    808             return self.run_predict()
--> 809         return self.run_train()
    810 
    811     def _pre_training_routine(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    869                 with self.profiler.profile("run_training_epoch"):
    870                     # run train epoch
--> 871                     self.train_loop.run_training_epoch()
    872 
    873                 if self.max_steps and self.max_steps <= self.global_step:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
    497             # ------------------------------------
    498             with self.trainer.profiler.profile("run_training_batch"):
--> 499                 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
    500 
    501             # when returning -1 from train_step, we end epoch early

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
    736 
    737                         # optimizer step
--> 738                         self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    739                         if len(self.trainer.optimizers) > 1:
    740                             # revert back to previous state

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    440             on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE,
    441             using_native_amp=using_native_amp,
--> 442             using_lbfgs=is_lbfgs,
    443         )
    444 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1401 
   1402         """
-> 1403         optimizer.step(closure=optimizer_closure)
   1404 
   1405     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/optimizer.py in step(self, closure, *args, **kwargs)
    212             profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
    213 
--> 214         self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
    215         self._total_optimizer_step_calls += 1
    216 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/optimizer.py in __optimizer_step(self, closure, profiler_name, **kwargs)
    132 
    133         with trainer.profiler.profile(profiler_name):
--> 134             trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
    135 
    136     def step(self, *args, closure: Optional[Callable] = None, **kwargs):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
    327         )
    328         if make_optimizer_step:
--> 329             self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
    330         self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
    331         self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
    334         self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
    335     ) -> None:
--> 336         self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
    337 
    338     def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in optimizer_step(self, optimizer, lambda_closure, **kwargs)
    191 
    192     def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 193         optimizer.step(closure=lambda_closure, **kwargs)
    194 
    195     @property

/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
     86                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     87                 with torch.autograd.profiler.record_function(profile_name):
---> 88                     return func(*args, **kwargs)
     89             return wrapper
     90 

/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

/usr/local/lib/python3.7/dist-packages/torch/optim/adam.py in step(self, closure)
     64         if closure is not None:
     65             with torch.enable_grad():
---> 66                 loss = closure()
     67 
     68         for group in self.param_groups:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in train_step_and_backward_closure()
    731                         def train_step_and_backward_closure():
    732                             result = self.training_step_and_backward(
--> 733                                 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
    734                             )
    735                             return None if result is None else result.loss

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
    821         with self.trainer.profiler.profile("training_step_and_backward"):
    822             # lightning module hook
--> 823             result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
    824             self._curr_step_result = result
    825 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
    288             model_ref._results = Result()
    289             with self.trainer.profiler.profile("training_step"):
--> 290                 training_step_output = self.trainer.accelerator.training_step(args)
    291                 self.trainer.accelerator.post_training_step()
    292 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in training_step(self, args)
    202 
    203         with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 204             return self.training_type_plugin.training_step(*args)
    205 
    206     def post_training_step(self) -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in training_step(self, *args, **kwargs)
    153 
    154     def training_step(self, *args, **kwargs):
--> 155         return self.lightning_module.training_step(*args, **kwargs)
    156 
    157     def post_training_step(self):

/usr/local/lib/python3.7/dist-packages/scvi/train/_trainingplans.py in training_step(self, batch, batch_idx, optimizer_idx)
    126             "reconstruction_loss_sum": reconstruction_loss.sum(),
    127             "kl_local_sum": scvi_loss.kl_local.sum(),
--> 128             "kl_global": scvi_loss.kl_global,
    129             "n_obs": reconstruction_loss.shape[0],
    130         }

/usr/local/lib/python3.7/dist-packages/scvi/module/base/_base_module.py in kl_global(self)
     81     @property
     82     def kl_global(self) -> torch.Tensor:
---> 83         return self._get_dict_sum(self._kl_global)
     84 
     85 

/usr/local/lib/python3.7/dist-packages/scvi/module/base/_base_module.py in _get_dict_sum(dictionary)
     64         total = 0.0
     65         for value in dictionary.values():
---> 66             total += value
     67         return total
     68 

TypeError: unsupported operand type(s) for +=: 'float' and 'torch.Size'

Any ideas?

I downgraded with

!pip3 install --upgrade scvi-tools==0.13.0

because of this issue

could this lead to problems when implementing new models?

Thanks again for your help.

I will do that, thanks!

I would recommend going through the traceback and trying to find the bugs. For example, here there is a shape error with your kl_global term, which is why I proposed you exchange it with torch.tensor(0.0)

This is what I have been trying to do…
But I will continue doing that of course.
But would you agree that on a first glance, there are no obvious mistakes in the model and that it is designed the way, that scvi-models are supposed to?
Further, will it often be necessary to adjust the .train() function itself for new models, or should the new model be designed in a way that .train() should work for most of them, without needing to adjust the function?

Well the overall structure looks good, the only bug you have relates to the kl_global, which in your case you shouldn’t even have to specify.

For a large class of models you won’t need to change the train function. If you need customized inference procedures you will have to change it potentially.