diff --git a/embeddings/model/lightning_module/lightning_module.py b/embeddings/model/lightning_module/lightning_module.py index dde23b09..09b8e363 100644 --- a/embeddings/model/lightning_module/lightning_module.py +++ b/embeddings/model/lightning_module/lightning_module.py @@ -106,7 +106,7 @@ def predict( with open(os.path.join(predpath, file), "rb") as f: batch_indices = pickle.load(f) all_batch_indices.append(list(flatten(batch_indices))) - all_batch_indices = torch.Tensor([y for x in all_batch_indices for y in x]) + all_batch_indices = torch.Tensor([y for x in all_batch_indices for y in x]).long() probabilities = softmax(torch.cat(all_logits), dim=1)[all_batch_indices] preds = torch.cat(all_preds)[all_batch_indices]