diff --git a/bali/__init__.py b/bali/__init__.py index 61a6551..96ec847 100644 --- a/bali/__init__.py +++ b/bali/__init__.py @@ -1 +1 @@ -__version__ = '3.0.0-rc.2' +__version__ = '3.0.0-rc.3' diff --git a/bali/db/connection.py b/bali/db/connection.py index cca4cf9..e23c089 100644 --- a/bali/db/connection.py +++ b/bali/db/connection.py @@ -1,25 +1,19 @@ import logging +import warnings from functools import wraps -from contextlib import asynccontextmanager -from sqla_wrapper import SQLAlchemy +from sqla_wrapper import SQLAlchemy, BaseModel from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm.decl_api import DeclarativeMeta -from .models import get_base_model - -# from core.config import settings +from .models import included_models +# TODO: Removed logging according 12factor error_logger = logging.getLogger('error') -database_schema_async_maps = [ - ('sqlite://', 'sqlite+aiosqlite://'), - ('mysql+pymysql://', 'mysql+aiomysql://'), - ('postgres://', 'postgresql+asyncpg://'), -] - # noinspection PyPep8Naming class DB: @@ -50,15 +44,19 @@ def connect( engine_options=engine_options, session_options=session_options, ) - async_database_uri = database_uri - for sync_schema, async_schema in database_schema_async_maps: - async_database_uri = async_database_uri.replace( - sync_schema, async_schema - ) + self._db.Model = self._db.registry.generate_base( + cls=BaseModel, + name="Model", + metaclass=AsyncModelDeclarativeMeta, + ) + + async_database_uri = get_async_database_uri(database_uri) self._async_engine = create_async_engine(async_database_uri) self.async_session = sessionmaker( - self._async_engine, class_=AsyncSession, expire_on_commit=False + self._async_engine, + class_=AsyncSession, + expire_on_commit=False, ) def __getattribute__(self, attr, *args, **kwargs): @@ -68,20 +66,45 @@ def __getattribute__(self, attr, *args, **kwargs): if not self._db: raise Exception('Database session not initialized') - # BaseModel - if attr == 'BaseModel': - return get_base_model(self) + # BaseModels + if attr in included_models: + return included_models[attr](self) return getattr(self._db, attr) db = DB() + +def get_async_database_uri(database_uri): + """ + Transform populate database schema to async format, + which is used by SQLA-Wrapper + """ + uri = database_uri + database_schema_async_maps = [ + ('sqlite://', 'sqlite+aiosqlite://'), + ('mysql+pymysql://', 'mysql+aiomysql://'), + ('postgres://', 'postgresql+asyncpg://'), + ] + for sync_schema, async_schema in database_schema_async_maps: + uri = uri.replace(sync_schema, async_schema) + return uri + + MAXIMUM_RETRY_ON_DEADLOCK: int = 3 def retry_on_deadlock_decorator(func): - lock_messages_error = ['Deadlock found', 'Lock wait timeout exceeded'] + warnings.warn( + 'retry_on_deadlock_decorator will remove in 3.2', + DeprecationWarning, + ) + + lock_messages_error = [ + 'Deadlock found', + 'Lock wait timeout exceeded', + ] @wraps(func) def wrapper(*args, **kwargs): @@ -94,8 +117,8 @@ def wrapper(*args, **kwargs): if any(msg in e.message for msg in lock_messages_error) \ and attempt_count <= MAXIMUM_RETRY_ON_DEADLOCK: error_logger.error( - 'Deadlock detected. Trying sql transaction once more. Attempts count: %s' - % (attempt_count + 1) + 'Deadlock detected. Trying sql transaction once more. ' + 'Attempts count: %s' % (attempt_count + 1) ) else: raise @@ -105,6 +128,11 @@ def wrapper(*args, **kwargs): def close_connection(func): + warnings.warn( + 'retry_on_deadlock_decorator will remove in 3.2', + DeprecationWarning, + ) + def wrapper(*args, **kwargs): try: result = func(*args, **kwargs) @@ -114,3 +142,27 @@ def wrapper(*args, **kwargs): return result return wrapper + + +class AsyncModelDeclarativeMeta(DeclarativeMeta): + """Make db.BaseModel support async using this metaclass""" + def __getattribute__(self, attr): + if attr == 'aio': + aio = super().__getattribute__(attr) + if any([aio.db is None, aio.model is None]): + aio = type( + f'Aio{aio.__qualname__}', + aio.__bases__, + dict(aio.__dict__), + ) + setattr(aio, 'db', self._db) + setattr(aio, 'model', self) + return aio + + return super().__getattribute__(attr) + + # noinspection PyMethodParameters + def __call__(self, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + instance.aio = self.aio(instance) + return instance diff --git a/bali/db/models.py b/bali/db/models.py index cff28e0..191cdd5 100644 --- a/bali/db/models.py +++ b/bali/db/models.py @@ -1,3 +1,16 @@ +"""Model included + +`db.BaseModel` is the most common model. +If you don't use `db.BaseModel`, you can compose Mixins to `db.Model` + +Import Mixins in your project examples: + + ```python + from bali.db.models import GenericModelMixin, AsyncModelMixin + ``` + +""" + from contextvars import ContextVar from datetime import datetime from typing import List, Dict @@ -5,12 +18,13 @@ import pytz from sqlalchemy import Column, DateTime, Boolean from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.future import select +from sqlalchemy.inspection import inspect from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.sql.functions import func from sqlalchemy.types import TypeDecorator -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.inspection import inspect from ..utils import timezone @@ -39,13 +53,12 @@ def process_bind_param(self, value, _): def get_base_model(db): - class BaseModel(db.Model): + class BaseModel(db.Model, GenericModelMixin, AsyncModelMixin): __abstract__ = True __asdict_include_hybrid_properties__ = False - created_time = Column(DateTime, default=datetime.utcnow) - updated_time = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - is_active = Column(Boolean, default=True) + # Bind SQLA-wrapper database to model + _db = db @classmethod def exists(cls, **attrs): @@ -55,7 +68,9 @@ def exists(cls, **attrs): @classmethod def create(cls, **attrs): - """Create and persist a new record for the model, and returns it.""" + """ + Create and persist a new record for the model, and returns it. + """ return cls(**attrs).save() @classmethod @@ -109,7 +124,9 @@ def _asdict(self, **kwargs): for i in inspect(type(self)).all_orm_descriptors: if isinstance(i, InstrumentedAttribute): output_fields.append(i.key) - elif isinstance(i, hybrid_property) and include_hybrid_properties: + elif isinstance( + i, hybrid_property + ) and include_hybrid_properties: output_fields.append(i.__name__) return {i: getattr(self, i, None) for i in output_fields} @@ -145,11 +162,9 @@ def update_or_create(cls, defaults: Dict = None, **kwargs): try: try: instance = ( - db.s.query(cls) - .filter_by(**kwargs) - .populate_existing() - .with_for_update() - .one() + db.s.query(cls).filter_by( + **kwargs + ).populate_existing().with_for_update().one() ) except NoResultFound: instance = cls(**{**kwargs, **(defaults or {})}) # noqa @@ -159,11 +174,9 @@ def update_or_create(cls, defaults: Dict = None, **kwargs): except SQLAlchemyError: db.s.rollback() instance = ( - db.s.query(cls) - .filter_by(**kwargs) - .populate_existing() - .with_for_update() - .one() + db.s.query(cls).filter_by( + **kwargs + ).populate_existing().with_for_update().one() ) else: return instance, True @@ -179,3 +192,79 @@ def update_or_create(cls, defaults: Dict = None, **kwargs): return instance, False return BaseModel + + +class AsyncModelManager: + """Async model bind to aio""" + + db = None + model = None + + def __init__(self, instance): + self.instance = instance + + @classmethod + async def exists(cls, **attrs): + async with cls.db.async_session() as async_session: + stmt = select(cls.model).filter_by(**attrs) + result = await async_session.execute(stmt) + return bool(result.scalars().first()) + + @classmethod + async def create(cls, **attrs): + await cls.model(**attrs).aio.save() + + @classmethod + async def first(cls, **attrs): + async with cls.db.async_session() as async_session: + stmt = select(cls.model).filter_by(**attrs) + result = await async_session.execute(stmt) + return result.scalars().first() + + async def save(self): + async with self.db.async_session() as async_session: + async_session.add(self.instance) + await async_session.commit() + return self.instance + + async def delete(self): + async with self.db.async_session() as async_session: + async_session.delete(self.instance) + await async_session.commit() + + +# expose the include models and model creator +included_models = { + 'BaseModel': get_base_model, +} + +# -------------------- Models Mixins -------------------- # + + +class GenericModelMixin: + """Generic model include the following fields""" + created_time = Column(DateTime, default=datetime.utcnow) + updated_time = Column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + is_active = Column(Boolean, default=True) + + +class AsyncModelMixin: + """Async models methods + + All async method accessed by `aio` + + ```python + # Model + async def get_first_user() + await User.aio.first() + + # Instance + user = User() + user.aio.save() + ``` + + """ + + aio = AsyncModelManager diff --git a/requirements.txt b/requirements.txt index 9b288a1..2954520 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiomysql dateparser==1.0.0 decamelize==0.1.2 fastapi[all]==0.63.0 diff --git a/tests/test_db.py b/tests/test_db.py index 2ed6068..0299826 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,3 +1,6 @@ +from sqlalchemy.orm.decl_api import DeclarativeMeta +DeclarativeMeta = None + import pytest from sqlalchemy import Column, Integer, String from sqlalchemy.future import select @@ -31,6 +34,12 @@ def get_by_username_sync(cls, username): return User.first(username=username) +class Book(db.BaseModel): + __tablename__ = "books" + id = Column(Integer, primary_key=True) + title = Column(String(20), index=True) + + db.create_all() diff --git a/tests/test_db_operators.py b/tests/test_db_operators.py index 1c07108..a164aa0 100644 --- a/tests/test_db_operators.py +++ b/tests/test_db_operators.py @@ -15,6 +15,11 @@ class User(db.BaseModel): username = Column(String(50), default='') age = Column(Integer) + class Book(db.BaseModel): + __tablename__ = "books" + id = Column(Integer, primary_key=True) + title = Column(String(20), index=True) + db.create_all() lucy = User.create(**{ 'username': 'Lucy', @@ -32,7 +37,9 @@ class User(db.BaseModel): assert lucy.id > 0 users = User.query().filter(User.username.like('%c%'), User.age > 0).all() - assert len(users) == 2, 'Fetch users count in common way should work properly' + assert len( + users + ) == 2, 'Fetch users count in common way should work properly' filters = { 'username__like': '%c%', diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..c8d4423 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,70 @@ +import pytest +from sqlalchemy import Column, Integer, String +from sqlalchemy.future import select +from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import declarative_base + +from bali.db import db + +DB_URI = 'sqlite:///:memory:' + +db.connect(DB_URI) + +Base = declarative_base() + + +class User(db.BaseModel): + __tablename__ = "users" + id = Column(Integer, primary_key=True) + username = Column(String(20), index=True) + + +class Book(db.BaseModel): + __tablename__ = "books" + id = Column(Integer, primary_key=True) + title = Column(String(20), index=True) + + +@pytest.mark.asyncio +async def test_db_async_model_exists(): + # Create model schema to database + async with db._async_engine.begin() as conn: + await conn.run_sync(db.BaseModel.metadata.create_all) + + username = 'Lorry' + non_username = 'Noname' + title = 'Design Pattern' + + async with db.async_session() as async_session: + user = User(username=username) + book = Book(title=title) + async_session.add(user) + async_session.add(book) + await async_session.commit() + + is_exists = await User.aio.exists(username=username) + assert isinstance(is_exists, bool) + assert is_exists, f'{username} should exists' + + is_exists = await Book.aio.exists(title=title) + assert isinstance(is_exists, bool) + assert is_exists, f'Book {title} should exists' + + is_exists = await User.aio.exists(username=non_username) + assert isinstance(is_exists, bool) + assert not is_exists, f'{non_username} should not exists' + + +@pytest.mark.asyncio +async def test_db_async_instance_save(): + # Create model schema to database + async with db._async_engine.begin() as conn: + await conn.run_sync(db.BaseModel.metadata.create_all) + + username = 'Jeff Inno' + user = User(username=username) + await user.aio.save() + + is_exists = await User.aio.exists(username=username) + assert isinstance(is_exists, bool) + assert is_exists, 'User should persisted into database' diff --git a/tests/test_model_resource.py b/tests/test_model_resource.py index 55f0040..b6e01d1 100644 --- a/tests/test_model_resource.py +++ b/tests/test_model_resource.py @@ -6,7 +6,7 @@ from bali.db import db from bali.db.operators import get_filters_expr from bali.decorators import action -from bali.resource import ModelResource +from bali.resources import ModelResource from bali.schemas import ListRequest from permissions import IsAuthenticated diff --git a/tests/test_resource.py b/tests/test_resource.py index 96f3a7a..fac7b7c 100644 --- a/tests/test_resource.py +++ b/tests/test_resource.py @@ -6,7 +6,7 @@ from bali.db import db from bali.db.operators import get_filters_expr from bali.decorators import action -from bali.resource import Resource +from bali.resources import Resource from bali.schemas import ListRequest from permissions import IsAuthenticated