Thanks so much @adamgayoso! It really helps, and I totally understand the abstraction of code, which make it more flexible and compatible. Really appreciate the hard work!
Basically I changed a little bit of the code of CellAssignModule to fit in other variables in the CellAssign algorithm. What I’m concerned is that the variable inference in the new model might not work. It seems the predict function output the originally defined randomized delta variables instead of optimized delta variables in the new model. I debugged it for a while and it seems the loss function might not be called in the new code. Not sure why it happened.
Train function is still same. Prediction function is like below:
def predict(self) -> pd.DataFrame:
"""Predict soft cell type assignment probability for each cell."""
adata = self._validate_anndata(None)
scdl = self._make_data_loader(adata=adata)
# predictions = 
for idx, tensors in enumerate(scdl):
generative_inputs = self.module._get_generative_input(tensors, None)
outputs = self.module.generative(**generative_inputs)
if idx == 0:
delta_c = outputs["delta_c"]
delta_p = outputs["delta_p"]
delta_cp = outputs["delta_cp"]
delta_c = torch.cat((delta_c, outputs["delta_c"]))
delta_p = torch.cat((delta_p, outputs["delta_p"]))
delta_cp = torch.cat((delta_cp, outputs["delta_cp"]))
# gamma = outputs["gamma"]
# predictions += [gamma.cpu()]
# to be better specified ??
return delta_c.numpy(), delta_p.numpy(), delta_cp.numpy()