diff --git a/bali/db/comparator.py b/bali/db/comparators.py similarity index 100% rename from bali/db/comparator.py rename to bali/db/comparators.py diff --git a/bali/db/connection.pyi b/bali/db/connection.pyi index fcc9a47..176b678 100644 --- a/bali/db/connection.pyi +++ b/bali/db/connection.pyi @@ -18,11 +18,11 @@ class BaseModel: updated_time: _DateTimeField is_active: _BooleanField - def _asdict(self, include_hybrid_properties=__asdict_include_hybrid_properties__) -> Dict[str, Any]: ... + def _asdict(self, **kwargs) -> Dict[str, Any]: ... - def to_dict(self, include_hybrid_properties=__asdict_include_hybrid_properties__) -> Dict[str, Any]: ... + def to_dict(self, **kwargs) -> Dict[str, Any]: ... - def dict(self, include_hybrid_properties=__asdict_include_hybrid_properties__) -> Dict[str, Any]: ... + def dict(self, **kwargs) -> Dict[str, Any]: ... @classmethod def exists(cls: Type[_M], **attrs) -> bool: ... diff --git a/bali/db/models.py b/bali/db/models.py index 8ef549f..6b54218 100644 --- a/bali/db/models.py +++ b/bali/db/models.py @@ -99,7 +99,12 @@ def delete(self): db.session.delete(self) db.session.commit() if context_auto_commit.get() else db.session.flush() - def _asdict(self, include_hybrid_properties=__asdict_include_hybrid_properties__): + def _asdict(self, **kwargs): + include_hybrid_properties = kwargs.setdefault( + "include_hybrid_properties", + self.__asdict_include_hybrid_properties__ + ) + output_fields = [] for i in inspect(type(self)).all_orm_descriptors: if isinstance(i, InstrumentedAttribute): diff --git a/bali/utils/timezone.py b/bali/utils/timezone.py index 855118b..24477f2 100644 --- a/bali/utils/timezone.py +++ b/bali/utils/timezone.py @@ -1,5 +1,6 @@ +import calendar import os -from datetime import datetime +from datetime import datetime, date, timedelta from typing import Union import pytz @@ -66,7 +67,7 @@ def make_naive( return value.astimezone(timezone).replace(tzinfo=None) -def localtime(value: datetime = None, timezone: StrTzInfoType = None): +def localtime(value: datetime = None, timezone: StrTzInfoType = None) -> datetime: value, timezone = value or now(), timezone or get_current_timezone() if isinstance(timezone, str): timezone = pytz.timezone(timezone) @@ -75,5 +76,26 @@ def localtime(value: datetime = None, timezone: StrTzInfoType = None): return value.astimezone(timezone) -def localdate(value: datetime = None, timezone: StrTzInfoType = None): +def localdate(value: datetime = None, timezone: StrTzInfoType = None) -> date: return localtime(value, timezone).date() + + +def start_of( + granularity: str, + value: datetime = None, + *, + timezone: StrTzInfoType = None, +) -> datetime: + value = localtime(value, timezone) + if granularity == "year": + value = value.replace(month=1, day=1) + elif granularity == "month": + value = value.replace(day=1) + elif granularity == "week": + value = value - timedelta(days=calendar.weekday(value.year, value.month, value.day)) + elif granularity == "day": + pass + else: + raise ValueError("Granularity must be year, month, week or day") + + return value.replace(hour=0, minute=0, second=0, microsecond=0)