diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 0dd48140..ce025852 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -129,8 +129,10 @@ def __init__( self.len_neg_pairs = min([min(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg]) elif sampling_strategy == SamplingStrategy.OVERSAMPLING: - self.len_pos_pairs = min([max(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg]) - self.len_neg_pairs = min([max(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg]) + num_pos_or_neg_pairs = max(self.pos_pairs_combination, self.neg_pairs_combination) + # saveguard for either negative samples only or single positive. + self.len_pos_pairs = min([num_pos_or_neg_pairs, self.max_pos_or_neg]) if self.pos_pairs_combination > 0 else 0 + self.len_neg_pairs = min([num_pos_or_neg_pairs, self.max_pos_or_neg]) if self.neg_pairs_combination > 0 else 0 def generate_positive_pair(self) -> Generator[SentencePair, None, None]: pair_generator = shuffle_combinations(self.sentence_labels)