Skip to content

Commit

Permalink
feat: interface
Browse files Browse the repository at this point in the history
  • Loading branch information
taharallouche committed Oct 26, 2024
1 parent c9b13e4 commit 8fa4e51
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 102 deletions.
64 changes: 0 additions & 64 deletions hakeem/core/aggregation/aggregators.py

This file was deleted.

Empty file.
30 changes: 30 additions & 0 deletions hakeem/core/aggregation/aggregators/condorcet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import pandas as pd

from hakeem.core.aggregation.base import WeightedAggregator
from hakeem.core.utils.inventory import COLUMNS, DEFAULT_RELIABILITY_BOUNDS


class CondorcetAggregator(WeightedAggregator):
def __init__(
self,
lower_reliability_bound: float = DEFAULT_RELIABILITY_BOUNDS.lower,
upper_reliability_bound: float = DEFAULT_RELIABILITY_BOUNDS.upper,
task_column: str = COLUMNS.question,
worker_column: str = COLUMNS.voter,
):
super().__init__(task_column, worker_column)
self.lower_reliability_bound = lower_reliability_bound
self.upper_reliability_bound = upper_reliability_bound

def compute_weights(self, annotations: pd.DataFrame) -> pd.Series:
vote_size = annotations.sum(axis=1)
reliabilities = (len(annotations.columns) - vote_size - 1) / (
len(annotations.columns) - 2
)
reliabilities = reliabilities.clip(
self.lower_reliability_bound, self.upper_reliability_bound
)
weights = np.log(reliabilities / (1 - reliabilities))

return weights
31 changes: 31 additions & 0 deletions hakeem/core/aggregation/aggregators/mallows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import pandas as pd

from hakeem.core.aggregation.base import WeightedAggregator


class StandardApprovalAggregator(WeightedAggregator):
@staticmethod
def compute_weights(annotations: pd.DataFrame) -> pd.Series:
return pd.Series(1, index=annotations.index)


class EuclidAggregator(WeightedAggregator):
@staticmethod
def compute_weights(annotations: pd.DataFrame) -> pd.Series:
vote_size = annotations.sum(axis=1)
return np.sqrt(vote_size + 1) - np.sqrt(vote_size - 1)


class JaccardAggregator(WeightedAggregator):
@staticmethod
def compute_weights(annotations: pd.DataFrame) -> pd.Series:
vote_size = annotations.sum(axis=1)
return 1 / vote_size


class DiceAggregator(WeightedAggregator):
@staticmethod
def compute_weights(annotations: pd.DataFrame) -> pd.Series:
vote_size = annotations.sum(axis=1)
return 2 / (vote_size + 1)
60 changes: 45 additions & 15 deletions hakeem/core/aggregation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,70 @@


class Aggregator(ABC):
_name: str
def __init__(
self, task_column: str = COLUMNS.question, worker_column: str = COLUMNS.voter
) -> None:
self.task_column = task_column
self.worker_column = worker_column

def fit_predict(self, annotations: pd.DataFrame) -> pd.DataFrame:
annotations = self._coerce_annotations(annotations)
return self._aggregate(annotations)

@abstractmethod
def aggregate(self, annotations: pd.DataFrame) -> pd.DataFrame:
def _aggregate(self, annotations: pd.DataFrame) -> pd.DataFrame:
raise NotImplementedError

def __str__(self) -> str:
return self._name
def _coerce_annotations(self, annotations: pd.DataFrame) -> pd.DataFrame:
all_columns = annotations.reset_index().columns
required = [self.task_column, self.worker_column]

if missing := set(required) - set(all_columns):
raise ValueError(
f"Annotations should have {self.task_column} and"
f" {self.worker_column} as columns or index levels, missing {missing}."
)

class VoterMixin:
@staticmethod
def _get_aggregated_labels(votes: pd.DataFrame) -> pd.DataFrame:
scores = votes.groupby(COLUMNS.question, sort=False)[votes.columns].sum()
if set(all_columns) == set(required):
raise ValueError("Annotations should have at least one label column")

annotations = annotations.reset_index().set_index(required)[
[column for column in annotations.columns if column not in required]
]

return annotations

scores = scores.reindex(votes.index.get_level_values(COLUMNS.question).unique())

