Skip to content

Commit

Permalink
Fix a bug in XGBPipeLine of benchmark code and refactor test codes to…
Browse files Browse the repository at this point in the history
… use benchmark files
  • Loading branch information
Alcoholrithm committed Mar 30, 2024
1 parent 687371b commit 737241a
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 1,434 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ temporary_ckpt_data/
build/
dist/
.pytest_cache/
benchmark/benchmark_ckpt/
*benchmark_ckpt/
15 changes: 3 additions & 12 deletions benchmark/pipelines/xgb_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ class XGBConfig:
scale_pos_weight: int

early_stopping_rounds: int

# # task: str = field(default=None)

# def __post_init__(self):
# if self.task is None:
# raise ValueError("The task of the problem must be specified in the 'task' attribute.")
# elif (type(self.task) is not str or (self.task != "regression" and self.task != "classification")):
# raise ValueError(f"{self.task} is not a valid task. Choices are: ['regression', 'classification']")

class XGBModule(object):
def __init__(self, model_class: Union[XGBClassifier, XGBRegressor]):
Expand All @@ -52,11 +44,9 @@ def initialize(self):
self.hparams_range = hparams_range

def _get_config(self, hparams: Dict[str, Any]):
# hparams["task"] = "regression" if self.output_dim == 1 else "classification"
hparams["early_stopping_rounds"] = self.args.second_phase_patience

return self.config_class(**hparams)
# return asdict(self.config_class(**hparams))

def fit_model(self, pl_module: XGBModule, config: XGBConfig):

Expand All @@ -66,8 +56,9 @@ def fit_model(self, pl_module: XGBModule, config: XGBConfig):

def evaluate(self, pl_module: XGBModule, config: XGBConfig, X: pd.DataFrame, y: pd.Series):

preds = pl_module.predict(X)

preds = pl_module.predict_proba(X)

# print(preds.shape, y.shape)
score = self.metric(preds, y)

return score
Expand Down
33 changes: 0 additions & 33 deletions test/abalone.py

This file was deleted.

24 changes: 0 additions & 24 deletions test/diabetes.py

This file was deleted.

21 changes: 21 additions & 0 deletions test/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from types import SimpleNamespace

def get_args():
args = SimpleNamespace()

args.max_epochs = 1
args.first_phase_patience = 1
args.second_phase_patience = 1
args.n_trials = 1

args.labeled_sample_ratio = 1
args.valid_size = 0.2
args.test_size = 0.2
args.random_seed = 0
args.batch_size = 128

args.n_jobs = 4
args.accelerator = "cpu"
args.devices = "auto"

return args
Loading

0 comments on commit 737241a

Please sign in to comment.