Skip to content

Commit

Permalink
chore: add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
taharallouche committed Nov 9, 2024
1 parent 458f7c6 commit 9916f9e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
51 changes: 51 additions & 0 deletions hakeem/aggregation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,27 @@


class Aggregator(ABC):
"""
Abstract base class for aggregators.
This class provides a template for creating custom aggregation methods. It requires
subclasses to implement the `_aggregate` method, which performs the actual
aggregation logic.
Attributes:
task_column (str): The name of the column containing task identifiers.
worker_column (str): The name of the column containing worker identifiers.
Methods:
fit_predict(annotations: pd.DataFrame) -> pd.DataFrame:
Coerces the schema of the annotations DataFrame and applies the aggregation
method defined in the `_aggregate` method.
_aggregate(annotations: pd.DataFrame) -> pd.DataFrame:
Abstract method to be implemented by subclasses. This method should contain
the logic for aggregating the annotations.
"""

def __init__(
self, task_column: str = COLUMNS.question, worker_column: str = COLUMNS.voter
) -> None:
Expand All @@ -23,6 +44,19 @@ def _aggregate(self, annotations: pd.DataFrame) -> pd.DataFrame:


class WeightedApprovalMixin:
"""
A mixin class that provides functionality to aggregate weighted answers.
Static Methods
-------
_aggregate_weighted_answers(
weighted_answers: pd.DataFrame, task_column: str
) -> pd.DataFrame
Aggregates weighted answers by summing the scores for each task and
determining the winning alternatives.
"""

@staticmethod
def _aggregate_weighted_answers(
weighted_answers: pd.DataFrame, task_column: str
Expand All @@ -45,6 +79,23 @@ def _aggregate_weighted_answers(


class WeightedAggregator(Aggregator, WeightedApprovalMixin):
"""
A base class for aggregators that use weighted annotations.
This class extends the Aggregator and WeightedApprovalMixin classes and provides
a framework for aggregating annotations with weights. Subclasses must implement
the `compute_weights` method to define how weights are calculated.
Methods
-------
compute_weights(annotations: pd.DataFrame) -> pd.Series
Abstract method to compute weights for the given annotations. Must be
implemented by subclasses.
_aggregate(annotations: pd.DataFrame) -> pd.DataFrame
Aggregates the given annotations using the computed weights.
"""

@abstractmethod
def compute_weights(self, annotations: pd.DataFrame) -> pd.Series:
raise NotImplementedError
Expand Down
16 changes: 16 additions & 0 deletions hakeem/utils/coerce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
def coerce_schema(
annotations: pd.DataFrame, task_column: str, worker_column: str
) -> pd.DataFrame:
"""
Coerce the schema of the annotations DataFrame to ensure it contains the required
columns and reindex it based on the specified task and worker columns.
Parameters:
annotations (pd.DataFrame): The DataFrame containing annotation data.
task_column (str): The name of the column representing tasks.
worker_column (str): The name of the column representing workers.
Returns:
pd.DataFrame: A DataFrame reindexed by the task and worker columns,
containing only the label columns.
Raises:
ValueError: If the required columns are missing or if there are no label columns.
"""
all_columns = annotations.reset_index().columns
required = [task_column, worker_column]

Expand Down

0 comments on commit 9916f9e

Please sign in to comment.