Directly accessing scVI's decoder

Hi, thanks for this really exciting tool.

I’m trying out some trajectory analysis where I would like to map predicted trajectories from the latent space back to the original expression space; in essence I would like to directly pass the points of these trajectories (which are not part of the anndata object) through the model’s decoder.

I see that the VAE class has a function ‘generative()’ which looks like what I am after, but I am unfamiliar with the architecture of the model and I am unsure how best to access this function. What would be the best way for a user to directly pass latent space points through the model’s decoder?

Any help would be much appreciated, thanks!

Rory

Hi! Thanks for using scvi-tools.

It might be helpful to look at the structure of the get_latent_representation() method.

        adata = self._validate_anndata(adata)
        scdl = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size
        )
        latent = []
        for tensors in scdl:
            inference_inputs = self.module._get_inference_input(tensors)
            outputs = self.module.inference(**inference_inputs)
            qz_m = outputs["qz_m"]
            qz_v = outputs["qz_v"]
            z = outputs["z"]

            if give_mean:
                # does each model need to have this latent distribution param?
                if self.module.latent_distribution == "ln":
                    samples = Normal(qz_m, qz_v.sqrt()).sample([mc_samples])
                    z = torch.nn.functional.softmax(samples, dim=-1)
                    z = z.mean(dim=0)
                else:
                    z = qz_m

            latent += [z.cpu()]
        return np.array(torch.cat(latent))

Basically what you want to do is create a dataloader with your values of z, iterate over it, and called self.module.generative(...)

Note that the signature of generative in VAE is

    @auto_move_data
    def generative(
        self, z, library, batch_index, cont_covs=None, cat_covs=None, y=None
    ):

So you’ll need to decode z with a batch_index batch size by 1 tensor of ints. The library size you can make a torch tensor of 1s, as it only affects the computation of px_rate. From generative return dictionary you’ll want px_scale.

Please feel free to follow up with additional questions.