Skip to content

Commit

Permalink
fix: restrict lightning module to only use 1 GPU for predict v2
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Nov 27, 2023
1 parent 91d2163 commit b7b3ed3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 31 deletions.
39 changes: 21 additions & 18 deletions embeddings/model/lightning_module/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Optional[Tuple[STEP_OUTPUT,
return logits, preds

def predict(
self, dataloader: DataLoader[HuggingFaceDataset], trainer=None
self, dataloader: DataLoader[HuggingFaceDataset]
) -> Dict[str, nptyping.NDArray[Any]]:
predict_output = self._predict_with_trainer(dataloader, trainer)
predict_output = self._predict_with_trainer(dataloader)
assert predict_output
logits, predictions = zip(*predict_output)
probabilities = softmax(torch.cat(logits), dim=1).numpy()
Expand All @@ -85,25 +85,28 @@ def predict(
return result

def _predict_with_trainer(
self, dataloader: DataLoader[HuggingFaceDataset], trainer=None
self, dataloader: DataLoader[HuggingFaceDataset]
) -> Optional[_PREDICT_OUTPUT]:
if trainer is not None:
self.trainer = trainer
assert self.trainer is not None

try:
return self.trainer.predict(
model=self, dataloaders=dataloader, return_predictions=True, ckpt_path="last"
)
except MisconfigurationException: # model loaded but not fitted
_logger.warning(
"The best model checkpoint cannot be loaded because trainer.fit has not been called. Using current weights for prediction."
)
return self.trainer.predict(
model=self,
dataloaders=dataloader,
return_predictions=True,
)
torch.distributed.destroy_process_group()
if self.trainer.is_global_zero:
self.trainer = pl.Trainer(gpus=1)
try:
return self.trainer.predict(
model=self, dataloaders=dataloader, return_predictions=True, ckpt_path="last"
)
except MisconfigurationException: # model loaded but not fitted
_logger.warning(
"The best model checkpoint cannot be loaded because trainer.fit has not been called. Using current weights for prediction."
)
return self.trainer.predict(
model=self,
dataloaders=dataloader,
return_predictions=True,
)
else:
raise RuntimeError("Got `False` for `trainer.is_global_zero` attribute!")

def on_train_epoch_end(self) -> None:
self._aggregate_and_log_metrics(self.train_metrics)
Expand Down
14 changes: 1 addition & 13 deletions embeddings/task/lightning_task/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,7 @@ def build_task_model(self) -> None:

def predict(self, dataloader: DataLoader[Any], return_names: bool = True) -> Predictions:
assert self.model is not None
if (self.task_train_kwargs.get("accelerator", None) == "gpu") and (
torch.cuda.device_count() > 1
):
pred_trainer_kwargs = self.task_train_kwargs.copy()
pred_trainer_kwargs["devices"] = 1
self.trainer = pl.Trainer(
default_root_dir=str(self.output_path),
callbacks=self.callbacks,
logger=self.loggers,
inference_mode=self.inference_mode,
**self.task_train_kwargs,
)
results = self.model.predict(dataloader=dataloader, trainer=self.trainer)
results = self.model.predict(dataloader=dataloader)
results["names"] = np.array(self.model.target_names)
return Predictions(**results)

Expand Down

0 comments on commit b7b3ed3

Please sign in to comment.