Skip to content

Commit

Permalink
safeguard for oversampling strategy, when DS is negatives only or sin…
Browse files Browse the repository at this point in the history
…gle pos pair
  • Loading branch information
DemirTonchev committed Jan 13, 2025
1 parent d229add commit ea88e9b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/setfit/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ def __init__(
self.len_neg_pairs = min([min(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg])

elif sampling_strategy == SamplingStrategy.OVERSAMPLING:
self.len_pos_pairs = min([max(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg])
self.len_neg_pairs = min([max(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg])
num_pos_or_neg_pairs = max(self.pos_pairs_combination, self.neg_pairs_combination)
# saveguard for either negative samples only or single positive.
self.len_pos_pairs = min([num_pos_or_neg_pairs, self.max_pos_or_neg]) if self.pos_pairs_combination > 0 else 0
self.len_neg_pairs = min([num_pos_or_neg_pairs, self.max_pos_or_neg]) if self.neg_pairs_combination > 0 else 0

def generate_positive_pair(self) -> Generator[SentencePair, None, None]:
pair_generator = shuffle_combinations(self.sentence_labels)
Expand Down

0 comments on commit ea88e9b

Please sign in to comment.