Skip to content

Commit

Permalink
Merge pull request #51 from JoshYuJump/3.0-stable
Browse files Browse the repository at this point in the history
Async model manager
  • Loading branch information
JoshYuJump authored Mar 12, 2022
2 parents 0080987 + f51a277 commit 50d3320
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 45 deletions.
2 changes: 1 addition & 1 deletion bali/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '3.0.0-rc.2'
__version__ = '3.0.0-rc.3'
98 changes: 75 additions & 23 deletions bali/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
import logging
import warnings
from functools import wraps

from contextlib import asynccontextmanager
from sqla_wrapper import SQLAlchemy
from sqla_wrapper import SQLAlchemy, BaseModel
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.decl_api import DeclarativeMeta

from .models import get_base_model

# from core.config import settings
from .models import included_models

# TODO: Removed logging according 12factor
error_logger = logging.getLogger('error')

database_schema_async_maps = [
('sqlite://', 'sqlite+aiosqlite://'),
('mysql+pymysql://', 'mysql+aiomysql://'),
('postgres://', 'postgresql+asyncpg://'),
]


# noinspection PyPep8Naming
class DB:
Expand Down Expand Up @@ -50,15 +44,19 @@ def connect(
engine_options=engine_options,
session_options=session_options,
)
async_database_uri = database_uri
for sync_schema, async_schema in database_schema_async_maps:
async_database_uri = async_database_uri.replace(
sync_schema, async_schema
)
self._db.Model = self._db.registry.generate_base(
cls=BaseModel,
name="Model",
metaclass=AsyncModelDeclarativeMeta,
)

async_database_uri = get_async_database_uri(database_uri)
self._async_engine = create_async_engine(async_database_uri)

self.async_session = sessionmaker(
self._async_engine, class_=AsyncSession, expire_on_commit=False
self._async_engine,
class_=AsyncSession,
expire_on_commit=False,
)

def __getattribute__(self, attr, *args, **kwargs):
Expand All @@ -68,20 +66,45 @@ def __getattribute__(self, attr, *args, **kwargs):
if not self._db:
raise Exception('Database session not initialized')

# BaseModel
if attr == 'BaseModel':
return get_base_model(self)
# BaseModels
if attr in included_models:
return included_models[attr](self)

return getattr(self._db, attr)


db = DB()


def get_async_database_uri(database_uri):
"""
Transform populate database schema to async format,
which is used by SQLA-Wrapper
"""
uri = database_uri
database_schema_async_maps = [
('sqlite://', 'sqlite+aiosqlite://'),
('mysql+pymysql://', 'mysql+aiomysql://'),
('postgres://', 'postgresql+asyncpg://'),
]
for sync_schema, async_schema in database_schema_async_maps:
uri = uri.replace(sync_schema, async_schema)
return uri


MAXIMUM_RETRY_ON_DEADLOCK: int = 3


def retry_on_deadlock_decorator(func):
lock_messages_error = ['Deadlock found', 'Lock wait timeout exceeded']
warnings.warn(
'retry_on_deadlock_decorator will remove in 3.2',
DeprecationWarning,
)

lock_messages_error = [
'Deadlock found',
'Lock wait timeout exceeded',
]

@wraps(func)
def wrapper(*args, **kwargs):
Expand All @@ -94,8 +117,8 @@ def wrapper(*args, **kwargs):
if any(msg in e.message for msg in lock_messages_error) \
and attempt_count <= MAXIMUM_RETRY_ON_DEADLOCK:
error_logger.error(
'Deadlock detected. Trying sql transaction once more. Attempts count: %s'
% (attempt_count + 1)
'Deadlock detected. Trying sql transaction once more. '
'Attempts count: %s' % (attempt_count + 1)
)
else:
raise
Expand All @@ -105,6 +128,11 @@ def wrapper(*args, **kwargs):


def close_connection(func):
warnings.warn(
'retry_on_deadlock_decorator will remove in 3.2',
DeprecationWarning,
)

def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
Expand All @@ -114,3 +142,27 @@ def wrapper(*args, **kwargs):
return result

return wrapper


class AsyncModelDeclarativeMeta(DeclarativeMeta):
"""Make db.BaseModel support async using this metaclass"""
def __getattribute__(self, attr):
if attr == 'aio':
aio = super().__getattribute__(attr)
if any([aio.db is None, aio.model is None]):
aio = type(
f'Aio{aio.__qualname__}',
aio.__bases__,
dict(aio.__dict__),
)
setattr(aio, 'db', self._db)
setattr(aio, 'model', self)
return aio

return super().__getattribute__(attr)

