Skip to content

Commit

Permalink
Add XGBPipeLine for benchmark and Fix bugs in TabularLightningModule
Browse files Browse the repository at this point in the history
  • Loading branch information
Alcoholrithm committed Mar 30, 2024
1 parent d8b79bc commit 687371b
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 13 deletions.
8 changes: 5 additions & 3 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
from datasets import load_diabetes, load_abalone
from pipelines import VIMEPipeLine, SubTabPipeLine, SCARFPipeLine
from pipelines import VIMEPipeLine, SubTabPipeLine, SCARFPipeLine, XGBPipeLine

def main():
parser = argparse.ArgumentParser(add_help=True)

parser.add_argument('--model', type=str, choices=["vime", "subtab", "scarf"])
parser.add_argument('--model', type=str, choices=["xgb", "vime", "subtab", "scarf"])
parser.add_argument('--data', type=str, choices=["diabetes", "abalone"])

parser.add_argument('--labeled_sample_ratio', type=float, default=0.1)
Expand Down Expand Up @@ -37,7 +37,9 @@ def main():

data, label, continuous_cols, category_cols, output_dim, metric, metric_hparams = load_data()

if args.model == "vime":
if args.model == "xgb":
pipeline_class = XGBPipeLine
elif args.model == "vime":
pipeline_class = VIMEPipeLine
elif args.model == "subtab":
pipeline_class = SubTabPipeLine
Expand Down
7 changes: 7 additions & 0 deletions benchmark/hparams_range/xgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
hparams_range = {
'max_leaves' : ['suggest_int', ['max_leaves', 300, 4000]],
'n_estimators' : ['suggest_int', ['n_estimators', 10, 3000]],
'learning_rate' : ['suggest_float', ['learning_rate',0, 1]],
'max_depth' : ['suggest_int', ['max_depth', 3, 20]],
'scale_pos_weight' : ['suggest_int', ['scale_pos_weight', 1, 100]],
}
4 changes: 3 additions & 1 deletion benchmark/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .xgb_pipeline import XGBPipeLine
from .vime_pipeline import VIMEPipeLine
from .subtab_pipeline import SubTabPipeLine
from .scarf_pipeline import SCARFPipeLine
__all__ = ["VIMEPipeLine", "SubTabPipeLine", "SCARFPipeLine"]

__all__ = ["XGBPipeLine", "VIMEPipeLine", "SubTabPipeLine", "SCARFPipeLine"]
6 changes: 3 additions & 3 deletions benchmark/pipelines/vime_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def fit_model(self, pl_module: TS3LLightining, config: Type[BaseConfig]):

pl_module.set_second_phase()

train_ds = VIMEDataset(self.X_train, self.y_train.values, config, unlabeled_data=self.X_unlabeled, continuous_cols=self.continuous_cols, category_cols=self.category_cols, is_second_phase=True)
test_ds = VIMEDataset(self.X_valid, self.y_valid.values, config, continuous_cols=self.continuous_cols, category_cols=self.category_cols, is_second_phase=True)
train_ds = VIMEDataset(self.X_train, self.y_train.values, config, unlabeled_data=self.X_unlabeled, continuous_cols=self.continuous_cols, category_cols=self.category_cols, is_second_phase=True, is_regression=True if self.output_dim==1 else False)
test_ds = VIMEDataset(self.X_valid, self.y_valid.values, config, continuous_cols=self.continuous_cols, category_cols=self.category_cols, is_second_phase=True, is_regression=True if self.output_dim==1 else False)

pl_datamodule = TS3LDataModule(train_ds, test_ds, batch_size = self.args.batch_size, train_sampler="random" if self.output_dim == 1 else "weighted", train_collate_fn=VIMESemiSLCollateFN())

Expand Down Expand Up @@ -134,7 +134,7 @@ def evaluate(self, pl_module: TS3LLightining, config: Type[BaseConfig], X: pd.Da
callbacks = None,
)

test_ds = VIMEDataset(X, category_cols=self.category_cols, continuous_cols=self.continuous_cols, is_second_phase=True)
test_ds = VIMEDataset(X, category_cols=self.category_cols, continuous_cols=self.continuous_cols, is_second_phase=True, is_regression=True if self.output_dim==1 else False)
test_dl = DataLoader(test_ds, self.args.batch_size, shuffle=False, sampler = SequentialSampler(test_ds), num_workers=self.args.n_jobs)

preds = trainer.predict(pl_module, test_dl)
Expand Down
74 changes: 74 additions & 0 deletions benchmark/pipelines/xgb_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse
import pandas as pd
from typing import List, Dict, Any, Union, Type

