This function would generate posterior predictive samples, but I imagine you want to sample from the prior and not the posterior.

Something like this would work. There might be some import errors, CUDA errors, etc. Something to keep in mind is that here the library size is assumed to be 1, and the batch index of all cells is 0, which could be a problem if the model is trained on multiple batches.

```
import torch
from torch.distributions import Normal
@torch.no_grad()
def prior_predictive_sample(
self,
n_samples: int = 1000,
) -> np.ndarray:
r"""
Generate observation samples from the prior predictive distribution.
Parameters
----------
n_samples
Number of samples.
Returns
-------
x_new : :py:class:`torch.Tensor`
tensor with shape (n_cells, n_genes, n_samples)
"""
if self.model.gene_likelihood not in ["zinb", "nb", "poisson"]:
raise ValueError("Invalid gene_likelihood.")
# Sampling
qz_m = torch.zeros(n_samples, self.model.n_latent)
qz_v = torch.ones(n_samples, self.model.n_latent)
z = Normal(qz_m, qz_v).sample()
dec_batch_index = torch.zeros(n_samples, 1)
y = torch.zeros(n_samples, 1)
library = torch.zeros(n_samples, 1) # gets exponentiated
px_scale, px_r, px_rate, px_dropout = self.model.decoder(
self.model.dispersion, z, library, dec_batch_index, y
)
if self.model.dispersion == "gene-label":
px_r = F.linear(
one_hot(y, self.model.n_labels), self.px_r
) # px_r gets transposed - last dimension is nb genes
elif self.model.dispersion == "gene-batch":
px_r = F.linear(one_hot(dec_batch_index, self.model.n_batch), self.model.px_r)
elif self.model.dispersion == "gene":
px_r = self.model.px_r
px_r = torch.exp(px_r)
if self.model.gene_likelihood == "poisson":
l_train = px_rate
l_train = torch.clamp(l_train, max=1e8)
dist = torch.distributions.Poisson(
l_train
) # Shape : (n_samples, n_cells_batch, n_genes)
elif self.model.gene_likelihood == "nb":
dist = NegativeBinomial(mu=px_rate, theta=px_r)
elif self.model.gene_likelihood == "zinb":
dist = ZeroInflatedNegativeBinomial(
mu=px_rate, theta=px_r, zi_logits=px_dropout
)
else:
raise ValueError(
"{} reconstruction error not handled right now".format(
self.model.gene_likelihood
)
)
exprs = dist.sample()
return exprs.cpu().numpy() # Shape (n_cells, n_genes)
```