Skip to content

Commit

Permalink
Merge pull request #173 from COINtoolbox/172-add-malanchev-to-time-do…
Browse files Browse the repository at this point in the history
…main

implement malanchev features for time_domain
  • Loading branch information
emilleishida authored Feb 15, 2024
2 parents 9c4deea + b160546 commit ffbd689
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 33 deletions.
4 changes: 2 additions & 2 deletions resspect/feature_extractors/light_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def conv_flux_mag(flux, zpt: float = 27.5):
return np.array(mag)

def check_queryable(self, mjd: float, filter_lim: float, criteria: int =1,
days_since_last_obs=2, feature_method='Bazin',
days_since_last_obs=2, feature_method: str = 'Bazin',
filter_cut='r'):
"""Check if this object can be queried in a given day.
Expand Down Expand Up @@ -307,7 +307,7 @@ def check_queryable(self, mjd: float, filter_lim: float, criteria: int =1,
self.last_mag = self.conv_flux_mag([fitted_flux])[0]

else:
raise ValueError('Only "Bazin" features are implemented!')
raise ValueError('Only "Bazin" and "malanchev" features are implemented!')

elif sum(surv_flag):
raise ValueError('Criteria needs to be "1" or "2". \n ' + \
Expand Down
2 changes: 0 additions & 2 deletions resspect/feature_extractors/malanchev.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def fit(self, band: str) -> np.ndarray:

# build filter flag
band_indices = self.photometry['band'] == band
if not sum(band_indices) > (len(self.features_names) - 1):
return np.array([])

extractor = licu.Extractor(licu.AndersonDarlingNormal(),
licu.InterPercentileRange(0.05),
Expand Down
56 changes: 38 additions & 18 deletions resspect/time_domain_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def load_dataset(file_names_dict: dict, survey_name: str = 'DES',
Currently only nclass == 2 is implemented.
feature_extraction_method: str (optional)
Feature extraction method. The current implementation only
accepts method=='bazin' or 'photometry'.
accepts method=='bazin', 'photometry', or 'malanchev'.
Default is 'bazin'.
is_save_samples: bool (optional)
If True, save training and test samples to file.
Expand Down Expand Up @@ -137,10 +137,10 @@ def _load_first_loop_and_full_data(
number_of_classes
Number of classes to consider in the classification
Currently only number_of_classes == 2 is implemented.
feature_extraction_method
Chosen classifier.
The current implementation accepts `RandomForest`,
'GradientBoostedTrees', 'KNN', 'MLP', 'SVM' and 'NB'.
feature_extraction_method: str (optional)
Feature extraction method. The current implementation only
accepts method=='bazin', 'photometry', or 'malanchev'.
Default is 'bazin'.
is_save_samples
If True, save training and test samples to file.
Default is False.
Expand All @@ -151,14 +151,16 @@ def _load_first_loop_and_full_data(
file_names_dict=first_loop_file_name,
survey_name=survey_name, is_separate_files=is_separate_files,
initial_training=0, ia_training_fraction=ia_training_fraction,
is_queryable=is_queryable)
is_queryable=is_queryable,
feature_extraction_method=feature_extraction_method)
light_curve_file_name = {None: initial_light_curve_file_name['train']}
light_curve_data = load_dataset(
file_names_dict=light_curve_file_name,
survey_name=survey_name, is_separate_files=is_separate_files,
initial_training=initial_training,
ia_training_fraction=ia_training_fraction,
is_queryable=is_queryable)
is_queryable=is_queryable,
feature_extraction_method=feature_extraction_method)
else:
first_loop_file_name = {'pool': first_loop_file_name}
first_loop_data = load_dataset(
Expand Down Expand Up @@ -459,7 +461,7 @@ def _save_metrics_and_queried_sample(
def _load_next_day_data(
next_day_features_file_name: str, is_separate_files: bool,
is_queryable: bool, survey_name: str, ia_training_fraction: float,
is_save_samples: bool):
is_save_samples: bool, feature_extraction_method: str='bazin'):
"""
Loads features of next day
Expand Down Expand Up @@ -488,13 +490,14 @@ def _load_next_day_data(
next_day_data = load_dataset(
next_day_features_file_name, survey_name, samples_list=['pool'],
is_separate_files=is_separate_files, is_queryable=is_queryable,
is_save_samples=is_save_samples
)
is_save_samples=is_save_samples,
feature_extraction_method=feature_extraction_method)
else:
next_day_features_file_name = {None: next_day_features_file_name}
next_day_data = load_dataset(
next_day_features_file_name, survey_name, is_queryable=is_queryable,
initial_training=0, ia_training_fraction=ia_training_fraction)
initial_training=0, ia_training_fraction=ia_training_fraction,
feature_extraction_method=feature_extraction_method)
return next_day_data


