db_base.py 1.6 KB
Newer Older
P
peng.xu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
import logging
from sqlalchemy import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.orm.session import Session as SessionBase

logger = logging.getLogger(__name__)


class LocalSession(SessionBase):
    def __init__(self, db, autocommit=False, autoflush=True, **options):
        self.db = db
        bind = options.pop('bind', None) or db.engine
        SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, **options)


class DB:
    Model = declarative_base()

    def __init__(self, uri=None, echo=False):
        self.echo = echo
        uri and self.init_db(uri, echo)
        self.session_factory = scoped_session(sessionmaker(class_=LocalSession, db=self))

赖龙 已提交
26
    def init_db(self, uri, echo=False, pool_size=100, pool_recycle=5, pool_timeout=30, pool_pre_ping=True, max_overflow=0):
P
peng.xu 已提交
27 28 29 30
        url = make_url(uri)
        if url.get_backend_name() == 'sqlite':
            self.engine = create_engine(url)
        else:
赖龙 已提交
31
            self.engine = create_engine(uri, pool_size, pool_recycle, pool_timeout, pool_pre_ping, echo, max_overflow)
P
peng.xu 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
        self.uri = uri
        self.url = url

    def __str__(self):
        return '<DB: backend={};database={}>'.format(self.url.get_backend_name(), self.url.database)

    @property
    def Session(self):
        return self.session_factory()

    def remove_session(self):
        self.session_factory.remove()

    def drop_all(self):
        self.Model.metadata.drop_all(self.engine)

    def create_all(self):
        self.Model.metadata.create_all(self.engine)