Skip to content

Commit

Permalink
removed unnecessary np uses and casting
Browse files Browse the repository at this point in the history
  • Loading branch information
DemirTonchev committed Jan 13, 2025
1 parent d989f22 commit d229add
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/setfit/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,26 +111,26 @@ def __init__(
# calculate number of positive and negative combinations
label_counts = Counter(labels)
# postive number of pairs from an n element set without replacement
self.total_pos_pairs = int(sum([n * (n - 1) / 2 for n in label_counts.values()]))
self.pos_pairs_combination = sum([n * (n - 1) // 2 for n in label_counts.values()])
# negative product
self.total_neg_pairs = sum(a * b for a, b in combinations(label_counts.values(), 2))
self.neg_pairs_combination = sum(a * b for a, b in combinations(label_counts.values(), 2))

if num_iterations is not None and num_iterations > 0:
iterations = num_iterations * len(self.sentences)
self.len_pos_pairs = iterations if self.pos_pairs_combination > 0 else 0
self.len_neg_pairs = iterations if self.neg_pairs_combination > 0 else 0

elif sampling_strategy == SamplingStrategy.UNIQUE:
self.len_pos_pairs = int(np.min([self.total_pos_pairs, self.max_pos_or_neg]))
self.len_neg_pairs = int(np.min([self.total_neg_pairs, self.max_pos_or_neg]))
self.len_pos_pairs = min(self.pos_pairs_combination, self.max_pos_or_neg)
self.len_neg_pairs = min(self.neg_pairs_combination, self.max_pos_or_neg)

elif sampling_strategy == SamplingStrategy.UNDERSAMPLING:
self.len_pos_pairs = int(np.min([min(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg]))
self.len_neg_pairs = int(np.min([min(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg]))
self.len_pos_pairs = min([min(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg])
self.len_neg_pairs = min([min(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg])

elif sampling_strategy == SamplingStrategy.OVERSAMPLING:
self.len_pos_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg]))
self.len_neg_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg]))
self.len_pos_pairs = min([max(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg])
self.len_neg_pairs = min([max(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg])

def generate_positive_pair(self) -> Generator[SentencePair, None, None]:
pair_generator = shuffle_combinations(self.sentence_labels)
Expand Down

0 comments on commit d229add

Please sign in to comment.