I am trying to understand the adversarial training in the GIMVItraining plan. If I understand it correctly, the purpose of this is to encourage latent space mixing, e.g., the latent code learnt from adata_seq can also represent adata-spatial information and vice versa.
- In your implementation below, you trained a classifier to predict the “adversarial” mode label. But it won’t make a difference to just use the true mode label as it is a binary classification, right? (i.e., “predicting all seq samples to 0, spatial samples to 1” is equivalent to “predicting all seq samples to 1, spatial samples to 0”)
- Would it make more sense to use the uniform logits (0.5, 0.5) instead of (0,1) so that latent space can not “distinguish” which mode it is from?
# fool classifier if doing adversarial training batch_tensor = [ torch.zeros((z.shape, 1), device=z.device) + i for i, z in enumerate(zs) ] if kappa > 0 and self.adversarial_classifier is not False: fool_loss = self.loss_adversarial_classifier( torch.cat(zs), torch.cat(batch_tensor), False ) loss += fool_loss * kappa