From b97eb22c6e7fea8ca190af4837276154d6016164 Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Wed, 13 Oct 2021 20:32:57 +0800 Subject: [PATCH 01/11] Upgraded to sql-wrapper v5.0.0dev --- .editorconfig | 2 +- CHANGELOG | 4 +++- bali/__init__.py | 2 +- bali/db/connection.py | 40 +++++++++++++++++++++++++------------ bali/db/models.py | 46 +++++++++++++++++++++---------------------- requirements.txt | 5 ++--- tests/test_db.py | 4 ++-- 7 files changed, 59 insertions(+), 44 deletions(-) diff --git a/.editorconfig b/.editorconfig index cf51bc5..041328a 100644 --- a/.editorconfig +++ b/.editorconfig @@ -16,4 +16,4 @@ indent_style = space tab_width = 2 [*.py] -max_line_length = 100 +max_line_length = 79 diff --git a/CHANGELOG b/CHANGELOG index 76d2f65..557e332 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -3,8 +3,10 @@ # This project adheres to [Semantic Versioning](http://semver.org/). # includes Added / Changed / Fixed -## [2.1.0] UNRELEASED +## [3.0.0] UNRELEASED ### Added +- Supported uvicorn 0.15 +- Database model support asynchronous ### Changed - Adjusted version range dependency packages diff --git a/bali/__init__.py b/bali/__init__.py index a33997d..4eb28e3 100644 --- a/bali/__init__.py +++ b/bali/__init__.py @@ -1 +1 @@ -__version__ = '2.1.0' +__version__ = '3.0.0' diff --git a/bali/db/connection.py b/bali/db/connection.py index aa4edcc..7ee0506 100644 --- a/bali/db/connection.py +++ b/bali/db/connection.py @@ -13,28 +13,42 @@ # noinspection PyPep8Naming class DB: def __init__(self): - self._session = None - - def connect(self, database_uri, **kwargs): - kwargs.setdefault("pool_size", 5) - kwargs.setdefault("pool_recycle", 2 * 60 * 60) - - # developers need to know when the ORM object needs to reload from the db - kwargs.setdefault("expire_on_commit", False) - self._session = SQLAlchemy(database_uri, **kwargs) + self._db = None + + def connect( + self, + database_uri, + engine_options=None, + session_options=None, + **kwargs + ): + engine_options = engine_options or {} + engine_options.setdefault("pool_size", 5) + engine_options.setdefault("pool_recycle", 2 * 60 * 60) + + session_options = session_options or {} + # developers need to know when the ORM object needs to reload + # from the db + session_options.setdefault("expire_on_commit", False) + + self._db = SQLAlchemy( + database_uri, + engine_options=engine_options, + session_options=session_options, + ) def __getattribute__(self, attr, *args, **kwargs): try: return super().__getattribute__(attr) except AttributeError: - if not self._session: + if not self._db: raise Exception('Database session not initialized') # BaseModel if attr == 'BaseModel': return get_base_model(self) - return getattr(self._session, attr) + return getattr(self._db, attr) db = DB() @@ -56,8 +70,8 @@ def wrapper(*args, **kwargs): if any(msg in e.message for msg in lock_messages_error) \ and attempt_count <= MAXIMUM_RETRY_ON_DEADLOCK: error_logger.error( - 'Deadlock detected. Trying sql transaction once more. Attempts count: %s' % - (attempt_count + 1) + 'Deadlock detected. Trying sql transaction once more. Attempts count: %s' + % (attempt_count + 1) ) else: raise diff --git a/bali/db/models.py b/bali/db/models.py index 4ca2dee..763c9e6 100644 --- a/bali/db/models.py +++ b/bali/db/models.py @@ -46,8 +46,8 @@ class BaseModel(db.Model): @classmethod def exists(cls, **attrs): """Returns whether an object with these attributes exists.""" - equery = cls.query().filter_by(**attrs).exists() - return bool(db.session.query(equery).scalar()) + query = cls.query().filter_by(**attrs).exists() + return bool(db.s.query(query).scalar()) @classmethod def create(cls, **attrs): @@ -61,7 +61,7 @@ def create_or_first(cls, **attrs): try: return cls.create(**attrs) except IntegrityError: - db.session.rollback() + db.s.rollback() return cls.first(**attrs) @classmethod @@ -80,20 +80,20 @@ def first_or_error(cls, **attrs): @classmethod def query(cls): - return db.session.query(cls) + return db.s.query(cls) def save(self): """Override default model's save""" global context_auto_commit - db.session.add(self) - db.session.commit() if context_auto_commit.get() else db.session.flush() + db.s.add(self) + db.s.commit() if context_auto_commit.get() else db.s.flush() return self def delete(self): """Override default model's delete""" global context_auto_commit - db.session.delete(self) - db.session.commit() if context_auto_commit.get() else db.session.flush() + db.s.delete(self) + db.s.commit() if context_auto_commit.get() else db.s.flush() def to_dict(self): return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} @@ -103,7 +103,7 @@ def dict(self): @classmethod def count(cls, **attrs) -> int: - return db.session.query(func.count(cls.id)).filter_by(**attrs).scalar() + return db.s.query(func.count(cls.id)).filter_by(**attrs).scalar() @classmethod def get_fields(cls) -> List[str]: @@ -111,18 +111,18 @@ def get_fields(cls) -> List[str]: @classmethod def get_or_create(cls, defaults: Dict = None, **kwargs): - instance = db.session.query(cls).filter_by(**kwargs).one_or_none() + instance = db.s.query(cls).filter_by(**kwargs).one_or_none() if instance: return instance, False instance = cls(**{**kwargs, **(defaults or {})}) # noqa try: - db.session.add(instance) - db.session.commit() + db.s.add(instance) + db.s.commit() return instance, True except SQLAlchemyError: - db.session.rollback() - instance = db.session.query(cls).filter_by(**kwargs).one() + db.s.rollback() + instance = db.s.query(cls).filter_by(**kwargs).one() return instance, False @classmethod @@ -130,7 +130,7 @@ def update_or_create(cls, defaults: Dict = None, **kwargs): try: try: instance = ( - db.session.query(cls) + db.s.query(cls) .filter_by(**kwargs) .populate_existing() .with_for_update() @@ -139,12 +139,12 @@ def update_or_create(cls, defaults: Dict = None, **kwargs): except NoResultFound: instance = cls(**{**kwargs, **(defaults or {})}) # noqa try: - db.session.add(instance) - db.session.commit() + db.s.add(instance) + db.s.commit() except SQLAlchemyError: - db.session.rollback() + db.s.rollback() instance = ( - db.session.query(cls) + db.s.query(cls) .filter_by(**kwargs) .populate_existing() .with_for_update() @@ -153,14 +153,14 @@ def update_or_create(cls, defaults: Dict = None, **kwargs): else: return instance, True except SQLAlchemyError: - db.session.rollback() + db.s.rollback() raise else: for k, v in defaults.items(): setattr(instance, k, v) - db.session.add(instance) - db.session.commit() - db.session.refresh(instance) + db.s.add(instance) + db.s.commit() + db.s.refresh(instance) return instance, False return BaseModel diff --git a/requirements.txt b/requirements.txt index ab11f7e..56b4926 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,6 @@ pydantic-sqlalchemy>=0.0.7,<1 python-jose[cryptography]==3.2.0 pytz==2021.1 redis==3.5.3 -sqla-wrapper==4.200628 -SQLAlchemy==1.3.19 +sqla-wrapper>=5.0.0.dev5,<6 typer>=0.3.2 -uvicorn==0.12.3 +uvicorn>=0.12.3,<=0.15 diff --git a/tests/test_db.py b/tests/test_db.py index 7273476..69b928b 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -8,7 +8,7 @@ # noinspection PyProtectedMember def test_db_connect(): db.connect(DB_URI) - assert db._session is not None + assert db._db is not None def test_base_model(): @@ -21,7 +21,7 @@ class User(db.BaseModel): db.create_all() # using the exists columns - users = db.query(User.created_time, User.updated_time, User.is_active).all() + users = db.s.query(User.created_time, User.updated_time, User.is_active).all() assert users == [] From eb523f19ad35b78e1a9dcec0b01fbbac047caf71 Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Wed, 13 Oct 2021 20:45:23 +0800 Subject: [PATCH 02/11] Supported uvicorn 0.15 --- CHANGELOG | 1 + bali/application.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG b/CHANGELOG index 557e332..2a37282 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -5,6 +5,7 @@ ## [3.0.0] UNRELEASED ### Added +- Upgraded to sql-wrapper v5.0.0dev - Supported uvicorn 0.15 - Database model support asynchronous diff --git a/bali/application.py b/bali/application.py index 2a02fe2..fb8826b 100644 --- a/bali/application.py +++ b/bali/application.py @@ -47,8 +47,9 @@ def __getattribute__(self, attr, *args, **kwargs): return super().__getattribute__(attr) except AttributeError: if not self._app: + print(f'attr: {attr}') # uvicorn entry is __call__ - if attr == '__call__': + if attr == '__call__' or attr == '__getstate__': self.http() return getattr(self._app, attr) From 75ae9e6a3d6dae4fd0d2881ce23b1f58539ef7c3 Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Thu, 14 Oct 2021 10:17:58 +0800 Subject: [PATCH 03/11] Removed main.py default launch behavior --- CHANGELOG | 2 ++ bali/application.py | 7 +++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index 2a37282..6998732 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -8,6 +8,8 @@ - Upgraded to sql-wrapper v5.0.0dev - Supported uvicorn 0.15 - Database model support asynchronous +### Changed +- Removed main.py default launch behavior ### Changed - Adjusted version range dependency packages diff --git a/bali/application.py b/bali/application.py index fb8826b..f4d8d25 100644 --- a/bali/application.py +++ b/bali/application.py @@ -106,14 +106,13 @@ def http(self): add_pagination(self._app) def launch(self, http: bool = False, rpc: bool = False): - start_all = not any([http, rpc]) - if start_all: - return self._start_all() + if not http and not rpc: + typer.echo('Please provided launch service type: --http or --rpc') if http: self._launch_http() - if start_all or rpc: + if rpc: self._launch_rpc() def start(self): From 70e9ccd9653a1bef03f1613dbbee15cc64f3fe15 Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Thu, 14 Oct 2021 18:31:46 +0800 Subject: [PATCH 04/11] Created async SQLAlchemy engine in db layer --- CHANGELOG | 3 ++- bali/db/connection.py | 28 ++++++++++++++++++++-- bali/db/connection.pyi | 5 +++- bali/interceptors.py | 2 +- bali/middlewares.py | 2 +- requirements_dev.txt | 2 ++ tests/test_db.py | 54 ++++++++++++++++++++++++++++-------------- 7 files changed, 72 insertions(+), 24 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index 6998732..37e83f2 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -7,7 +7,8 @@ ### Added - Upgraded to sql-wrapper v5.0.0dev - Supported uvicorn 0.15 -- Database model support asynchronous +- Model support asynchronous +- Resource support asynchronous ### Changed - Removed main.py default launch behavior diff --git a/bali/db/connection.py b/bali/db/connection.py index 7ee0506..cca4cf9 100644 --- a/bali/db/connection.py +++ b/bali/db/connection.py @@ -1,19 +1,32 @@ import logging from functools import wraps +from contextlib import asynccontextmanager from sqla_wrapper import SQLAlchemy from sqlalchemy.exc import OperationalError -from .models import get_base_model, AwareDateTime +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import sessionmaker + +from .models import get_base_model # from core.config import settings error_logger = logging.getLogger('error') +database_schema_async_maps = [ + ('sqlite://', 'sqlite+aiosqlite://'), + ('mysql+pymysql://', 'mysql+aiomysql://'), + ('postgres://', 'postgresql+asyncpg://'), +] + # noinspection PyPep8Naming class DB: def __init__(self): - self._db = None + self._db = None # sync engine + self._async_engine = None # async engine + self.async_session = None # async session maker def connect( self, @@ -31,11 +44,22 @@ def connect( # from the db session_options.setdefault("expire_on_commit", False) + # Sync mode db instance self._db = SQLAlchemy( database_uri, engine_options=engine_options, session_options=session_options, ) + async_database_uri = database_uri + for sync_schema, async_schema in database_schema_async_maps: + async_database_uri = async_database_uri.replace( + sync_schema, async_schema + ) + self._async_engine = create_async_engine(async_database_uri) + + self.async_session = sessionmaker( + self._async_engine, class_=AsyncSession, expire_on_commit=False + ) def __getattribute__(self, attr, *args, **kwargs): try: diff --git a/bali/db/connection.pyi b/bali/db/connection.pyi index 076f25f..6fc5430 100644 --- a/bali/db/connection.pyi +++ b/bali/db/connection.pyi @@ -55,13 +55,16 @@ class BaseModel: class DB(SQLAlchemy): BaseModel: BaseModel + _async_engine: Any + def connect(self, database_uri: str) -> None: ... @property def session(self) -> Session: ... - def transaction(self) -> None: ... + async def async_session(self) -> Any: ... + db = DB() diff --git a/bali/interceptors.py b/bali/interceptors.py index 34679b1..766b977 100644 --- a/bali/interceptors.py +++ b/bali/interceptors.py @@ -21,7 +21,7 @@ def setup(self): def teardown(self): try: - db.remove() + db.s.remove() except Exception: pass diff --git a/bali/middlewares.py b/bali/middlewares.py index 7bfa974..f213007 100644 --- a/bali/middlewares.py +++ b/bali/middlewares.py @@ -17,6 +17,6 @@ async def process_middleware(request: Request, call_next): # remove db session when FastAPI request ended try: - db.remove() + db.s.remove() except Exception: pass diff --git a/requirements_dev.txt b/requirements_dev.txt index f083d88..c86c78f 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,5 +1,7 @@ -r requirements.txt pytest==6.1.2 +pytest-asyncio>=0.15.0 pytest-cov==2.12.0 +aiosqlite twine==3.2.0 wheel==0.35.1 diff --git a/tests/test_db.py b/tests/test_db.py index 69b928b..af307ca 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,39 +1,57 @@ +import pytest from sqlalchemy import Column, Integer +from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import declarative_base from bali.db import db DB_URI = 'sqlite:///:memory:' +db.connect(DB_URI) -# noinspection PyProtectedMember -def test_db_connect(): - db.connect(DB_URI) - assert db._db is not None +Base = declarative_base() -def test_base_model(): - db.connect(DB_URI) - class User(db.BaseModel): - __tablename__ = "users" - id = Column(Integer, primary_key=True) +class User(db.BaseModel): + __tablename__ = "users" + id = Column(Integer, primary_key=True) + + +db.create_all() + + +def test_db_connect(): + assert db._db is not None + assert db._async_engine is not None + assert isinstance(db.async_session, sessionmaker) - db.create_all() +def test_base_model(): # using the exists columns - users = db.s.query(User.created_time, User.updated_time, User.is_active).all() - assert users == [] + users = db.s.query(User.created_time, User.updated_time, + User.is_active).all() + assert len(users) >= 0 def test_base_model_create_entity(): - db.connect(DB_URI) + user = User.create() + assert user.id > 0 # create successfully + assert user.created_time <= user.updated_time + assert user.is_active - class User(db.BaseModel): - __tablename__ = "users" - id = Column(Integer, primary_key=True) - db.create_all() - user = User.create() +@pytest.mark.asyncio +async def test_base_model_create_entity_async(): + # Create model schema to database + async with db._async_engine.begin() as conn: + await conn.run_sync(db.BaseModel.metadata.create_all) + + async with db.async_session() as async_session: + user = User() + async_session.add(user) + await async_session.commit() + assert user.id > 0 # create successfully assert user.created_time <= user.updated_time assert user.is_active From e478b5af0ff1d56364499fd1cc4d80199d37c13d Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Fri, 15 Oct 2021 10:04:03 +0800 Subject: [PATCH 05/11] Tested async engine could access in models layer --- tests/test_db.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/test_db.py b/tests/test_db.py index af307ca..5ac6338 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,5 +1,6 @@ import pytest -from sqlalchemy import Column, Integer +from sqlalchemy import Column, Integer, String +from sqlalchemy.future import select from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import declarative_base @@ -12,10 +13,22 @@ Base = declarative_base() - class User(db.BaseModel): __tablename__ = "users" id = Column(Integer, primary_key=True) + username = Column(String(20), index=True) + + @classmethod + async def get_by_username(cls, username): + async with db.async_session() as async_session: + result = await async_session.execute( + select(User).filter(User.username == username) + ) + return result.scalars().first() + + @classmethod + def get_by_username_sync(cls, username): + return User.first(username=username) db.create_all() @@ -35,7 +48,7 @@ def test_base_model(): def test_base_model_create_entity(): - user = User.create() + user = User.create(username='Lorry') assert user.id > 0 # create successfully assert user.created_time <= user.updated_time assert user.is_active @@ -48,10 +61,24 @@ async def test_base_model_create_entity_async(): await conn.run_sync(db.BaseModel.metadata.create_all) async with db.async_session() as async_session: - user = User() + user = User(username='Ary') async_session.add(user) await async_session.commit() assert user.id > 0 # create successfully assert user.created_time <= user.updated_time assert user.is_active + + +def test_fetch_entity_sync(): + # Create model schema to database + user = User.get_by_username_sync('Lorry') + assert user.username == 'Lorry' + + +@pytest.mark.asyncio +async def test_fetch_entity_async(): + async with db.async_session() as async_session: + user = await User.get_by_username('Ary') + + assert user.username == 'Ary' From 75945ecc16e5001dbf222d4303d10f6d1e3d744e Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Fri, 15 Oct 2021 10:07:20 +0800 Subject: [PATCH 06/11] Optimized fetch entity async tests --- tests/test_db.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_db.py b/tests/test_db.py index 5ac6338..2ed6068 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -78,7 +78,5 @@ def test_fetch_entity_sync(): @pytest.mark.asyncio async def test_fetch_entity_async(): - async with db.async_session() as async_session: - user = await User.get_by_username('Ary') - + user = await User.get_by_username('Ary') assert user.username == 'Ary' From 6dc75ac8949357b15e8e776d96b3630e9da8a947 Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Sat, 16 Oct 2021 16:25:35 +0800 Subject: [PATCH 07/11] gRPC asyncio example --- bali/aio/__init__.py | 0 bali/aio/interceptors.py | 39 +++++++++++++++++++++ bali/application.py | 11 ++++-- bali/utils/__init__.py | 6 ++++ examples/grpc_server_async.py | 66 +++++++++++++++++++++++++++++++++++ examples/main.py | 3 +- requirements.txt | 2 +- requirements_dev.txt | 3 +- 8 files changed, 124 insertions(+), 6 deletions(-) create mode 100644 bali/aio/__init__.py create mode 100644 bali/aio/interceptors.py create mode 100644 examples/grpc_server_async.py diff --git a/bali/aio/__init__.py b/bali/aio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bali/aio/interceptors.py b/bali/aio/interceptors.py new file mode 100644 index 0000000..1310e45 --- /dev/null +++ b/bali/aio/interceptors.py @@ -0,0 +1,39 @@ +""" +gRPC interceptors +""" +import grpc +import logging + +from google.protobuf import json_format + +from bali.core import db + +from typing import Callable, Any +from grpc.aio import ServerInterceptor +from ..core import _settings + +logger = logging.getLogger('bali') + + +class ProcessInterceptor(ServerInterceptor): + def setup(self): + pass + + def teardown(self): + try: + db.s.remove() + except Exception: + pass + + async def intercept_service( + self, + continuation: Callable, + handler_call_details: grpc.HandlerCallDetails, + ) -> Any: + self.setup() + try: + result = await continuation() + finally: + self.teardown() + + return result diff --git a/bali/application.py b/bali/application.py index f4d8d25..4d2afd8 100644 --- a/bali/application.py +++ b/bali/application.py @@ -1,4 +1,5 @@ import gzip +import inspect from multiprocessing import Process from typing import Callable @@ -12,6 +13,7 @@ from ._utils import singleton from .middlewares import process_middleware +from .utils import sync_exec class GzipRequest(Request): @@ -65,11 +67,14 @@ def _launch_http(self): self._app = FastAPI(**self.base_settings) uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, access_log=True) - def _launch_rpc(self): + async def _launch_rpc(self): service = self.kwargs.get('rpc_service') if not service: raise Exception('rpc_service not provided') - service.serve() + if inspect.iscoroutinefunction(service.serve): + await service.serve() + else: + service.serve() def _start_all(self): process_http = Process(target=self._launch_http) @@ -113,7 +118,7 @@ def launch(self, http: bool = False, rpc: bool = False): self._launch_http() if rpc: - self._launch_rpc() + sync_exec(self._launch_rpc()) def start(self): typer.run(self.launch) diff --git a/bali/utils/__init__.py b/bali/utils/__init__.py index 0bd17dc..b4cdd62 100644 --- a/bali/utils/__init__.py +++ b/bali/utils/__init__.py @@ -1,3 +1,4 @@ +import asyncio from enum import Enum from datetime import datetime, date from decimal import Decimal @@ -130,3 +131,8 @@ def get_beginning_datetime( ) -> datetime: _datetime = datetime(year, month, day) return make_aware(_datetime, timezone=timezone, is_dst=is_dst) + + +def sync_exec(coro): + loop = asyncio.get_event_loop() + return loop.run_until_complete(coro) diff --git a/examples/grpc_server_async.py b/examples/grpc_server_async.py new file mode 100644 index 0000000..37cf657 --- /dev/null +++ b/examples/grpc_server_async.py @@ -0,0 +1,66 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the GRPC helloworld.Greeter server.""" + +import logging +import asyncio + +from concurrent import futures + +import grpc + +import helloworld_pb2 +import helloworld_pb2 as pb2 +import helloworld_pb2_grpc +from bali.aio.interceptors import ProcessInterceptor +from bali.mixins import ServiceMixin +from resources import GreeterResource, ItemResource + + +class GrpcServer(helloworld_pb2_grpc.GreeterServicer, ServiceMixin): + def SayHello(self, request, context): + print('Greeter.SayHello') + return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + + def GetGreeter(self, request, context): + print('Greeter.GetGreeter') + return GreeterResource(request, context, pb2.ItemResponse).get() + + def ListGreeter(self, request, context): + print('Greeter.ListGreeter') + return GreeterResource(request, context, pb2.ListResponse).list() + + def CreateGreeter(self, request, context): + print('Greeter.CreateGreeter') + return GreeterResource(request, context, pb2.ItemResponse).create() + + def GetItem(self, request, context): + return ItemResource(request, context, pb2.ItemResponse).get() + + def ListItems(self, request, context): + return ItemResource(request, context, pb2.ListResponse).list() + + +async def serve(): + server = grpc.aio.server() + helloworld_pb2_grpc.add_GreeterServicer_to_server(GrpcServer(), server) + server.add_insecure_port('[::]:50051') + await server.start() + print("gRPC Service Hello world started") + await server.wait_for_termination() + + +if __name__ == '__main__': + logging.basicConfig() + asyncio.run(serve()) diff --git a/examples/main.py b/examples/main.py index 4944191..e8bfb6b 100644 --- a/examples/main.py +++ b/examples/main.py @@ -1,6 +1,7 @@ # noinspection PyUnresolvedReferences import config import grpc_server +import grpc_server_async from bali.core import Bali from v1.app import router from fastapi_pagination import LimitOffsetPage, add_pagination, paginate @@ -13,7 +14,7 @@ 'prefix': '/v1', }], backend_cors_origins=['http://127.0.0.1'], - rpc_service=grpc_server, + rpc_service=grpc_server_async, ) app.settings(title='Bali Example') diff --git a/requirements.txt b/requirements.txt index 56b4926..689d8fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ fastapi-pagination==0.7.4 grpcio>=1.32.0,<1.40 grpcio-tools>=1.32.0,<1.40 grpc-interceptor==0.13.0 -PyMySQL==0.10.1 +PyMySQL<=0.9.3,>=0.9 passlib[bcrypt]==1.7.2 pillow>=7.2.0,<8.3 protobuf==3.13.0 diff --git a/requirements_dev.txt b/requirements_dev.txt index c86c78f..750b336 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,7 +1,8 @@ -r requirements.txt +aiomysql~=0.0.21 +aiosqlite pytest==6.1.2 pytest-asyncio>=0.15.0 pytest-cov==2.12.0 -aiosqlite twine==3.2.0 wheel==0.35.1 From a67e511bccab9bf0c8a230e25bffcad88babb905 Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Sat, 16 Oct 2021 16:34:07 +0800 Subject: [PATCH 08/11] Fixed asyncio gRPC interceptor continuation calls --- bali/aio/interceptors.py | 2 +- examples/grpc_server_async.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bali/aio/interceptors.py b/bali/aio/interceptors.py index 1310e45..e23bbf1 100644 --- a/bali/aio/interceptors.py +++ b/bali/aio/interceptors.py @@ -32,7 +32,7 @@ async def intercept_service( ) -> Any: self.setup() try: - result = await continuation() + result = await continuation(handler_call_details) finally: self.teardown() diff --git a/examples/grpc_server_async.py b/examples/grpc_server_async.py index 37cf657..1f5ec55 100644 --- a/examples/grpc_server_async.py +++ b/examples/grpc_server_async.py @@ -53,7 +53,7 @@ def ListItems(self, request, context): async def serve(): - server = grpc.aio.server() + server = grpc.aio.server(interceptors=(ProcessInterceptor(), )) helloworld_pb2_grpc.add_GreeterServicer_to_server(GrpcServer(), server) server.add_insecure_port('[::]:50051') await server.start() From b8875beccd7aab96cc80f4049d191e1f1e00a9f5 Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Mon, 18 Oct 2021 10:07:17 +0800 Subject: [PATCH 09/11] Tests multi thread mode vs asyncio mode performance --- bali/decorators.py | 61 ++++++++++++++++++++++++++++++- examples/benchmark.py | 17 +++++++++ examples/grpc_client.py | 68 +++++++++++++++++------------------ examples/grpc_server.py | 36 +++++++++---------- examples/grpc_server_async.py | 45 ++++++++++++----------- examples/main.py | 2 +- examples/resources/item.py | 17 +++++++++ setup.py | 2 +- 8 files changed, 172 insertions(+), 76 deletions(-) create mode 100644 examples/benchmark.py diff --git a/bali/decorators.py b/bali/decorators.py index f3f38ef..69128f2 100644 --- a/bali/decorators.py +++ b/bali/decorators.py @@ -1,4 +1,5 @@ import functools +import inspect from fastapi.dependencies.utils import get_typed_signature from fastapi_pagination import LimitOffsetParams, set_page @@ -69,7 +70,65 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) - return wrapper + @functools.wraps(func) + async def wrapper_async(self, *args, **kwargs): + + # Put args to inner function from request object + if self._is_rpc: + request_data = MessageToDict( + self._request, + including_default_value_fields=True, + preserving_proto_field_name=True, + ) + + if func.__name__ == 'get': + pk = self._request.id + result = await func(self, pk) + if not isinstance(result, dict): + result = result.dict() + response_data = {'data': result} + + elif func.__name__ == 'list': + schema_in = get_schema_in(func) + result = await func(self, schema_in(**request_data)) + # Paginated the result queryset or iterable object + if isinstance(result, BaseModel): + raise ReturnTypeError('Generic actions `list` should return a sequence') + else: + set_page(Page) + params = LimitOffsetParams( + limit=request_data.get('limit') or 10, + offset=request_data.get('offset'), + ) + response_data = paginate(result, params=params, is_rpc=True) + + elif func.__name__ in ['create', 'update']: + schema_in = get_schema_in(func) + data = request_data.get('data') + result = await func(self, schema_in(**data)) + if not isinstance(result, dict): + result = result.dict() + response_data = {'data': result} + + elif func.__name__ == 'delete': + pk = self._request.id + result = await func(self, pk) + response_data = {'result': bool(result)} + + else: + # custom action + schema_in = get_schema_in(func) + result = await func(self, schema_in(**request_data)) + if not isinstance(result, dict): + result = result.dict() + response_data = result + + # Convert response data to gRPC response + return ParseDict(response_data, self._response_message(), ignore_unknown_fields=True) + + return func(self, *args, **kwargs) + + return wrapper_async if inspect.iscoroutinefunction(func) else wrapper def action(methods=None, detail=None, **kwargs): diff --git a/examples/benchmark.py b/examples/benchmark.py new file mode 100644 index 0000000..3c73aba --- /dev/null +++ b/examples/benchmark.py @@ -0,0 +1,17 @@ +import time +from multiprocessing import Process + +from grpc_client import run + +if __name__ == '__main__': + t1 = time.time() + processes = [] + execute_count = 1000 + for i in range(execute_count): + p = Process(target=run) + p.start() + processes.append(p) + for p in processes: + p.join() + t2 = time.time() + print('Execution took %s seconds' % (t2 - t1)) diff --git a/examples/grpc_client.py b/examples/grpc_client.py index f9fedd5..aad3a1f 100644 --- a/examples/grpc_client.py +++ b/examples/grpc_client.py @@ -28,44 +28,44 @@ def run(): # NOTE(gRPC Python Team): .close() is possible on a channel and should be # used in circumstances in which the with statement does not fit the needs # of the code. - with grpc.insecure_channel('localhost:50051') as channel: - stub = helloworld_pb2_grpc.GreeterStub(channel) - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) - print("Greeter client received : %s" % response.message) - - with grpc.insecure_channel('localhost:50051') as channel: - stub = helloworld_pb2_grpc.GreeterStub(channel) - response = stub.GetGreeter(helloworld_pb2.GetRequest(id=3)) - print("Greeter client received : %s" % MessageToDict(response)) - - with grpc.insecure_channel('localhost:50051') as channel: - stub = helloworld_pb2_grpc.GreeterStub(channel) - data = { - 'id': 1, - 'content': 'Greeter', - } - request_pb = ParseDict( - {'data': data}, - helloworld_pb2.CreateRequest(), - ignore_unknown_fields=True, - ) - response = stub.CreateGreeter(request_pb) - print("Greeter client received : %s" % MessageToDict(response)) - - with grpc.insecure_channel('localhost:50051') as channel: - stub = helloworld_pb2_grpc.GreeterStub(channel) - response = stub.ListGreeter(helloworld_pb2.ListRequest(limit=2, offset=3)) - print("Greeter client received : %s" % MessageToDict(response)) - - with grpc.insecure_channel('localhost:50051') as channel: - stub = helloworld_pb2_grpc.GreeterStub(channel) - response = stub.GetItem(helloworld_pb2.GetRequest(id=1)) - print("Greeter client received : %s" % MessageToDict(response)) + # with grpc.insecure_channel('localhost:50051') as channel: + # stub = helloworld_pb2_grpc.GreeterStub(channel) + # response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + # # print("Greeter client received : %s" % response.message) + # + # with grpc.insecure_channel('localhost:50051') as channel: + # stub = helloworld_pb2_grpc.GreeterStub(channel) + # response = stub.GetGreeter(helloworld_pb2.GetRequest(id=3)) + # # print("Greeter client received : %s" % MessageToDict(response)) + # + # with grpc.insecure_channel('localhost:50051') as channel: + # stub = helloworld_pb2_grpc.GreeterStub(channel) + # data = { + # 'id': 1, + # 'content': 'Greeter', + # } + # request_pb = ParseDict( + # {'data': data}, + # helloworld_pb2.CreateRequest(), + # ignore_unknown_fields=True, + # ) + # response = stub.CreateGreeter(request_pb) + # # print("Greeter client received : %s" % MessageToDict(response)) + # + # with grpc.insecure_channel('localhost:50051') as channel: + # stub = helloworld_pb2_grpc.GreeterStub(channel) + # response = stub.ListGreeter(helloworld_pb2.ListRequest(limit=2, offset=3)) + # # print("Greeter client received : %s" % MessageToDict(response)) + # + # with grpc.insecure_channel('localhost:50051') as channel: + # stub = helloworld_pb2_grpc.GreeterStub(channel) + # response = stub.GetItem(helloworld_pb2.GetRequest(id=1)) + # print("Greeter client received : %s" % MessageToDict(response)) with grpc.insecure_channel('localhost:50051') as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) response = stub.ListItems(helloworld_pb2.ListRequest(limit=2, offset=3)) - print("Greeter client received : %s" % MessageToDict(response)) + # print("Greeter client received : %s" % MessageToDict(response)) if __name__ == '__main__': diff --git a/examples/grpc_server.py b/examples/grpc_server.py index 2bf5804..068c284 100644 --- a/examples/grpc_server.py +++ b/examples/grpc_server.py @@ -27,24 +27,24 @@ class GrpcServer(helloworld_pb2_grpc.GreeterServicer, ServiceMixin): - def SayHello(self, request, context): - print('Greeter.SayHello') - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) - - def GetGreeter(self, request, context): - print('Greeter.GetGreeter') - return GreeterResource(request, context, pb2.ItemResponse).get() - - def ListGreeter(self, request, context): - print('Greeter.ListGreeter') - return GreeterResource(request, context, pb2.ListResponse).list() - - def CreateGreeter(self, request, context): - print('Greeter.CreateGreeter') - return GreeterResource(request, context, pb2.ItemResponse).create() - - def GetItem(self, request, context): - return ItemResource(request, context, pb2.ItemResponse).get() + # def SayHello(self, request, context): + # print('Greeter.SayHello') + # return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + # + # def GetGreeter(self, request, context): + # print('Greeter.GetGreeter') + # return GreeterResource(request, context, pb2.ItemResponse).get() + # + # def ListGreeter(self, request, context): + # print('Greeter.ListGreeter') + # return GreeterResource(request, context, pb2.ListResponse).list() + # + # def CreateGreeter(self, request, context): + # print('Greeter.CreateGreeter') + # return GreeterResource(request, context, pb2.ItemResponse).create() + # + # def GetItem(self, request, context): + # return ItemResource(request, context, pb2.ItemResponse).get() def ListItems(self, request, context): return ItemResource(request, context, pb2.ListResponse).list() diff --git a/examples/grpc_server_async.py b/examples/grpc_server_async.py index 1f5ec55..4e128d1 100644 --- a/examples/grpc_server_async.py +++ b/examples/grpc_server_async.py @@ -29,31 +29,34 @@ class GrpcServer(helloworld_pb2_grpc.GreeterServicer, ServiceMixin): - def SayHello(self, request, context): - print('Greeter.SayHello') - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + # def SayHello(self, request, context): + # print('Greeter.SayHello') + # return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + # + # def GetGreeter(self, request, context): + # print('Greeter.GetGreeter') + # return GreeterResource(request, context, pb2.ItemResponse).get() + # + # def ListGreeter(self, request, context): + # print('Greeter.ListGreeter') + # return GreeterResource(request, context, pb2.ListResponse).list() + # + # def CreateGreeter(self, request, context): + # print('Greeter.CreateGreeter') + # return GreeterResource(request, context, pb2.ItemResponse).create() + # + # def GetItem(self, request, context): + # return ItemResource(request, context, pb2.ItemResponse).get() - def GetGreeter(self, request, context): - print('Greeter.GetGreeter') - return GreeterResource(request, context, pb2.ItemResponse).get() - - def ListGreeter(self, request, context): - print('Greeter.ListGreeter') - return GreeterResource(request, context, pb2.ListResponse).list() - - def CreateGreeter(self, request, context): - print('Greeter.CreateGreeter') - return GreeterResource(request, context, pb2.ItemResponse).create() - - def GetItem(self, request, context): - return ItemResource(request, context, pb2.ItemResponse).get() - - def ListItems(self, request, context): - return ItemResource(request, context, pb2.ListResponse).list() + async def ListItems(self, request, context): + return await ItemResource(request, context, + pb2.ListResponse).list_async() async def serve(): - server = grpc.aio.server(interceptors=(ProcessInterceptor(), )) + server = grpc.aio.server( + interceptors=(ProcessInterceptor(), ), maximum_concurrent_rpcs=10 + ) helloworld_pb2_grpc.add_GreeterServicer_to_server(GrpcServer(), server) server.add_insecure_port('[::]:50051') await server.start() diff --git a/examples/main.py b/examples/main.py index e8bfb6b..75b18de 100644 --- a/examples/main.py +++ b/examples/main.py @@ -1,6 +1,6 @@ # noinspection PyUnresolvedReferences import config -import grpc_server +# import grpc_server import grpc_server_async from bali.core import Bali from v1.app import router diff --git a/examples/resources/item.py b/examples/resources/item.py index 9b93244..f63f39e 100644 --- a/examples/resources/item.py +++ b/examples/resources/item.py @@ -1,4 +1,7 @@ +import asyncio +import time from typing import Optional +from bali.core import db from pydantic import BaseModel @@ -10,6 +13,9 @@ from permissions import IsAuthenticated from schemas import ItemModel +from sqlalchemy.future import select + + class QFilter(BaseModel): name: str @@ -31,8 +37,19 @@ def get(self, pk=None): @action() def list(self, schema_in: ListRequest = None): + time.sleep(2) return Item.query().filter(*get_filters_expr(Item, **schema_in.filters)) + @action() + async def list_async(self, schema_in: ListRequest = None): + async with db.async_session() as async_session: + stmt = select(Item).filter( + *get_filters_expr(Item, **schema_in.filters) + ) + result = await async_session.execute(stmt) + await asyncio.sleep(2) + return [] + @action() def create(self, schema_in: schema = None): return Item.create(**schema_in.dict()) diff --git a/setup.py b/setup.py index ae2a1f3..51c7f29 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ def read(f): author_email='josh.yu_8@live.com', license='MIT', install_requires=INSTALL_REQUIREMENTS, - packages=find_packages(), + packages=find_packages(exclude=['examples', 'examples.*', 'tests']), package_data={'bali': ['db/*.pyi']}, include_package_data=True, zip_safe=False, From 56db0893a5951f5951cc55dc4ac81513cf6ed466 Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Mon, 18 Oct 2021 10:22:00 +0800 Subject: [PATCH 10/11] Implemented issue #35 and bump version 3.0.0-rc.1 --- bali/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bali/__init__.py b/bali/__init__.py index 4eb28e3..e96725b 100644 --- a/bali/__init__.py +++ b/bali/__init__.py @@ -1 +1 @@ -__version__ = '3.0.0' +__version__ = '3.0.0-rc.1' From fc3445d9c764e6c046115e594ec2152ab86cc009 Mon Sep 17 00:00:00 2001 From: JoshYuJump Date: Mon, 18 Oct 2021 10:30:15 +0800 Subject: [PATCH 11/11] Updated removed main.py default launch behavior docs --- README.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/README.md b/README.md index 50b7676..cea7219 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,6 @@ app.settings(base_settings={'title': 'Bali App'}) Launch ```bash -# lauch RPC and HTTP service -python main.py - # lauch RPC python main.py --rpc