Skip to content

Commit

Permalink
Fixed limits and warning in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Langdale committed Oct 18, 2024
1 parent 0df92ca commit 899bbce
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
20 changes: 11 additions & 9 deletions src/matchbox/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pandas import DataFrame
from pyarrow import Table as ArrowTable
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import (
LABEL_STYLE_TABLENAME_PLUS_COL,
MetaData,
Expand Down Expand Up @@ -39,6 +39,12 @@ class Cluster(BaseModel):
class SourceWarehouse(BaseModel):
"""A warehouse where source data for datasets in Matchbox can be found."""

model_config = ConfigDict(
populate_by_name=True,
extra="forbid",
arbitrary_types_allowed=True,
)

alias: str
db_type: str
user: str
Expand All @@ -48,11 +54,6 @@ class SourceWarehouse(BaseModel):
database: str
_engine: Engine | None = None

class Config:
populate_by_name = True
extra = "forbid"
arbitrary_types_allowed = True

@property
def engine(self) -> Engine:
if self._engine is None:
Expand Down Expand Up @@ -97,14 +98,15 @@ def from_engine(cls, engine: Engine, alias: str | None = None) -> "SourceWarehou
class Source(BaseModel):
"""A dataset that can be indexed in the Matchbox database."""

model_config = ConfigDict(
populate_by_name=True,
)

database: SourceWarehouse | None = None
db_pk: str
db_schema: str
db_table: str

class Config:
populate_by_name = True

def __str__(self) -> str:
return f"{self.db_schema}.{self.db_table}"

Expand Down
10 changes: 9 additions & 1 deletion src/matchbox/server/postgresql/utils/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ def query(
"""
tables: list[pa.Table] = []

if limit:
limit_base = limit // len(selector)
limit_remainder = limit % len(selector)

for source, fields in selector.items():
if model is None:
# We want raw data with no clusters
Expand All @@ -300,7 +304,11 @@ def query(
hash_query = _model_to_hashes(source, model, engine=engine)

if limit:
hash_query = hash_query.limit(limit / len(selector))
remain = 0
if limit_remainder:
remain = 1
limit_remainder -= 1
hash_query = hash_query.limit(limit_base + remain)

mb_hashes = sql_to_df(hash_query, engine, return_type="arrow")

Expand Down

0 comments on commit 899bbce

Please sign in to comment.