Skip to content

Commit

Permalink
Merge pull request #47 from Ed-XCF/feature/hybrid-property-to-dict
Browse files Browse the repository at this point in the history
to dict support hybrid_property and add Comparator
  • Loading branch information
JoshYuJump authored Dec 10, 2021
2 parents 41985a8 + 0787ab7 commit 04a02b8
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
19 changes: 19 additions & 0 deletions bali/db/comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from collections import defaultdict

from sqlalchemy import case
from sqlalchemy.ext.hybrid import Comparator
from sqlalchemy.util.langhelpers import dictlike_iteritems


class CaseComparator(Comparator):
def __init__(self, whens, expression):
super().__init__(expression)
self.whens, self.reversed_whens = dictlike_iteritems(whens), defaultdict(list)
for k, v in self.whens:
self.reversed_whens[v].append(k)

def __clause_element__(self):
return case(self.whens, self.expression)

def __eq__(self, other):
return super().__clause_element__().in_(self.reversed_whens[other])
7 changes: 4 additions & 3 deletions bali/db/connection.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ _BooleanField = Union[bool, Column[Optional[bool]]]

class BaseModel:
__abstract__ = True
__asdict_include_hybrid_properties__ = False
created_time: _DateTimeField
updated_time: _DateTimeField
is_active: _BooleanField

def _asdict(self) -> Dict[str, Any]: ...
def _asdict(self, include_hybrid_properties=__asdict_include_hybrid_properties__) -> Dict[str, Any]: ...

def to_dict(self) -> Dict[str, Any]: ...
def to_dict(self, include_hybrid_properties=__asdict_include_hybrid_properties__) -> Dict[str, Any]: ...

def dict(self) -> Dict[str, Any]: ...
def dict(self, include_hybrid_properties=__asdict_include_hybrid_properties__) -> Dict[str, Any]: ...

@classmethod
def exists(cls: Type[_M], **attrs) -> bool: ...
Expand Down
15 changes: 13 additions & 2 deletions bali/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import pytz
from sqlalchemy import Column, DateTime, Boolean
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql.functions import func
from sqlalchemy.types import TypeDecorator
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.inspection import inspect

from ..utils import timezone

Expand Down Expand Up @@ -38,6 +41,7 @@ def process_bind_param(self, value, _):
def get_base_model(db):
class BaseModel(db.Model):
__abstract__ = True
__asdict_include_hybrid_properties__ = False

created_time = Column(DateTime, default=datetime.utcnow)
updated_time = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
Expand Down Expand Up @@ -95,8 +99,15 @@ def delete(self):
db.session.delete(self)
db.session.commit() if context_auto_commit.get() else db.session.flush()

def _asdict(self):
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns}
def _asdict(self, include_hybrid_properties=__asdict_include_hybrid_properties__):
output_fields = []
for i in inspect(type(self)).all_orm_descriptors:
if isinstance(i, InstrumentedAttribute):
output_fields.append(i.key)
elif isinstance(i, hybrid_property) and include_hybrid_properties:
output_fields.append(i.__name__)

return {i: getattr(self, i, None) for i in output_fields}

dict = to_dict = _asdict

Expand Down
13 changes: 13 additions & 0 deletions bali/db/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..exceptions import OperatorModelError

OPERATOR_SPLITTER = '__'
REVERSER = "-"

OPERATORS = {
'isnull': lambda c, v: (c == None) if v else (c != None), # noqa
Expand Down Expand Up @@ -68,3 +69,15 @@ def get_filters_expr(cls, **filters):
expressions.append(op(column, value))

return expressions


def dj_lookup_to_sqla(expression: str) -> Tuple:
col_name, op_name = expression, "exact"
if OPERATOR_SPLITTER in col_name:
col_name, op_name = col_name.rsplit(OPERATOR_SPLITTER, 1)
return OPERATORS[op_name], col_name


def dj_ordering_to_sqla(expression: str):
wrapper = desc if expression.startswith(REVERSER) else asc
return wrapper(expression.lstrip(REVERSER))

0 comments on commit 04a02b8

Please sign in to comment.