class WeightedApprovalMixin:
@staticmethod
def _get_aggregated_labels(
weighted_answers: pd.DataFrame, task_column: str
) -> pd.DataFrame:
scores = weighted_answers.groupby(task_column, sort=False)[
weighted_answers.columns
].sum()

scores = scores.reindex(
weighted_answers.index.get_level_values(task_column).unique()
)

winning_alternatives = scores.idxmax(axis=1).astype(
pd.CategoricalDtype(categories=votes.columns)
pd.CategoricalDtype(categories=weighted_answers.columns)
)

aggregated_labels = pd.get_dummies(winning_alternatives)

return aggregated_labels


class WeightedAggregator(Aggregator, VoterMixin):
class WeightedAggregator(Aggregator, WeightedApprovalMixin):
@abstractmethod
def _compute_weights(self, annotations: pd.DataFrame) -> pd.Series:
def compute_weights(self, annotations: pd.DataFrame) -> pd.Series:
raise NotImplementedError

def aggregate(self, annotations: pd.DataFrame) -> pd.DataFrame:
weights = self._compute_weights(annotations)
def _aggregate(self, annotations: pd.DataFrame) -> pd.DataFrame:
weights = self.compute_weights(annotations)

weighted_answers = annotations.multiply(weights, axis="index")

return self._get_aggregated_labels(weighted_answers)
return self._get_aggregated_labels(weighted_answers, self.task_column)
36 changes: 18 additions & 18 deletions hakeem/paper_results/evaluation/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from collections.abc import Iterable
from random import sample
from typing import Mapping

Expand All @@ -10,8 +9,10 @@
from sklearn.metrics import accuracy_score
from tqdm import tqdm

from hakeem.core.aggregation.aggregators import (
from hakeem.core.aggregation.aggregators.condorcet import (
CondorcetAggregator,
)
from hakeem.core.aggregation.aggregators.mallows import (
DiceAggregator,
EuclidAggregator,
JaccardAggregator,
Expand All @@ -31,20 +32,19 @@ def compare_methods(
groundtruth: pd.DataFrame,
max_voters: int,
n_batch: int,
aggregators: Iterable[Aggregator] = (
StandardApprovalAggregator(),
EuclidAggregator(),
JaccardAggregator(),
DiceAggregator(),
CondorcetAggregator(),
),
aggregators: Mapping[str, Aggregator] = {
"Standard Approval Aggregator": StandardApprovalAggregator(),
"Euclidean Mallow Aggregator": EuclidAggregator(),
"Jaccard Mallow Aggregator": JaccardAggregator(),
"Dice Mallow Aggregator": DiceAggregator(),
"Condorcet Aggregator": CondorcetAggregator(),
},
) -> dict[str, NDArray]:
accuracy = {
str(aggregator): np.zeros([n_batch, max_voters - 1])
for aggregator in aggregators
aggregator: np.zeros([n_batch, max_voters - 1]) for aggregator in aggregators
}
confidence_intervals = {
str(aggregator): np.zeros([max_voters - 1, 3]) for aggregator in aggregators
aggregator: np.zeros([max_voters - 1, 3]) for aggregator in aggregators
}

logging.info("Experiment started : running the different aggregators ...")
Expand All @@ -60,15 +60,15 @@ def compare_methods(
annotations.index.get_level_values(COLUMNS.voter).isin(voters)
]

for aggregator in aggregators:
aggregated_labels = aggregator.aggregate(annotations_batch)
accuracy[str(aggregator)][batch, num - 1] = accuracy_score(
for name, aggregator in aggregators.items():
aggregated_labels = aggregator.fit_predict(annotations_batch)
accuracy[name][batch, num - 1] = accuracy_score(
groundtruth, aggregated_labels
)

for aggregator in aggregators:
confidence_intervals[str(aggregator)][num - 1, :] = (
get_mean_confidence_interval(accuracy[str(aggregator)][:, num - 1])
for name in aggregators:
confidence_intervals[name][num - 1, :] = get_mean_confidence_interval(
accuracy[name][:, num - 1]
)

logging.info("Experiment completed, gathering the results ..")
Expand Down
Loading

0 comments on commit 8fa4e51

Please sign in to comment.