Expand Down Expand Up @@ -724,7 +727,8 @@ def process_next_day_loop(
is_separate_files: bool, is_queryable: bool, survey_name: str,
ia_training_fraction: float, id_key_name: str,
light_curve_train_ids: np.ndarray, is_save_samples: bool,
canonical_data: DataBase, strategy: str) -> DataBase:
canonical_data: DataBase, strategy: str,
feature_extraction_method: str='bazin') -> DataBase:
"""
Runs next day active learning loop
Expand Down Expand Up @@ -759,10 +763,15 @@ def process_next_day_loop(
Query strategy. Options are (all can be run with budget):
"UncSampling", "UncSamplingEntropy", "UncSamplingLeastConfident",
"UncSamplingMargin", "QBDMI", "QBDEntropy", "RandomSampling"
feature_extraction_method: str (optional)
Feature extraction method. The current implementation only
accepts method=='bazin' or 'photometry'.
Default is 'bazin'.
"""
next_day_data = _load_next_day_data(
next_day_features_file_name, is_separate_files, is_queryable,
survey_name, ia_training_fraction, is_save_samples)
survey_name, ia_training_fraction, is_save_samples,
feature_extraction_method=feature_extraction_method)
for metadata_value in light_curve_data.train_metadata[id_key_name].values:
next_day_pool_metadata = next_day_data.pool_metadata[id_key_name].values
if metadata_value in next_day_pool_metadata:
Expand Down Expand Up @@ -800,7 +809,7 @@ def run_time_domain_active_learning_loop(
light_curve_train_ids: np.ndarray, canonical_data: DataBase,
is_separate_files: bool, path_to_features_directory: str,
fname_pattern: list, survey_name: str, ia_training_fraction: float,
is_save_samples: bool, **kwargs: dict):
is_save_samples: bool, feature_extraction_method: str='bazin', **kwargs: dict):
"""
Runs time domain active learning loop
Expand Down Expand Up @@ -861,6 +870,10 @@ def run_time_domain_active_learning_loop(
is_save_samples
If True, save training and test samples to file.
Default is False.
feature_extraction_method: str (optional)
Feature extraction method. The current implementation only
accepts method=='bazin' or 'photometry'.
Default is 'bazin'.
kwargs
All keywords required by the classifier function.
Expand Down Expand Up @@ -900,7 +913,8 @@ def run_time_domain_active_learning_loop(
light_curve_data = process_next_day_loop(
light_curve_data, next_day_features_file_name, is_separate_files,
is_queryable, survey_name, ia_training_fraction, id_key_name,
light_curve_train_ids, is_save_samples, canonical_data, strategy)
light_curve_train_ids, is_save_samples, canonical_data, strategy,
feature_extraction_method=feature_extraction_method)


# TODO: Too many arguments. Refactor and update docs
Expand All @@ -915,6 +929,7 @@ def time_domain_loop(days: list, output_metrics_file: str,
query_thre: float = 1.0, save_samples: bool = False,
sep_files: bool = False, survey: str = 'LSST',
initial_training: str = 'original',
feature_extraction_method: str = 'bazin',
save_full_query: bool = False, **kwargs):
"""
Perform the active learning loop. All results are saved to file.
Expand Down Expand Up @@ -992,6 +1007,10 @@ def time_domain_loop(days: list, output_metrics_file: str,
If int: choose the required number of samples at random,
ensuring that at least half are SN Ia
Default is 'original'.
feature_extraction_method: str (optional)
Feature extraction method. The current implementation only
accepts method=='bazin' or 'photometry'.
Default is 'bazin'.
"""

# load features for the first obs day
Expand All @@ -1001,8 +1020,8 @@ def time_domain_loop(days: list, output_metrics_file: str,

first_loop_data, light_curve_data = _load_first_loop_and_full_data(
first_loop_file_name, path_to_ini_files, survey, initial_training,
ia_frac, queryable, sep_files, nclass, is_save_samples=save_samples
)
ia_frac, queryable, sep_files, nclass, is_save_samples=save_samples,
feature_extraction_method = feature_extraction_method)

# get keyword for obj identification
id_key_name = light_curve_data.identify_keywords()
Expand All @@ -1024,7 +1043,8 @@ def time_domain_loop(days: list, output_metrics_file: str,
strategy, budgets, queryable, query_thre, batch, output_metrics_file,
output_queried_file, save_full_query, id_key_name,
light_curve_train_ids, canonical_data, sep_files, path_to_features_dir,
fname_pattern, survey, ia_frac, save_samples, **kwargs)
fname_pattern, survey, ia_frac, save_samples,
feature_extraction_method=feature_extraction_method, **kwargs)

def main():
return None
Expand Down
26 changes: 15 additions & 11 deletions resspect/time_domain_snpcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

from resspect.feature_extractors.bazin import BazinFeatureExtractor
from resspect.feature_extractors.bump import BumpFeatureExtractor
from resspect.feature_extractors.malanchev import MalanchevFeatureExtractor
from resspect.lightcurves_utils import BAZIN_HEADERS
from resspect.lightcurves_utils import MALANCHEV_HEADERS
from resspect.lightcurves_utils import get_files_list
from resspect.lightcurves_utils import get_query_flags
from resspect.lightcurves_utils import maybe_create_directory
Expand All @@ -34,20 +36,22 @@

FEATURE_EXTRACTOR_MAPPING = {
"bazin": BazinFeatureExtractor,
"bump": BumpFeatureExtractor
"bump": BumpFeatureExtractor,
"malanchev": MalanchevFeatureExtractor
}


FEATURE_EXTRACTOR_HEADERS_MAPPING = {
"bazin": BAZIN_HEADERS
"bazin": BAZIN_HEADERS,
"malanchev": MALANCHEV_HEADERS
}


class SNPCCPhotometry:
"""
Handles photometric information for entire SNPCC data.
This class only works for Bazin feature extraction method.
This class only works for Bazin and Malanchev feature extraction methods.
Attributes
----------
Expand Down Expand Up @@ -105,15 +109,15 @@ def create_daily_file(self, output_dir: str,
If True, calculate cost of taking a spectra in the last
observed photometric point. Default is False.
feature_extractor: str
Feature extraction method, only possibility is 'Bazin'.
Feature extraction method, only possibilities are 'Bazin' and 'malanchev'.
"""
maybe_create_directory(output_dir)
self._features_file_name = os.path.join(
output_dir, 'day_' + str(day) + '.csv')
logging.info('Creating features file')
with open(self._features_file_name, 'w') as features_file:
if feature_extractor not in FEATURE_EXTRACTOR_HEADERS_MAPPING:
raise ValueError('Only Bazin headers are supported')
raise ValueError('Only Bazin and Malanchev headers are supported')
self._header = FEATURE_EXTRACTOR_HEADERS_MAPPING[
feature_extractor]['snpcc_header']
if get_cost:
Expand Down Expand Up @@ -152,7 +156,7 @@ def _maybe_create_features_file(self, output_dir: str, day_of_survey: int,
day_of_survey: int
Day since the beginning of survey.
feature_extractor: str
Feature extraction method, only possibility is 'Bazin'.
Feature extraction method, only possibilities are 'Bazin' and 'malanchev'.
get_cost: bool
if True, cost of taking a spectra is computed.
"""
Expand All @@ -171,13 +175,13 @@ def _verify_dataset_and_features_method(dataset_name: str,
dataset_name: str
name of the dataset used
feature_extractor: str
Feature extraction method, only possibility is 'Bazin'.
Feature extraction method, only possibilities are 'Bazin' and 'malanchev'.
"""
if dataset_name != 'SNPCC':
raise ValueError('This class supports only SNPCC dataset!')
# TODO: Update when bump headers are available
if feature_extractor != 'bazin':
raise ValueError('Only bazin features are implemented!!')
if feature_extractor != 'bazin' and feature_extractor!='malanchev':
raise ValueError('Only bazin and malanchev features are implemented!!')

def _check_queryable(self, light_curve_data,
queryable_criteria: int,
Expand Down Expand Up @@ -384,7 +388,7 @@ def build_one_epoch(self, raw_data_dir: str, day_of_survey: int,
"""
Fit features for all objects with enough points in a given day.
Generate 1 file containing best-fit Bazin parameters for a given
Generate 1 file containing best-fit Bazin or malanchev parameters for a given
day of the survey.
Parameters
Expand All @@ -404,7 +408,7 @@ def build_one_epoch(self, raw_data_dir: str, day_of_survey: int,
Only used if "queryable_criteria == 2". Default is 2.
feature_extractor: str (optional)
Feature extraction method.
Only possibility is 'Bazin'.
Only possibilities are 'Bazin' or 'malanchev'.
get_cost: bool (optional)
If True, calculate cost of taking a spectra in the last
observed photometric point. Default is False.
Expand Down

0 comments on commit ffbd689

Please sign in to comment.