Skip to content

Commit

Permalink
Bring BaseModel's methods to Async model manager
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshYuJump committed Mar 12, 2022
1 parent 3dcde6b commit fe3094c
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 28 deletions.
2 changes: 1 addition & 1 deletion bali/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '3.0.0-rc.2'
__version__ = '3.0.0-rc.3'
45 changes: 39 additions & 6 deletions bali/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import warnings
from functools import wraps

from sqla_wrapper import SQLAlchemy
from sqla_wrapper import SQLAlchemy, BaseModel
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.decl_api import DeclarativeMeta

from .models import get_base_model
from .models import included_models

# TODO: Removed logging according 12factor
error_logger = logging.getLogger('error')
Expand Down Expand Up @@ -43,11 +44,19 @@ def connect(
engine_options=engine_options,
session_options=session_options,
)
self._db.Model = self._db.registry.generate_base(
cls=BaseModel,
name="Model",
metaclass=AsyncModelDeclarativeMeta,
)

async_database_uri = get_async_database_uri(database_uri)
self._async_engine = create_async_engine(async_database_uri)

self.async_session = sessionmaker(
self._async_engine, class_=AsyncSession, expire_on_commit=False
self._async_engine,
class_=AsyncSession,
expire_on_commit=False,
)

def __getattribute__(self, attr, *args, **kwargs):
Expand All @@ -57,9 +66,9 @@ def __getattribute__(self, attr, *args, **kwargs):
if not self._db:
raise Exception('Database session not initialized')

# BaseModel
if attr == 'BaseModel':
return get_base_model(self)
# BaseModels
if attr in included_models:
return included_models[attr](self)

return getattr(self._db, attr)

Expand Down Expand Up @@ -133,3 +142,27 @@ def wrapper(*args, **kwargs):
return result

return wrapper


class AsyncModelDeclarativeMeta(DeclarativeMeta):
"""Make db.BaseModel support async using this metaclass"""
def __getattribute__(self, attr):
if attr == 'aio':
aio = super().__getattribute__(attr)
if any([aio.db is None, aio.model is None]):
aio = type(
f'Aio{aio.__qualname__}',
aio.__bases__,
dict(aio.__dict__),
)
setattr(aio, 'db', self._db)
setattr(aio, 'model', self)
return aio

return super().__getattribute__(attr)