from dataclasses import dataclass, field, asdict
from .pipeline import PipeLine
from xgboost import XGBClassifier, XGBRegressor
from hparams_range.xgb import hparams_range

@dataclass
class XGBConfig:

max_leaves: int

n_estimators: int

learning_rate: float

max_depth: int

scale_pos_weight: int

early_stopping_rounds: int

# # task: str = field(default=None)

# def __post_init__(self):
# if self.task is None:
# raise ValueError("The task of the problem must be specified in the 'task' attribute.")
# elif (type(self.task) is not str or (self.task != "regression" and self.task != "classification")):
# raise ValueError(f"{self.task} is not a valid task. Choices are: ['regression', 'classification']")

class XGBModule(object):
def __init__(self, model_class: Union[XGBClassifier, XGBRegressor]):
self.model_class = model_class

def __call__(self, config: XGBConfig):
return self.model_class(**asdict(config))

class XGBPipeLine(PipeLine):
def __init__(self, args: argparse.Namespace, data: pd.DataFrame, label: pd.Series, continuous_cols: List[str], category_cols: List[str], output_dim: int, metric: str, metric_hparams: Dict[str, Any] = {}):
super().__init__(args, data, label, continuous_cols, category_cols, output_dim, metric, metric_hparams)

def initialize(self):
self.config_class = XGBConfig
if self.output_dim == 1:
self.pl_module_class = XGBRegressor
else:
self.pl_module_class = XGBClassifier
self.pl_module_class = XGBModule(self.pl_module_class)

self.hparams_range = hparams_range

def _get_config(self, hparams: Dict[str, Any]):
# hparams["task"] = "regression" if self.output_dim == 1 else "classification"
hparams["early_stopping_rounds"] = self.args.second_phase_patience

return self.config_class(**hparams)
# return asdict(self.config_class(**hparams))

def fit_model(self, pl_module: XGBModule, config: XGBConfig):

pl_module.fit(self.X_train, self.y_train, eval_set=[(self.X_valid, self.y_valid)], verbose = 0)

return pl_module

def evaluate(self, pl_module: XGBModule, config: XGBConfig, X: pd.DataFrame, y: pd.Series):

preds = pl_module.predict(X)

score = self.metric(preds, y)

return score

8 changes: 4 additions & 4 deletions ts3l/pl_modules/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def _on_second_phase_validation_start(self):
"""
if len(self.second_phase_step_outputs) > 0:
train_loss = torch.Tensor([out["loss"] for out in self.second_phase_step_outputs]).detach().mean()
y = torch.cat([out["y"] for out in self.second_phase_step_outputs]).detach().cpu()
y_hat = torch.cat([out["y_hat"] for out in self.second_phase_step_outputs]).detach().cpu()
y = torch.cat([out["y"] for out in self.second_phase_step_outputs if out["y"].numel() != 1]).detach().cpu()
y_hat = torch.cat([out["y_hat"] for out in self.second_phase_step_outputs if out["y_hat"].numel() != 1]).detach().cpu()

train_score = self.metric(y_hat, y)

Expand All @@ -229,8 +229,8 @@ def _second_phase_validation_epoch_end(self) -> None:
val_loss = torch.Tensor([out["loss"] for out in self.second_phase_step_outputs]).mean()


y = torch.cat([out["y"].cpu() for out in self.second_phase_step_outputs])
y_hat = torch.cat([out["y_hat"].cpu() for out in self.second_phase_step_outputs])
y = torch.cat([out["y"].cpu() for out in self.second_phase_step_outputs if out["y"].numel() != 1])
y_hat = torch.cat([out["y_hat"].cpu() for out in self.second_phase_step_outputs if out["y_hat"].numel() != 1])
val_score = self.metric(y_hat, y)

self.log("val_" + self.metric.__name__, val_score, prog_bar = True)
Expand Down
3 changes: 1 addition & 2 deletions ts3l/pl_modules/vime_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,12 @@ def _get_second_phase_loss(self, batch:Dict[str, Any]):
preds = torch.stack([u_y_hat[i, :] for i in range(len(u_y_hat)) if i % self.consistency_len != 0], dim = 0)
unsupervised_loss += self.consistency_loss(preds, target)

labeled_x = x[y != self.u_label].squeeze()
labeled_x = x[y != self.u_label]
labeled_y = y[y != self.u_label]

y_hat = self.model(labeled_x).squeeze()

supervised_loss = self.loss_fn(y_hat, labeled_y)

loss = supervised_loss + self.beta * unsupervised_loss

return loss, labeled_y, y_hat
Expand Down

0 comments on commit 687371b

Please sign in to comment.