I have a general question on how to pass the trained scvi models to downstream analysis tools, e.g. Captum (interpretation algorithm for pytorch).
These tools, e.g., captum, usually accept
nn.module as input. How can I retrieve the
nn.module from compact scvi models and pass it to these tools?
Relevant questionss apply to any algorithms taking
pl.lightningmodule as input. For example in “weight and bias”,
wandb.watch() can take
pl.lightningmodule to log the gradients and weights during training, which can be useful to diagnose the model, but it is not clear how to do it with scvi models.