diff --git a/dpat/data/pmchhg_h5_dataset.py b/dpat/data/pmchhg_h5_dataset.py index 70d283e..2d81f64 100644 --- a/dpat/data/pmchhg_h5_dataset.py +++ b/dpat/data/pmchhg_h5_dataset.py @@ -539,9 +539,11 @@ def __getitem__(self, index: int) -> H5ItemObject: tile_y=metadata["all_tile_y"], case_id=case_id, img_id=img_id, - cc=metadata["all_location"].astype(str)[0], ) + if self.clinical_context: + data_obj = data_obj | dict(cc=metadata["all_location"].astype(str)[0][0]) + return data_obj def __len__(self): diff --git a/dpat/mil/models/ccmil.py b/dpat/mil/models/ccmil.py index 75bf04e..6a0d274 100644 --- a/dpat/mil/models/ccmil.py +++ b/dpat/mil/models/ccmil.py @@ -131,7 +131,7 @@ def _common_step(self, batch): """ # [0] to convert list to string, # because all tiles in a bag have the same clinical context. - x, y, cc = batch["data"], batch["target"], batch["cc"][0] + x, y, cc = batch["data"], batch["target"], batch["cc"] y_hat, A = self(x, cc) loss = self.loss_fn(y_hat, y) return loss, y_hat, y, A