-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix bugs that occur when using sklearn.metric and add benchmark codes
- Loading branch information
1 parent
33cfb46
commit 8c672e5
Showing
14 changed files
with
719 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,5 @@ temporary_ckpt_data/ | |
*egg* | ||
build/ | ||
dist/ | ||
.pytest_cache/ | ||
.pytest_cache/ | ||
benchmark/benchmark_ckpt/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import argparse | ||
from datasets import load_diabetes | ||
from pipelines import VIMEPipeLine, SubTabPipeLine, SCARFPipeLine | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(add_help=True) | ||
|
||
parser.add_argument('--model', type=str, choices=["vime", "subtab", "scarf"]) | ||
parser.add_argument('--data', type=str) | ||
|
||
parser.add_argument('--labeled_sample_ratio', type=float, default=0.1) | ||
parser.add_argument('--valid_size', type=float, default=0.2) | ||
parser.add_argument('--test_size', type=float, default=0.2) | ||
parser.add_argument('--random_seed', type=int, default=0) | ||
|
||
parser.add_argument('--batch_size', type=int, default=128) | ||
parser.add_argument('--first_phase_patience', type=int, default=8) | ||
parser.add_argument('--second_phase_patience', type=int, default=16) | ||
parser.add_argument('--n_trials', type=int, default=50) | ||
parser.add_argument('--n_jobs', type=int, default=32) | ||
parser.add_argument('--max_epochs', type=int, default=200) | ||
|
||
parser.add_argument('--accelerator', type=str, choices=["cuda", "cpu"]) | ||
parser.add_argument('--devices', nargs='+', type=int, required=True) | ||
|
||
parser.add_argument('--fast_dev_run', action="store_true") | ||
|
||
args = parser.parse_args() | ||
|
||
if args.accelerator == 'cpu': | ||
args.device = 'auto' | ||
|
||
data, label, continuous_cols, category_cols, output_dim, metric, metric_hparams = load_diabetes() | ||
|
||
if args.model == "vime": | ||
pipeline_class = VIMEPipeLine | ||
elif args.model == "subtab": | ||
pipeline_class = SubTabPipeLine | ||
elif args.model == "scarf": | ||
pipeline_class = SCARFPipeLine | ||
|
||
if args.fast_dev_run: | ||
args.max_epochs = 1 | ||
args.first_phase_patience = 1 | ||
args.second_phase_patience = 1 | ||
args.n_trials = 1 | ||
|
||
pipeline = pipeline_class(args, data, label, continuous_cols, category_cols, output_dim, metric, metric_hparams) | ||
|
||
pipeline.benchmark() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .diabetes import load_diabetes | ||
|
||
__all__=["load_diabetes"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from sklearn.datasets import fetch_openml | ||
import numpy as np | ||
from types import SimpleNamespace | ||
from typing import Tuple, List | ||
import pandas as pd | ||
from sklearn.preprocessing import LabelEncoder | ||
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler | ||
|
||
def load_diabetes(): | ||
|
||
diabetes = fetch_openml(data_id = 37, data_home='./data_cache') | ||
|
||
data = diabetes.data | ||
|
||
le = LabelEncoder() | ||
label = pd.Series(le.fit_transform(diabetes.target)) | ||
|
||
category_cols = [] | ||
continuous_cols = list(map(str, data.columns)) | ||
|
||
scaler = MinMaxScaler() | ||
data[continuous_cols] = scaler.fit_transform(data[continuous_cols]) | ||
|
||
return data, label, continuous_cols, category_cols, 2, "accuracy_score", {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
hparams_range = { | ||
|
||
'hidden_dim' : ['suggest_int', ['hidden_dim', 16, 512]], | ||
'encoder_depth' : ['suggest_int', ['encoder_depth', 2, 6]], | ||
'head_depth' : ['suggest_int', ['head_depth', 1, 3]], | ||
'corruption_rate' : ['suggest_float', ['corruption_rate', 0.1, 0.7]], | ||
'dropout_rate' : ['suggest_float', ['dropout_rate', 0.05, 0.3]], | ||
|
||
'lr' : ['suggest_float', ['lr', 0.0001, 0.05]], | ||
'weight_decay' : ['suggest_float', ['weight_decay', 0.00001, 0.0005]], | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
hparams_range = { | ||
|
||
'hidden_dim' : ['suggest_int', ['hidden_dim', 4, 1024]], | ||
|
||
'tau' : ["suggest_float", ["tau", 0.05, 0.15]], | ||
"use_cosine_similarity" : ["suggest_categorical", ["use_cosine_similarity", [True, False]]], | ||
"use_contrastive" : ["suggest_categorical", ["use_contrastive", [True, False]]], | ||
"use_distance" : ["suggest_categorical", ["use_distance", [True, False]]], | ||
|
||
"n_subsets" : ["suggest_int", ["n_subsets", 2, 7]], | ||
"overlap_ratio" : ["suggest_float", ["overlap_ratio", 0., 1]], | ||
|
||
"mask_ratio" : ["suggest_float", ["mask_ratio", 0.1, 0.3]], | ||
"noise_level" : ["suggest_float", ["noise_level", 0.5, 2]], | ||
"noise_type" : ["suggest_categorical", ["noise_type", ["Swap", "Gaussian", "Zero_Out"]]], | ||
|
||
'lr' : ['suggest_float', ['lr', 0.0001, 0.05]], | ||
'weight_decay' : ['suggest_float', ['weight_decay', 0.00001, 0.0005]], | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
hparams_range = { | ||
|
||
'hidden_dim' : ['suggest_int', ['hidden_dim', 16, 512]], | ||
|
||
'p_m' : ["suggest_float", ["p_m", 0.1, 0.9]], | ||
'alpha1' : ["suggest_float", ["alpha1", 0.1, 5]], | ||
'alpha2' : ["suggest_float", ["alpha2", 0.1, 5]], | ||
'beta' : ["suggest_float", ["beta", 0.1, 10]], | ||
'K' : ["suggest_int", ["K", 2, 20]], | ||
|
||
|
||
'lr' : ['suggest_float', ['lr', 0.0001, 0.05]], | ||
'weight_decay' : ['suggest_float', ['weight_decay', 0.00001, 0.0005]], | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .vime_pipeline import VIMEPipeLine | ||
from .subtab_pipeline import SubTabPipeLine | ||
from .scarf_pipeline import SCARFPipeLine | ||
__all__ = ["VIMEPipeLine", "SubTabPipeLine", "SCARFPipeLine"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import argparse | ||
import pandas as pd | ||
from typing import List, Type, Dict, Any | ||
from ts3l.utils import BaseConfig, RegressionMetric, ClassificationMetric | ||
from ts3l.pl_modules.base_module import TS3LLightining | ||
|
||
from abc import ABC, abstractmethod | ||
import optuna | ||
from sklearn.model_selection import train_test_split | ||
|
||
from copy import deepcopy | ||
|
||
class PipeLine(ABC): | ||
|
||
def __init__(self, | ||
args: argparse.Namespace, | ||
data: pd.DataFrame, | ||
label: pd.Series, | ||
continuous_cols: List[str], | ||
category_cols: List[str], | ||
output_dim: int, | ||
metric: str, | ||
metric_hparams: Dict[str, Any] = {}): | ||
self.args = args | ||
self.data = data | ||
self.label = label | ||
self.continuous_cols = continuous_cols | ||
self.category_cols = category_cols | ||
self.output_dim = output_dim | ||
self.metric = metric | ||
self.metric_hparams = metric_hparams | ||
|
||
X_train, X_valid, y_train, y_valid = train_test_split(data, label, test_size = args.valid_size + args.test_size, random_state=args.random_seed) | ||
|
||
self.X_valid, self.X_test, self.y_valid, self.y_test = train_test_split(X_valid, y_valid, test_size = args.test_size / (args.valid_size + args.test_size), random_state=args.random_seed) | ||
|
||
if args.labeled_sample_ratio == 1: | ||
self.X_train, self.y_train = X_train, y_train | ||
self.X_unlabeled = None | ||
else: | ||
self.X_train, self.X_unlabeled, self.y_train, _ = train_test_split(X_train, y_train, train_size = args.labeled_sample_ratio, random_state=args.random_seed) | ||
|
||
self.direction = "maximize" if self.output_dim > 1 else "minimize" | ||
|
||
self.config_class = None | ||
self.pl_module_class = None | ||
self.hparams_range = None | ||
|
||
self.__configure_metric() | ||
self.initialize() | ||
|
||
self.check_attributes() | ||
|
||
|
||
def __objective(self, trial: optuna.trial.Trial, | ||
) -> float: | ||
|
||
hparams = {} | ||
for k, v in self.hparams_range.items(): | ||
hparams[k] = getattr(trial, v[0])(*v[1]) | ||
|
||
config = self._get_config(hparams) | ||
pl_module = self.pl_module_class(config) | ||
|
||
pl_module = self.fit_model(pl_module, config) | ||
|
||
return self.evaluate(pl_module, config) | ||
|
||
@abstractmethod | ||
def fit_model(self, pl_module: TS3LLightining, config: Type[BaseConfig]): | ||
pass | ||
|
||
@abstractmethod | ||
def evaluate(self, pl_module: TS3LLightining, config: Type[BaseConfig]): | ||
pass | ||
|
||
@abstractmethod | ||
def initialize(self): | ||
pass | ||
|
||
def check_attributes(self): | ||
if self.config_class is None: | ||
raise NotImplementedError('self.config_class must be defined') | ||
if self.pl_module_class is None: | ||
raise NotImplementedError('self.pl_module_class must be defined') | ||
if self.hparams_range is None: | ||
raise NotImplementedError('self.hparams_range must be defined') | ||
|
||
def __tune_hyperparameters(self): | ||
study = optuna.create_study(direction=self.direction,sampler=optuna.samplers.TPESampler(seed=self.args.random_seed)) | ||
study.optimize(self.__objective, n_trials=self.args.n_trials, show_progress_bar=False) | ||
|
||
print("Number of finished trials: ", len(study.trials)) | ||
print("Best trial:") | ||
|
||
trial = study.best_trial | ||
hparams = dict(trial.params.items()) | ||
|
||
print(" Evaluation Results: {}".format(trial.value)) | ||
print(" Best hyperparameters: ", hparams) | ||
|
||
return hparams | ||
|
||
def __configure_metric(self): | ||
|
||
if self.output_dim == 1: | ||
self.metric = RegressionMetric(self.metric, self.metric_hparams) | ||
else: | ||
self.metric = ClassificationMetric(self.metric, self.metric_hparams) | ||
|
||
@abstractmethod | ||
def _get_config(self, _hparams: Dict[str, Any]): | ||
hparams = deepcopy(_hparams) | ||
|
||
hparams["optim_hparams"] = { | ||
"lr" : hparams["lr"], | ||
"weight_decay": hparams["weight_decay"] | ||
} | ||
del hparams["lr"] | ||
del hparams["weight_decay"] | ||
|
||
hparams["task"] = "regression" if self.output_dim == 1 else "classification" | ||
hparams["loss_fn"] = "MSELoss" if self.output_dim == 1 else "CrossEntropyLoss" | ||
hparams["metric"] = self.metric.__name__ | ||
hparams["metric_hparams"] = self.metric_hparams | ||
hparams["random_seed"] = self.args.random_seed | ||
|
||
return hparams | ||
|
||
|
||
def benchmark(self): | ||
hparams = self.__tune_hyperparameters() | ||
|
||
config = self._get_config(hparams) | ||
pl_module = self.pl_module_class(config) | ||
|
||
pl_module = self.fit_model(pl_module, config) | ||
|
||
print("Evaluation %s: %.4f" % (self.metric.__name__, self.evaluate(pl_module, config))) | ||
|
Oops, something went wrong.