Skip to content

Commit

Permalink
Merge pull request #42 from JoshYuJump/3.0-stable
Browse files Browse the repository at this point in the history
3.0 stable
  • Loading branch information
JoshYuJump authored Oct 18, 2021
2 parents dc38e57 + fc3445d commit 2789ccc
Show file tree
Hide file tree
Showing 24 changed files with 435 additions and 133 deletions.
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ indent_style = space
tab_width = 2

[*.py]
max_line_length = 100
max_line_length = 79
8 changes: 7 additions & 1 deletion CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
# This project adheres to [Semantic Versioning](http://semver.org/).
# includes Added / Changed / Fixed

## [2.1.0] UNRELEASED
## [3.0.0] UNRELEASED
### Added
- Upgraded to sql-wrapper v5.0.0dev
- Supported uvicorn 0.15
- Model support asynchronous
- Resource support asynchronous
### Changed
- Removed main.py default launch behavior

### Changed
- Adjusted version range dependency packages
Expand Down
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ app.settings(base_settings={'title': 'Bali App'})
Launch

```bash
# lauch RPC and HTTP service
python main.py

# lauch RPC
python main.py --rpc

Expand Down
2 changes: 1 addition & 1 deletion bali/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.1.0'
__version__ = '3.0.0-rc.1'
Empty file added bali/aio/__init__.py
Empty file.
39 changes: 39 additions & 0 deletions bali/aio/interceptors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
gRPC interceptors
"""
import grpc
import logging

from google.protobuf import json_format

from bali.core import db

from typing import Callable, Any
from grpc.aio import ServerInterceptor
from ..core import _settings

logger = logging.getLogger('bali')


class ProcessInterceptor(ServerInterceptor):
def setup(self):
pass

def teardown(self):
try:
db.s.remove()
except Exception:
pass

async def intercept_service(
self,
continuation: Callable,
handler_call_details: grpc.HandlerCallDetails,
) -> Any:
self.setup()
try:
result = await continuation(handler_call_details)
finally:
self.teardown()

return result
21 changes: 13 additions & 8 deletions bali/application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gzip
import inspect
from multiprocessing import Process
from typing import Callable

Expand All @@ -12,6 +13,7 @@

from ._utils import singleton
from .middlewares import process_middleware
from .utils import sync_exec


class GzipRequest(Request):
Expand Down Expand Up @@ -47,8 +49,9 @@ def __getattribute__(self, attr, *args, **kwargs):
return super().__getattribute__(attr)
except AttributeError:
if not self._app:
print(f'attr: {attr}')
# uvicorn entry is __call__
if attr == '__call__':
if attr == '__call__' or attr == '__getstate__':
self.http()
return getattr(self._app, attr)

Expand All @@ -64,11 +67,14 @@ def _launch_http(self):
self._app = FastAPI(**self.base_settings)
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, access_log=True)

def _launch_rpc(self):
async def _launch_rpc(self):
service = self.kwargs.get('rpc_service')
if not service:
raise Exception('rpc_service not provided')
service.serve()
if inspect.iscoroutinefunction(service.serve):
await service.serve()
else:
service.serve()

def _start_all(self):
process_http = Process(target=self._launch_http)
Expand Down Expand Up @@ -105,15 +111,14 @@ def http(self):
add_pagination(self._app)

def launch(self, http: bool = False, rpc: bool = False):
start_all = not any([http, rpc])
if start_all:
return self._start_all()
if not http and not rpc:
typer.echo('Please provided launch service type: --http or --rpc')

if http:
self._launch_http()

if start_all or rpc:
self._launch_rpc()
if rpc:
sync_exec(self._launch_rpc())

def start(self):
typer.run(self.launch)
66 changes: 52 additions & 14 deletions bali/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,78 @@
import logging
from functools import wraps

from contextlib import asynccontextmanager
from sqla_wrapper import SQLAlchemy
from sqlalchemy.exc import OperationalError
from .models import get_base_model, AwareDateTime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker

from .models import get_base_model

# from core.config import settings

error_logger = logging.getLogger('error')

database_schema_async_maps = [
('sqlite://', 'sqlite+aiosqlite://'),
('mysql+pymysql://', 'mysql+aiomysql://'),
('postgres://', 'postgresql+asyncpg://'),
]


# noinspection PyPep8Naming
class DB:
def __init__(self):
self._session = None

def connect(self, database_uri, **kwargs):
kwargs.setdefault("pool_size", 5)
kwargs.setdefault("pool_recycle", 2 * 60 * 60)

# developers need to know when the ORM object needs to reload from the db
kwargs.setdefault("expire_on_commit", False)
self._session = SQLAlchemy(database_uri, **kwargs)
self._db = None # sync engine
self._async_engine = None # async engine
self.async_session = None # async session maker

def connect(
self,
database_uri,
engine_options=None,
session_options=None,
**kwargs
):
engine_options = engine_options or {}
engine_options.setdefault("pool_size", 5)
engine_options.setdefault("pool_recycle", 2 * 60 * 60)

session_options = session_options or {}
# developers need to know when the ORM object needs to reload
# from the db
session_options.setdefault("expire_on_commit", False)

# Sync mode db instance
self._db = SQLAlchemy(
database_uri,
engine_options=engine_options,
session_options=session_options,
)
async_database_uri = database_uri
for sync_schema, async_schema in database_schema_async_maps:
async_database_uri = async_database_uri.replace(
sync_schema, async_schema
)
self._async_engine = create_async_engine(async_database_uri)

self.async_session = sessionmaker(
self._async_engine, class_=AsyncSession, expire_on_commit=False
)

def __getattribute__(self, attr, *args, **kwargs):
try:
return super().__getattribute__(attr)
except AttributeError:
if not self._session:
if not self._db:
raise Exception('Database session not initialized')

# BaseModel
if attr == 'BaseModel':
return get_base_model(self)

return getattr(self._session, attr)
return getattr(self._db, attr)


db = DB()
Expand All @@ -56,8 +94,8 @@ def wrapper(*args, **kwargs):
if any(msg in e.message for msg in lock_messages_error) \
and attempt_count <= MAXIMUM_RETRY_ON_DEADLOCK:
error_logger.error(
'Deadlock detected. Trying sql transaction once more. Attempts count: %s' %
(attempt_count + 1)
'Deadlock detected. Trying sql transaction once more. Attempts count: %s'
% (attempt_count + 1)
)
else:
raise
Expand Down
5 changes: 4 additions & 1 deletion bali/db/connection.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,16 @@ class BaseModel:
class DB(SQLAlchemy):
BaseModel: BaseModel

_async_engine: Any

def connect(self, database_uri: str) -> None: ...

@property
def session(self) -> Session: ...


def transaction(self) -> None: ...

async def async_session(self) -> Any: ...


db = DB()
46 changes: 23 additions & 23 deletions bali/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class BaseModel(db.Model):
@classmethod
def exists(cls, **attrs):
"""Returns whether an object with these attributes exists."""
equery = cls.query().filter_by(**attrs).exists()
return bool(db.session.query(equery).scalar())
query = cls.query().filter_by(**attrs).exists()
return bool(db.s.query(query).scalar())

@classmethod
def create(cls, **attrs):
Expand All @@ -61,7 +61,7 @@ def create_or_first(cls, **attrs):
try:
return cls.create(**attrs)
except IntegrityError:
db.session.rollback()
db.s.rollback()
return cls.first(**attrs)

@classmethod
Expand All @@ -80,20 +80,20 @@ def first_or_error(cls, **attrs):

@classmethod
def query(cls):
return db.session.query(cls)
return db.s.query(cls)

def save(self):
"""Override default model's save"""
global context_auto_commit
db.session.add(self)
db.session.commit() if context_auto_commit.get() else db.session.flush()
db.s.add(self)
db.s.commit() if context_auto_commit.get() else db.s.flush()
return self

def delete(self):
"""Override default model's delete"""
global context_auto_commit
db.session.delete(self)
db.session.commit() if context_auto_commit.get() else db.session.flush()
db.s.delete(self)
db.s.commit() if context_auto_commit.get() else db.s.flush()

def to_dict(self):
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns}
Expand All @@ -103,34 +103,34 @@ def dict(self):

@classmethod
def count(cls, **attrs) -> int:
return db.session.query(func.count(cls.id)).filter_by(**attrs).scalar()
return db.s.query(func.count(cls.id)).filter_by(**attrs).scalar()

@classmethod
def get_fields(cls) -> List[str]:
return [c.name for c in cls.__table__.columns]

@classmethod
def get_or_create(cls, defaults: Dict = None, **kwargs):
instance = db.session.query(cls).filter_by(**kwargs).one_or_none()
instance = db.s.query(cls).filter_by(**kwargs).one_or_none()
if instance:
return instance, False

instance = cls(**{**kwargs, **(defaults or {})}) # noqa
try:
db.session.add(instance)
db.session.commit()
db.s.add(instance)
db.s.commit()
return instance, True
except SQLAlchemyError:
db.session.rollback()
instance = db.session.query(cls).filter_by(**kwargs).one()
db.s.rollback()
instance = db.s.query(cls).filter_by(**kwargs).one()
return instance, False

@classmethod
def update_or_create(cls, defaults: Dict = None, **kwargs):
try:
try:
instance = (
db.session.query(cls)
db.s.query(cls)
.filter_by(**kwargs)
.populate_existing()
.with_for_update()
Expand All @@ -139,12 +139,12 @@ def update_or_create(cls, defaults: Dict = None, **kwargs):
except NoResultFound:
instance = cls(**{**kwargs, **(defaults or {})}) # noqa
try:
db.session.add(instance)
db.session.commit()
db.s.add(instance)
db.s.commit()
except SQLAlchemyError:
db.session.rollback()
db.s.rollback()
instance = (
db.session.query(cls)
db.s.query(cls)
.filter_by(**kwargs)
.populate_existing()
.with_for_update()
Expand All @@ -153,14 +153,14 @@ def update_or_create(cls, defaults: Dict = None, **kwargs):
else:
return instance, True
except SQLAlchemyError:
db.session.rollback()
db.s.rollback()
raise
else:
for k, v in defaults.items():
setattr(instance, k, v)
db.session.add(instance)
db.session.commit()
db.session.refresh(instance)
db.s.add(instance)
db.s.commit()
db.s.refresh(instance)
return instance, False

return BaseModel
Loading

0 comments on commit 2789ccc

Please sign in to comment.