import math
import sys
from flask import abort
from flask import render_template
from flask import request
from peewee import Database
from peewee import DoesNotExist
from peewee import Model
from peewee import Proxy
from peewee import SelectQuery
from playhouse.db_url import connect as db_url_connect
class PaginatedQuery(object):
def __init__(self, query_or_model, paginate_by, page_var='page', page=None,
check_bounds=False):
self.paginate_by = paginate_by
self.page_var = page_var
self.page = page or None
self.check_bounds = check_bounds
if isinstance(query_or_model, SelectQuery):
self.query = query_or_model
self.model = self.query.model
else:
self.model = query_or_model
self.query = self.model.select()
def get_page(self):
if self.page is not None:
return self.page
curr_page = request.args.get(self.page_var)
if curr_page and curr_page.isdigit():
return max(1, int(curr_page))
return 1
def get_page_count(self):
if not hasattr(self, '_page_count'):
self._page_count = int(math.ceil(
float(self.query.count()) / self.paginate_by))
return self._page_count
def get_object_list(self):
if self.check_bounds and self.get_page() > self.get_page_count():
abort(404)
return self.query.paginate(self.get_page(), self.paginate_by)
def get_page_range(self, page, total, show=5):
# Generate page buttons for a subset of pages, e.g. if the current page
# is 4, we have 10 pages, and want to show 5 buttons, this function
# returns us: [2, 3, 4, 5, 6]
start = max((page - (show // 2)), 1)
stop = min(start + show, total) + 1
start = max(min(start, stop - show), 1)
return list(range(start, stop)[:show])
def get_object_or_404(query_or_model, *query):
if not isinstance(query_or_model, SelectQuery):
query_or_model = query_or_model.select()
try:
return query_or_model.where(*query).get()
except DoesNotExist:
abort(404)
def object_list(template_name, query, context_variable='object_list',
paginate_by=20, page_var='page', page=None, check_bounds=True,
**kwargs):
paginated_query = PaginatedQuery(
query,
paginate_by=paginate_by,
page_var=page_var,
page=page,
check_bounds=check_bounds)
kwargs[context_variable] = paginated_query.get_object_list()
return render_template(
template_name,
pagination=paginated_query,
page=paginated_query.get_page(),
**kwargs)
def get_current_url():
if not request.query_string:
return request.path
return '%s?%s' % (request.path, request.query_string)
def get_next_url(default='/'):
if request.args.get('next'):
return request.args['next']
elif request.form.get('next'):
return request.form['next']
return default
class FlaskDB(object):
"""
Convenience wrapper for configuring a Peewee database for use with a Flask
application. Provides a base `Model` class and registers handlers to manage
the database connection during the request/response cycle.
Usage::
from flask import Flask
from peewee import *
from playhouse.flask_utils import FlaskDB
# The database can be specified using a database URL, or you can pass a
# Peewee database instance directly:
DATABASE = 'postgresql:///my_app'
DATABASE = PostgresqlDatabase('my_app')
# If we do not want connection-management on any views, we can specify
# the view names using FLASKDB_EXCLUDED_ROUTES. The db connection will
# not be opened/closed automatically when these views are requested:
FLASKDB_EXCLUDED_ROUTES = ('logout',)
app = Flask(__name__)
app.config.from_object(__name__)
# Now we can configure our FlaskDB:
flask_db = FlaskDB(app)
# Or use the "deferred initialization" pattern:
flask_db = FlaskDB()
flask_db.init_app(app)
# The `flask_db` provides a base Model-class for easily binding models
# to the configured database:
class User(flask_db.Model):
email = CharField()
"""
def __init__(self, app=None, database=None, model_class=Model,
excluded_routes=None):
self.database = None # Reference to actual Peewee database instance.
self.base_model_class = model_class
self._app = app
self._db = database # dict, url, Database, or None (default).
self._excluded_routes = excluded_routes or ()
if app is not None:
self.init_app(app)
def init_app(self, app):
self._app = app
if self._db is None:
if 'DATABASE' in app.config:
initial_db = app.config['DATABASE']
elif 'DATABASE_URL' in app.config:
initial_db = app.config['DATABASE_URL']
else:
raise ValueError('Missing required configuration data for '
'database: DATABASE or DATABASE_URL.')
else:
initial_db = self._db
if 'FLASKDB_EXCLUDED_ROUTES' in app.config:
self._excluded_routes = app.config['FLASKDB_EXCLUDED_ROUTES']
self._load_database(app, initial_db)
self._register_handlers(app)
def _load_database(self, app, config_value):
if isinstance(config_value, Database):
database = config_value
elif isinstance(config_value, dict):
database = self._load_from_config_dict(dict(config_value))
else:
# Assume a database connection URL.
database = db_url_connect(config_value)
if isinstance(self.database, Proxy):
self.database.initialize(database)
else:
self.database = database
def _load_from_config_dict(self, config_dict):
try:
name = config_dict.pop('name')
engine = config_dict.pop('engine')
except KeyError:
raise RuntimeError('DATABASE configuration must specify a '
'`name` and `engine`.')
if '.' in engine:
path, class_name = engine.rsplit('.', 1)
else:
path, class_name = 'peewee', engine
try:
__import__(path)
module = sys.modules[path]
database_class = getattr(module, class_name)
assert issubclass(database_class, Database)
except ImportError:
raise RuntimeError('Unable to import %s' % engine)
except AttributeError:
raise RuntimeError('Database engine not found %s' % engine)
except AssertionError:
raise RuntimeError('Database engine not a subclass of '
'peewee.Database: %s' % engine)
return database_class(name, **config_dict)
def _register_handlers(self, app):
app.before_request(self.connect_db)
app.teardown_request(self.close_db)
def get_model_class(self):
if self.database is None:
raise RuntimeError('Database must be initialized.')
class BaseModel(self.base_model_class):
class Meta:
database = self.database
return BaseModel
@property
def Model(self):
if self._app is None:
database = getattr(self, 'database', None)
if database is None:
self.database = Proxy()
if not hasattr(self, '_model_class'):
self._model_class = self.get_model_class()
return self._model_class
def connect_db(self):
if self._excluded_routes and request.endpoint in self._excluded_routes:
return
self.database.connect()
def close_db(self, exc):
if self._excluded_routes and request.endpoint in self._excluded_routes:
return
if not self.database.is_closed():
self.database.close()