# noinspection PyMethodParameters
def __call__(self, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
instance.aio = self.aio(instance)
return instance
125 changes: 107 additions & 18 deletions bali/db/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
"""Model included
`db.BaseModel` is the most common model.
If you don't use `db.BaseModel`, you can compose Mixins to `db.Model`
Import Mixins in your project examples:
```python
from bali.db.models import GenericModelMixin, AsyncModelMixin
```
"""

from contextvars import ContextVar
from datetime import datetime
from typing import List, Dict

import pytz
from sqlalchemy import Column, DateTime, Boolean
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.future import select
from sqlalchemy.inspection import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql.functions import func
from sqlalchemy.types import TypeDecorator
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.inspection import inspect

from ..utils import timezone

Expand Down Expand Up @@ -39,13 +53,12 @@ def process_bind_param(self, value, _):


def get_base_model(db):
class BaseModel(db.Model):
class BaseModel(db.Model, GenericModelMixin, AsyncModelMixin):
__abstract__ = True
__asdict_include_hybrid_properties__ = False

created_time = Column(DateTime, default=datetime.utcnow)
updated_time = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = Column(Boolean, default=True)
# Bind SQLA-wrapper database to model
_db = db

@classmethod
def exists(cls, **attrs):
Expand All @@ -55,7 +68,9 @@ def exists(cls, **attrs):

@classmethod
def create(cls, **attrs):
"""Create and persist a new record for the model, and returns it."""
"""
Create and persist a new record for the model, and returns it.
"""
return cls(**attrs).save()

@classmethod
Expand Down Expand Up @@ -109,7 +124,9 @@ def _asdict(self, **kwargs):
for i in inspect(type(self)).all_orm_descriptors:
if isinstance(i, InstrumentedAttribute):
output_fields.append(i.key)
elif isinstance(i, hybrid_property) and include_hybrid_properties:
elif isinstance(
i, hybrid_property
) and include_hybrid_properties:
output_fields.append(i.__name__)

return {i: getattr(self, i, None) for i in output_fields}
Expand Down Expand Up @@ -145,11 +162,9 @@ def update_or_create(cls, defaults: Dict = None, **kwargs):
try:
try:
instance = (
db.s.query(cls)
.filter_by(**kwargs)
.populate_existing()
.with_for_update()
.one()
db.s.query(cls).filter_by(
**kwargs
).populate_existing().with_for_update().one()
)
except NoResultFound:
instance = cls(**{**kwargs, **(defaults or {})}) # noqa
Expand All @@ -159,11 +174,9 @@ def update_or_create(cls, defaults: Dict = None, **kwargs):
except SQLAlchemyError:
db.s.rollback()
instance = (
db.s.query(cls)
.filter_by(**kwargs)
.populate_existing()
.with_for_update()
.one()
db.s.query(cls).filter_by(
**kwargs
).populate_existing().with_for_update().one()
)
else:
return instance, True
Expand All @@ -179,3 +192,79 @@ def update_or_create(cls, defaults: Dict = None, **kwargs):
return instance, False

return BaseModel


class AsyncModelManager:
"""Async model bind to aio"""

db = None
model = None

def __init__(self, instance):
self.instance = instance

@classmethod
async def exists(cls, **attrs):
async with cls.db.async_session() as async_session:
stmt = select(cls.model).filter_by(**attrs)
result = await async_session.execute(stmt)
return bool(result.scalars().first())

@classmethod
async def create(cls, **attrs):
await cls.model(**attrs).aio.save()

@classmethod
async def first(cls, **attrs):
async with cls.db.async_session() as async_session:
stmt = select(cls.model).filter_by(**attrs)
result = await async_session.execute(stmt)
return result.scalars().first()

async def save(self):
async with self.db.async_session() as async_session:
async_session.add(self.instance)
await async_session.commit()
return self.instance

async def delete(self):
async with self.db.async_session() as async_session:
async_session.delete(self.instance)
await async_session.commit()


# expose the include models and model creator
included_models = {
'BaseModel': get_base_model,
}

# -------------------- Models Mixins -------------------- #


class GenericModelMixin:
"""Generic model include the following fields"""
created_time = Column(DateTime, default=datetime.utcnow)
updated_time = Column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
is_active = Column(Boolean, default=True)


class AsyncModelMixin:
"""Async models methods
All async method accessed by `aio`
```python
# Model
async def get_first_user()
await User.aio.first()
# Instance
user = User()
user.aio.save()
```
"""

aio = AsyncModelManager
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
aiomysql
dateparser==1.0.0
decamelize==0.1.2
fastapi[all]==0.63.0
Expand Down
9 changes: 9 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from sqlalchemy.orm.decl_api import DeclarativeMeta
DeclarativeMeta = None

import pytest
from sqlalchemy import Column, Integer, String
from sqlalchemy.future import select
Expand Down Expand Up @@ -31,6 +34,12 @@ def get_by_username_sync(cls, username):
return User.first(username=username)


class Book(db.BaseModel):
__tablename__ = "books"
id = Column(Integer, primary_key=True)
title = Column(String(20), index=True)


db.create_all()


Expand Down
9 changes: 8 additions & 1 deletion tests/test_db_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class User(db.BaseModel):
username = Column(String(50), default='')
age = Column(Integer)

class Book(db.BaseModel):
__tablename__ = "books"
id = Column(Integer, primary_key=True)
title = Column(String(20), index=True)

db.create_all()
lucy = User.create(**{
'username': 'Lucy',
Expand All @@ -32,7 +37,9 @@ class User(db.BaseModel):
assert lucy.id > 0

users = User.query().filter(User.username.like('%c%'), User.age > 0).all()
assert len(users) == 2, 'Fetch users count in common way should work properly'
assert len(
users
) == 2, 'Fetch users count in common way should work properly'

filters = {
'username__like': '%c%',
Expand Down
Loading

0 comments on commit 50d3320

Please sign in to comment.