# noinspection PyMethodParameters
def __call__(self, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
instance.aio = self.aio(instance)
return instance
129 changes: 111 additions & 18 deletions bali/db/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
"""Model included
`db.BaseModel` is the most common model.
If you don't use `db.BaseModel`, you can compose Mixins to `db.Model`
Import Mixins in your project examples:
```python
from bali.db.models import GenericModelMixin, AsyncModelMixin
```
"""

from contextvars import ContextVar
from datetime import datetime
from typing import List, Dict

import pytz
from sqlalchemy import Column, DateTime, Boolean
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.future import select
from sqlalchemy.inspection import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql.functions import func
from sqlalchemy.types import TypeDecorator
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.inspection import inspect

from ..utils import timezone

Expand Down Expand Up @@ -39,13 +53,12 @@ def process_bind_param(self, value, _):


def get_base_model(db):
class BaseModel(db.Model):
class BaseModel(db.Model, GenericModelMixin, AsyncModelMixin):
__abstract__ = True
__asdict_include_hybrid_properties__ = False

created_time = Column(DateTime, default=datetime.utcnow)
updated_time = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = Column(Boolean, default=True)
# Bind SQLA-wrapper database to model
_db = db

@classmethod
def exists(cls, **attrs):
Expand All @@ -55,7 +68,9 @@ def exists(cls, **attrs):

@classmethod
def create(cls, **attrs):
"""Create and persist a new record for the model, and returns it."""
"""
Create and persist a new record for the model, and returns it.
"""
return cls(**attrs).save()

@classmethod
Expand Down Expand Up @@ -109,7 +124,9 @@ def _asdict(self, **kwargs):
for i in inspect(type(self)).all_orm_descriptors:
if isinstance(i, InstrumentedAttribute):
output_fields.append(i.key)
elif isinstance(i, hybrid_property) and include_hybrid_properties:
elif isinstance(
i, hybrid_property
) and include_hybrid_properties:
output_fields.append(i.__name__)

return {i: getattr(self, i, None) for i in output_fields}
Expand Down Expand Up @@ -145,11 +162,9 @@ def update_or_create(cls, defaults: Dict = None, **kwargs):
try:
try:
instance = (
db.s.query(cls)
.filter_by(**kwargs)
.populate_existing()
.with_for_update()
.one()
db.s.query(cls).filter_by(
**kwargs
).populate_existing().with_for_update().one()
)
except NoResultFound:
instance = cls(**{**kwargs, **(defaults or {})}) # noqa
Expand All @@ -159,11 +174,9 @@ def update_or_create(cls, defaults: Dict = None, **kwargs):
except SQLAlchemyError:
db.s.rollback()
instance = (
db.s.query(cls)
.filter_by(**kwargs)
.populate_existing()
.with_for_update()
.one()
db.s.query(cls).filter_by(
**kwargs
).populate_existing().with_for_update().one()
)
else:
return instance, True
Expand All @@ -179,3 +192,83 @@ def update_or_create(cls, defaults: Dict = None, **kwargs):
return instance, False

return BaseModel


def get_async_base_model(db=None):
return AsyncModelManager


class AsyncModelManager:
"""Async model bind to aio"""

db = None
model = None

def __init__(self, instance):
self.instance = instance

@classmethod
async def exists(cls, **attrs):
async with cls.db.async_session() as async_session:
stmt = select(cls.model).filter_by(**attrs)
result = await async_session.execute(stmt)
return bool(result.scalars().first())

@classmethod
async def create(cls, **attrs):
await cls.model(**attrs).aio.save()

@classmethod
async def first(cls, **attrs):
async with cls.db.async_session() as async_session:
stmt = select(cls.model).filter_by(**attrs)
result = await async_session.execute(stmt)
return result.scalars().first()

async def save(self):
async with self.db.async_session() as async_session:
async_session.add(self.instance)
await async_session.commit()
return self.instance

async def delete(self):
async with self.db.async_session() as async_session:
async_session.delete(self.instance)
await async_session.commit()


# expose the include models and model creator
included_models = {
'BaseModel': get_base_model,
}

# -------------------- Models Mixins -------------------- #


class GenericModelMixin:
"""Generic model include the following fields"""
created_time = Column(DateTime, default=datetime.utcnow)
updated_time = Column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
is_active = Column(Boolean, default=True)


class AsyncModelMixin:
"""Async models methods
All async method accessed by `aio`
```python
# Model
async def get_first_user()
await User.aio.first()
# Instance
user = User()
user.aio.save()
```
"""

aio = AsyncModelManager
9 changes: 9 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from sqlalchemy.orm.decl_api import DeclarativeMeta
DeclarativeMeta = None

import pytest
from sqlalchemy import Column, Integer, String
from sqlalchemy.future import select
Expand Down Expand Up @@ -31,6 +34,12 @@ def get_by_username_sync(cls, username):
return User.first(username=username)


class Book(db.BaseModel):
__tablename__ = "books"
id = Column(Integer, primary_key=True)
title = Column(String(20), index=True)


db.create_all()


Expand Down
9 changes: 8 additions & 1 deletion tests/test_db_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class User(db.BaseModel):
username = Column(String(50), default='')
age = Column(Integer)

class Book(db.BaseModel):
__tablename__ = "books"
id = Column(Integer, primary_key=True)
title = Column(String(20), index=True)

db.create_all()
lucy = User.create(**{
'username': 'Lucy',
Expand All @@ -32,7 +37,9 @@ class User(db.BaseModel):
assert lucy.id > 0

users = User.query().filter(User.username.like('%c%'), User.age > 0).all()
assert len(users) == 2, 'Fetch users count in common way should work properly'
assert len(
users
) == 2, 'Fetch users count in common way should work properly'

filters = {
'username__like': '%c%',
Expand Down
70 changes: 70 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from sqlalchemy import Column, Integer, String
from sqlalchemy.future import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import declarative_base

from bali.db import db

DB_URI = 'sqlite:///:memory:'

db.connect(DB_URI)

Base = declarative_base()


class User(db.BaseModel):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
username = Column(String(20), index=True)


class Book(db.BaseModel):
__tablename__ = "books"
id = Column(Integer, primary_key=True)
title = Column(String(20), index=True)


@pytest.mark.asyncio
async def test_db_async_model_exists():
# Create model schema to database
async with db._async_engine.begin() as conn:
await conn.run_sync(db.BaseModel.metadata.create_all)

username = 'Lorry'
non_username = 'Noname'
title = 'Design Pattern'

async with db.async_session() as async_session:
user = User(username=username)
book = Book(title=title)
async_session.add(user)
async_session.add(book)
await async_session.commit()

is_exists = await User.aio.exists(username=username)
assert isinstance(is_exists, bool)
assert is_exists, f'{username} should exists'

is_exists = await Book.aio.exists(title=title)
assert isinstance(is_exists, bool)
assert is_exists, f'Book {title} should exists'

is_exists = await User.aio.exists(username=non_username)
assert isinstance(is_exists, bool)
assert not is_exists, f'{non_username} should not exists'


@pytest.mark.asyncio
async def test_db_async_instance_save():
# Create model schema to database
async with db._async_engine.begin() as conn:
await conn.run_sync(db.BaseModel.metadata.create_all)

username = 'Jeff Inno'
user = User(username=username)
await user.aio.save()

is_exists = await User.aio.exists(username=username)
assert isinstance(is_exists, bool)
assert is_exists, 'User should persisted into database'
2 changes: 1 addition & 1 deletion tests/test_model_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bali.db import db
from bali.db.operators import get_filters_expr
from bali.decorators import action
from bali.resource import ModelResource
from bali.resources import ModelResource
from bali.schemas import ListRequest
from permissions import IsAuthenticated

Expand Down
2 changes: 1 addition & 1 deletion tests/test_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bali.db import db
from bali.db.operators import get_filters_expr
from bali.decorators import action
from bali.resource import Resource
from bali.resources import Resource
from bali.schemas import ListRequest
from permissions import IsAuthenticated

Expand Down

0 comments on commit fe3094c

Please sign in to comment.