# May you do good and not evil
# May you find forgiveness for yourself and forgive others
# May you share freely, never taking more than you give. -- SQLite source code
#
# As we enjoy great advantages from the inventions of others, we should be glad
# of an opportunity to serve others by an invention of ours, and this we should
# do freely and generously. -- Ben Franklin
#
# (\
# ( \ /(o)\ caw!
# ( \/ ()/ /)
# ( `;.))'".)
# `(/////.-'
# =====))=))===()
# ///'
# //
# '
import calendar
import datetime
import decimal
import hashlib
import itertools
import logging
import operator
import re
import sys
import threading
import time
import uuid
import weakref
from bisect import bisect_left
from bisect import bisect_right
from collections import deque
from collections import namedtuple
try:
from collections import OrderedDict
except ImportError:
OrderedDict = dict
from copy import deepcopy
from functools import wraps
from inspect import isclass
__version__ = '2.8.4'
__all__ = [
'BareField',
'BigIntegerField',
'BlobField',
'BooleanField',
'CharField',
'Check',
'Clause',
'CompositeKey',
'DatabaseError',
'DataError',
'DateField',
'DateTimeField',
'DecimalField',
'DeferredRelation',
'DoesNotExist',
'DoubleField',
'DQ',
'Field',
'FixedCharField',
'FloatField',
'fn',
'ForeignKeyField',
'ImproperlyConfigured',
'IntegerField',
'IntegrityError',
'InterfaceError',
'InternalError',
'JOIN',
'JOIN_FULL',
'JOIN_INNER',
'JOIN_LEFT_OUTER',
'Model',
'MySQLDatabase',
'NotSupportedError',
'OperationalError',
'Param',
'PostgresqlDatabase',
'prefetch',
'PrimaryKeyField',
'ProgrammingError',
'Proxy',
'R',
'SmallIntegerField',
'SqliteDatabase',
'SQL',
'TextField',
'TimeField',
'TimestampField',
'Using',
'UUIDField',
'Window',
]
# Set default logging handler to avoid "No handlers could be found for logger
# "peewee"" warnings.
try: # Python 2.7+
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
# All peewee-generated logs are logged to this namespace.
logger = logging.getLogger('peewee')
logger.addHandler(NullHandler())
# Python 2/3 compatibility helpers. These helpers are used internally and are
# not exported.
_METACLASS_ = '_metaclass_helper_'
def with_metaclass(meta, base=object):
return meta(_METACLASS_, (base,), {})
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
PY26 = sys.version_info[:2] == (2, 6)
if PY3:
import builtins
from collections import Callable
from functools import reduce
callable = lambda c: isinstance(c, Callable)
unicode_type = str
string_type = bytes
basestring = str
print_ = getattr(builtins, 'print')
binary_construct = lambda s: bytes(s.encode('raw_unicode_escape'))
long = int
def reraise(tp, value, tb=None):
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
elif PY2:
unicode_type = unicode
string_type = basestring
binary_construct = buffer
def print_(s):
sys.stdout.write(s)
sys.stdout.write('\n')
exec('def reraise(tp, value, tb=None): raise tp, value, tb')
else:
raise RuntimeError('Unsupported python version.')
if PY26:
_M = 10**6
total_seconds = lambda t: (t.microseconds + 0.0 + (t.seconds + t.days * 24 * 3600) * _M) / _M
else:
total_seconds = lambda t: t.total_seconds()
# By default, peewee supports Sqlite, MySQL and Postgresql.
try:
from pysqlite2 import dbapi2 as pysq3
except ImportError:
pysq3 = None
try:
import sqlite3
except ImportError:
sqlite3 = pysq3
else:
if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info:
sqlite3 = pysq3
try:
from psycopg2cffi import compat
compat.register()
except ImportError:
pass
try:
import psycopg2
from psycopg2 import extensions as pg_extensions
except ImportError:
psycopg2 = None
try:
import MySQLdb as mysql # prefer the C module.
except ImportError:
try:
import pymysql as mysql
except ImportError:
mysql = None
try:
from playhouse._speedups import format_date_time
from playhouse._speedups import sort_models_topologically
from playhouse._speedups import strip_parens
except ImportError:
def format_date_time(value, formats, post_process=None):
post_process = post_process or (lambda x: x)
for fmt in formats:
try:
return post_process(datetime.datetime.strptime(value, fmt))
except ValueError:
pass
return value
def sort_models_topologically(models):
"""Sort models topologically so that parents will precede children."""
models = set(models)
seen = set()
ordering = []
def dfs(model):
if model in models and model not in seen:
seen.add(model)
for foreign_key in model._meta.reverse_rel.values():
dfs(foreign_key.model_class)
ordering.append(model) # parent will follow descendants
# Order models by name and table initially to guarantee total ordering.
names = lambda m: (m._meta.name, m._meta.db_table)
for m in sorted(models, key=names, reverse=True):
dfs(m)
return list(reversed(ordering))
def strip_parens(s):
# Quick sanity check.
if not s or s[0] != '(':
return s
ct = i = 0
l = len(s)
while i < l:
if s[i] == '(' and s[l - 1] == ')':
ct += 1
i += 1
l -= 1
else:
break
if ct:
# If we ever end up with negatively-balanced parentheses, then we
# know that one of the outer parentheses was required.
unbalanced_ct = 0
required = 0
for i in range(ct, l - ct):
if s[i] == '(':
unbalanced_ct += 1
elif s[i] == ')':
unbalanced_ct -= 1
if unbalanced_ct < 0:
required += 1
unbalanced_ct = 0
if required == ct:
break
ct -= required
if ct > 0:
return s[ct:-ct]
return s
try:
from playhouse._speedups import _DictQueryResultWrapper
from playhouse._speedups import _ModelQueryResultWrapper
from playhouse._speedups import _SortedFieldList
from playhouse._speedups import _TuplesQueryResultWrapper
except ImportError:
_DictQueryResultWrapper = _ModelQueryResultWrapper = _SortedFieldList =\
_TuplesQueryResultWrapper = None
if sqlite3:
sqlite3.register_adapter(decimal.Decimal, str)
sqlite3.register_adapter(datetime.date, str)
sqlite3.register_adapter(datetime.time, str)
DATETIME_PARTS = ['year', 'month', 'day', 'hour', 'minute', 'second']
DATETIME_LOOKUPS = set(DATETIME_PARTS)
# Sqlite does not support the `date_part` SQL function, so we will define an
# implementation in python.
SQLITE_DATETIME_FORMATS = (
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
'%Y-%m-%d',
'%H:%M:%S',
'%H:%M:%S.%f',
'%H:%M')
def _sqlite_date_part(lookup_type, datetime_string):
assert lookup_type in DATETIME_LOOKUPS
if not datetime_string:
return
dt = format_date_time(datetime_string, SQLITE_DATETIME_FORMATS)
return getattr(dt, lookup_type)
SQLITE_DATE_TRUNC_MAPPING = {
'year': '%Y',
'month': '%Y-%m',
'day': '%Y-%m-%d',
'hour': '%Y-%m-%d %H',
'minute': '%Y-%m-%d %H:%M',
'second': '%Y-%m-%d %H:%M:%S'}
MYSQL_DATE_TRUNC_MAPPING = SQLITE_DATE_TRUNC_MAPPING.copy()
MYSQL_DATE_TRUNC_MAPPING['minute'] = '%Y-%m-%d %H:%i'
MYSQL_DATE_TRUNC_MAPPING['second'] = '%Y-%m-%d %H:%i:%S'
def _sqlite_date_trunc(lookup_type, datetime_string):
assert lookup_type in SQLITE_DATE_TRUNC_MAPPING
if not datetime_string:
return
dt = format_date_time(datetime_string, SQLITE_DATETIME_FORMATS)
return dt.strftime(SQLITE_DATE_TRUNC_MAPPING[lookup_type])
def _sqlite_regexp(regex, value):
return re.search(regex, value, re.I) is not None
class attrdict(dict):
def __getattr__(self, attr):
return self[attr]
# Operators used in binary expressions.
OP = attrdict(
AND='and',
OR='or',
ADD='+',
SUB='-',
MUL='*',
DIV='/',
BIN_AND='&',
BIN_OR='|',
XOR='^',
MOD='%',
EQ='=',
LT='<',
LTE='<=',
GT='>',
GTE='>=',
NE='!=',
IN='in',
NOT_IN='not in',
IS='is',
IS_NOT='is not',
LIKE='like',
ILIKE='ilike',
BETWEEN='between',
REGEXP='regexp',
CONCAT='||',
)
JOIN = attrdict(
INNER='INNER',
LEFT_OUTER='LEFT OUTER',
RIGHT_OUTER='RIGHT OUTER',
FULL='FULL',
)
JOIN_INNER = JOIN.INNER
JOIN_LEFT_OUTER = JOIN.LEFT_OUTER
JOIN_FULL = JOIN.FULL
RESULTS_NAIVE = 1
RESULTS_MODELS = 2
RESULTS_TUPLES = 3
RESULTS_DICTS = 4
RESULTS_AGGREGATE_MODELS = 5
# To support "django-style" double-underscore filters, create a mapping between
# operation name and operation code, e.g. "__eq" == OP.EQ.
DJANGO_MAP = {
'eq': OP.EQ,
'lt': OP.LT,
'lte': OP.LTE,
'gt': OP.GT,
'gte': OP.GTE,
'ne': OP.NE,
'in': OP.IN,
'is': OP.IS,
'like': OP.LIKE,
'ilike': OP.ILIKE,
'regexp': OP.REGEXP,
}
# Helper functions that are used in various parts of the codebase.
def merge_dict(source, overrides):
merged = source.copy()
merged.update(overrides)
return merged
def returns_clone(func):
"""
Method decorator that will "clone" the object before applying the given
method. This ensures that state is mutated in a more predictable fashion,
and promotes the use of method-chaining.
"""
def inner(self, *args, **kwargs):
clone = self.clone() # Assumes object implements `clone`.
func(clone, *args, **kwargs)
return clone
inner.call_local = func # Provide a way to call without cloning.
return inner
def not_allowed(func):
"""
Method decorator to indicate a method is not allowed to be called. Will
raise a `NotImplementedError`.
"""
def inner(self, *args, **kwargs):
raise NotImplementedError('%s is not allowed on %s instances' % (
func, type(self).__name__))
return inner
class Proxy(object):
"""
Proxy class useful for situations when you wish to defer the initialization
of an object.
"""
__slots__ = ['obj', '_callbacks']
def __init__(self):
self._callbacks = []
self.initialize(None)
def initialize(self, obj):
self.obj = obj
for callback in self._callbacks:
callback(obj)
def attach_callback(self, callback):
self._callbacks.append(callback)
return callback
def __getattr__(self, attr):
if self.obj is None:
raise AttributeError('Cannot use uninitialized Proxy.')
return getattr(self.obj, attr)
def __setattr__(self, attr, value):
if attr not in self.__slots__:
raise AttributeError('Cannot set attribute on proxy.')
return super(Proxy, self).__setattr__(attr, value)
class DeferredRelation(object):
_unresolved = set()
def __init__(self, rel_model_name=None):
self.fields = []
if rel_model_name is not None:
self._rel_model_name = rel_model_name.lower()
self._unresolved.add(self)
def set_field(self, model_class, field, name):
self.fields.append((model_class, field, name))
def set_model(self, rel_model):
for model, field, name in self.fields:
field.rel_model = rel_model
field.add_to_class(model, name)
@staticmethod
def resolve(model_cls):
unresolved = list(DeferredRelation._unresolved)
for dr in unresolved:
if dr._rel_model_name == model_cls.__name__.lower():
dr.set_model(model_cls)
DeferredRelation._unresolved.discard(dr)
class _CDescriptor(object):
def __get__(self, instance, instance_type=None):
if instance is not None:
return Entity(instance._alias)
return self
# Classes representing the query tree.
class Node(object):
"""Base-class for any part of a query which shall be composable."""
c = _CDescriptor()
_node_type = 'node'
def __init__(self):
self._negated = False
self._alias = None
self._bind_to = None
self._ordering = None # ASC or DESC.
@classmethod
def extend(cls, name=None, clone=False):
def decorator(method):
method_name = name or method.__name__
if clone:
method = returns_clone(method)
setattr(cls, method_name, method)
return method
return decorator
def clone_base(self):
return type(self)()
def clone(self):
inst = self.clone_base()
inst._negated = self._negated
inst._alias = self._alias
inst._ordering = self._ordering
inst._bind_to = self._bind_to
return inst
@returns_clone
def __invert__(self):
self._negated = not self._negated
@returns_clone
def alias(self, a=None):
self._alias = a
@returns_clone
def bind_to(self, bt):
"""
Bind the results of an expression to a specific model type. Useful
when adding expressions to a select, where the result of the expression
should be placed on a joined instance.
"""
self._bind_to = bt
@returns_clone
def asc(self):
self._ordering = 'ASC'
@returns_clone
def desc(self):
self._ordering = 'DESC'
def __pos__(self):
return self.asc()
def __neg__(self):
return self.desc()
def _e(op, inv=False):
"""
Lightweight factory which returns a method that builds an Expression
consisting of the left-hand and right-hand operands, using `op`.
"""
def inner(self, rhs):
if inv:
return Expression(rhs, op, self)
return Expression(self, op, rhs)
return inner
__and__ = _e(OP.AND)
__or__ = _e(OP.OR)
__add__ = _e(OP.ADD)
__sub__ = _e(OP.SUB)
__mul__ = _e(OP.MUL)
__div__ = __truediv__ = _e(OP.DIV)
__xor__ = _e(OP.XOR)
__radd__ = _e(OP.ADD, inv=True)
__rsub__ = _e(OP.SUB, inv=True)
__rmul__ = _e(OP.MUL, inv=True)
__rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True)
__rand__ = _e(OP.AND, inv=True)
__ror__ = _e(OP.OR, inv=True)
__rxor__ = _e(OP.XOR, inv=True)
def __eq__(self, rhs):
if rhs is None:
return Expression(self, OP.IS, None)
return Expression(self, OP.EQ, rhs)
def __ne__(self, rhs):
if rhs is None:
return Expression(self, OP.IS_NOT, None)
return Expression(self, OP.NE, rhs)
__lt__ = _e(OP.LT)
__le__ = _e(OP.LTE)
__gt__ = _e(OP.GT)
__ge__ = _e(OP.GTE)
__lshift__ = _e(OP.IN)
__rshift__ = _e(OP.IS)
__mod__ = _e(OP.LIKE)
__pow__ = _e(OP.ILIKE)
bin_and = _e(OP.BIN_AND)
bin_or = _e(OP.BIN_OR)
# Special expressions.
def in_(self, rhs):
return Expression(self, OP.IN, rhs)
def not_in(self, rhs):
return Expression(self, OP.NOT_IN, rhs)
def is_null(self, is_null=True):
if is_null:
return Expression(self, OP.IS, None)
return Expression(self, OP.IS_NOT, None)
def contains(self, rhs):
return Expression(self, OP.ILIKE, '%%%s%%' % rhs)
def startswith(self, rhs):
return Expression(self, OP.ILIKE, '%s%%' % rhs)
def endswith(self, rhs):
return Expression(self, OP.ILIKE, '%%%s' % rhs)
def between(self, low, high):
return Expression(self, OP.BETWEEN, Clause(low, R('AND'), high))
def regexp(self, expression):
return Expression(self, OP.REGEXP, expression)
def concat(self, rhs):
return Expression(self, OP.CONCAT, rhs)
class SQL(Node):
"""An unescaped SQL string, with optional parameters."""
_node_type = 'sql'
def __init__(self, value, *params):
self.value = value
self.params = params
super(SQL, self).__init__()
def clone_base(self):
return SQL(self.value, *self.params)
R = SQL # backwards-compat.
class Entity(Node):
"""A quoted-name or entity, e.g. "table"."column"."""
_node_type = 'entity'
def __init__(self, *path):
super(Entity, self).__init__()
self.path = path
def clone_base(self):
return Entity(*self.path)
def __getattr__(self, attr):
return Entity(*filter(None, self.path + (attr,)))
class Func(Node):
"""An arbitrary SQL function call."""
_node_type = 'func'
_no_coerce = set(('count', 'sum'))
def __init__(self, name, *arguments):
self.name = name
self.arguments = arguments
self._coerce = (name.lower() not in self._no_coerce) if name else False
super(Func, self).__init__()
@returns_clone
def coerce(self, coerce=True):
self._coerce = coerce
def clone_base(self):
res = Func(self.name, *self.arguments)
res._coerce = self._coerce
return res
def over(self, partition_by=None, order_by=None, window=None):
if isinstance(partition_by, Window) and window is None:
window = partition_by
if window is None:
sql = Window(
partition_by=partition_by, order_by=order_by).__sql__()
else:
sql = SQL(window._alias)
return Clause(self, SQL('OVER'), sql)
def __getattr__(self, attr):
def dec(*args, **kwargs):
return Func(attr, *args, **kwargs)
return dec
# fn is a factory for creating `Func` objects and supports a more friendly
# API. So instead of `Func("LOWER", param)`, `fn.LOWER(param)`.
fn = Func(None)
class Expression(Node):
"""A binary expression, e.g `foo + 1` or `bar < 7`."""
_node_type = 'expression'
def __init__(self, lhs, op, rhs, flat=False):
super(Expression, self).__init__()
self.lhs = lhs
self.op = op
self.rhs = rhs
self.flat = flat
def clone_base(self):
return Expression(self.lhs, self.op, self.rhs, self.flat)
class Param(Node):
"""
Arbitrary parameter passed into a query. Instructs the query compiler to
specifically treat this value as a parameter, useful for `list` which is
special-cased for `IN` lookups.
"""
_node_type = 'param'
def __init__(self, value, adapt=None):
self.value = value
self.adapt = adapt
super(Param, self).__init__()
def clone_base(self):
return Param(self.value, self.adapt)
class Passthrough(Param):
_node_type = 'passthrough'
class Clause(Node):
"""A SQL clause, one or more Node objects joined by spaces."""
_node_type = 'clause'
glue = ' '
parens = False
def __init__(self, *nodes, **kwargs):
if 'glue' in kwargs:
self.glue = kwargs['glue']
if 'parens' in kwargs:
self.parens = kwargs['parens']
super(Clause, self).__init__()
self.nodes = list(nodes)
def clone_base(self):
clone = Clause(*self.nodes)
clone.glue = self.glue
clone.parens = self.parens
return clone
class CommaClause(Clause):
"""One or more Node objects joined by commas, no parens."""
glue = ', '
class EnclosedClause(CommaClause):
"""One or more Node objects joined by commas and enclosed in parens."""
parens = True
class Window(Node):
def __init__(self, partition_by=None, order_by=None):
super(Window, self).__init__()
self.partition_by = partition_by
self.order_by = order_by
self._alias = self._alias or 'w'
def __sql__(self):
over_clauses = []
if self.partition_by:
over_clauses.append(Clause(
SQL('PARTITION BY'),
CommaClause(*self.partition_by)))
if self.order_by:
over_clauses.append(Clause(
SQL('ORDER BY'),
CommaClause(*self.order_by)))
return EnclosedClause(Clause(*over_clauses))
def clone_base(self):
return Window(self.partition_by, self.order_by)
def Check(value):
return SQL('CHECK (%s)' % value)
class DQ(Node):
"""A "django-style" filter expression, e.g. {'foo__eq': 'x'}."""
def __init__(self, **query):
super(DQ, self).__init__()
self.query = query
def clone_base(self):
return DQ(**self.query)
class _StripParens(Node):
_node_type = 'strip_parens'
def __init__(self, node):
super(_StripParens, self).__init__()
self.node = node
JoinMetadata = namedtuple('JoinMetadata', (
'src_model', # Source Model class.
'dest_model', # Dest Model class.
'src', # Source, may be Model, ModelAlias
'dest', # Dest, may be Model, ModelAlias, or SelectQuery.
'attr', # Attribute name joined instance(s) should be assigned to.
'primary_key', # Primary key being joined on.
'foreign_key', # Foreign key being joined from.
'is_backref', # Is this a backref, i.e. 1 -> N.
'alias', # Explicit alias given to join expression.
'is_self_join', # Is this a self-join?
'is_expression', # Is the join ON clause an Expression?
))
class Join(namedtuple('_Join', ('src', 'dest', 'join_type', 'on'))):
def get_foreign_key(self, source, dest, field=None):
if isinstance(source, SelectQuery) or isinstance(dest, SelectQuery):
return None, None
fk_field = source._meta.rel_for_model(dest, field)
if fk_field is not None:
return fk_field, False
reverse_rel = source._meta.reverse_rel_for_model(dest, field)
if reverse_rel is not None:
return reverse_rel, True
return None, None
def get_join_type(self):
return self.join_type or JOIN.INNER
def model_from_alias(self, model_or_alias):
if isinstance(model_or_alias, ModelAlias):
return model_or_alias.model_class
elif isinstance(model_or_alias, SelectQuery):
return model_or_alias.model_class
return model_or_alias
def _join_metadata(self):
# Get the actual tables being joined.
src = self.model_from_alias(self.src)
dest = self.model_from_alias(self.dest)
join_alias = isinstance(self.on, Node) and self.on._alias or None
is_expression = isinstance(self.on, (Expression, Func, SQL))
on_field = isinstance(self.on, (Field, FieldProxy)) and self.on or None
if on_field:
fk_field = on_field
is_backref = on_field.name not in src._meta.fields
else:
fk_field, is_backref = self.get_foreign_key(src, dest, self.on)
if fk_field is None and self.on is not None:
fk_field, is_backref = self.get_foreign_key(src, dest)
if fk_field is not None:
primary_key = fk_field.to_field
else:
primary_key = None
if not join_alias:
if fk_field is not None:
if is_backref:
target_attr = dest._meta.db_table
else:
target_attr = fk_field.name
else:
try:
target_attr = self.on.lhs.name
except AttributeError:
target_attr = dest._meta.db_table
else:
target_attr = None
return JoinMetadata(
src_model=src,
dest_model=dest,
src=self.src,
dest=self.dest,
attr=join_alias or target_attr,
primary_key=primary_key,
foreign_key=fk_field,
is_backref=is_backref,
alias=join_alias,
is_self_join=src is dest,
is_expression=is_expression)
@property
def metadata(self):
if not hasattr(self, '_cached_metadata'):
self._cached_metadata = self._join_metadata()
return self._cached_metadata
class FieldDescriptor(object):
# Fields are exposed as descriptors in order to control access to the
# underlying "raw" data.
def __init__(self, field):
self.field = field
self.att_name = self.field.name
def __get__(self, instance, instance_type=None):
if instance is not None:
return instance._data.get(self.att_name)
return self.field
def __set__(self, instance, value):
instance._data[self.att_name] = value
instance._dirty.add(self.att_name)
class Field(Node):
"""A column on a table."""
_field_counter = 0
_order = 0
_node_type = 'field'
db_field = 'unknown'
def __init__(self, null=False, index=False, unique=False,
verbose_name=None, help_text=None, db_column=None,
default=None, choices=None, primary_key=False, sequence=None,
constraints=None, schema=None):
self.null = null
self.index = index
self.unique = unique
self.verbose_name = verbose_name
self.help_text = help_text
self.db_column = db_column
self.default = default
self.choices = choices # Used for metadata purposes, not enforced.
self.primary_key = primary_key
self.sequence = sequence # Name of sequence, e.g. foo_id_seq.
self.constraints = constraints # List of column constraints.
self.schema = schema # Name of schema, e.g. 'public'.
# Used internally for recovering the order in which Fields were defined
# on the Model class.
Field._field_counter += 1
self._order = Field._field_counter
self._sort_key = (self.primary_key and 1 or 2), self._order
self._is_bound = False # Whether the Field is "bound" to a Model.
super(Field, self).__init__()
def clone_base(self, **kwargs):
inst = type(self)(
null=self.null,
index=self.index,
unique=self.unique,
verbose_name=self.verbose_name,
help_text=self.help_text,
db_column=self.db_column,
default=self.default,
choices=self.choices,
primary_key=self.primary_key,
sequence=self.sequence,
constraints=self.constraints,
schema=self.schema,
**kwargs)
if self._is_bound:
inst.name = self.name
inst.model_class = self.model_class
inst._is_bound = self._is_bound
return inst
def add_to_class(self, model_class, name):
"""
Hook that replaces the `Field` attribute on a class with a named
`FieldDescriptor`. Called by the metaclass during construction of the
`Model`.
"""
self.name = name
self.model_class = model_class
self.db_column = self.db_column or self.name
if not self.verbose_name:
self.verbose_name = re.sub('_+', ' ', name).title()
model_class._meta.add_field(self)
setattr(model_class, name, FieldDescriptor(self))
self._is_bound = True
def get_database(self):
return self.model_class._meta.database
def get_column_type(self):
field_type = self.get_db_field()
return self.get_database().compiler().get_column_type(field_type)
def get_db_field(self):
return self.db_field
def get_modifiers(self):
return None
def coerce(self, value):
return value
def db_value(self, value):
"""Convert the python value for storage in the database."""
return value if value is None else self.coerce(value)
def python_value(self, value):
"""Convert the database value to a pythonic value."""
return value if value is None else self.coerce(value)
def as_entity(self, with_table=False):
if with_table:
return Entity(self.model_class._meta.db_table, self.db_column)
return Entity(self.db_column)
def __ddl_column__(self, column_type):
"""Return the column type, e.g. VARCHAR(255) or REAL."""
modifiers = self.get_modifiers()
if modifiers:
return SQL(
'%s(%s)' % (column_type, ', '.join(map(str, modifiers))))
return SQL(column_type)
def __ddl__(self, column_type):
"""Return a list of Node instances that defines the column."""
ddl = [self.as_entity(), self.__ddl_column__(column_type)]
if not self.null:
ddl.append(SQL('NOT NULL'))
if self.primary_key:
ddl.append(SQL('PRIMARY KEY'))
if self.sequence:
ddl.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence))
if self.constraints:
ddl.extend(self.constraints)
return ddl
def __hash__(self):
return hash(self.name + '.' + self.model_class.__name__)
class BareField(Field):
db_field = 'bare'
def __init__(self, coerce=None, *args, **kwargs):
super(BareField, self).__init__(*args, **kwargs)
if coerce is not None:
self.coerce = coerce
def clone_base(self, **kwargs):
return super(BareField, self).clone_base(coerce=self.coerce, **kwargs)
class IntegerField(Field):
db_field = 'int'
coerce = int
class BigIntegerField(IntegerField):
db_field = 'bigint'
class SmallIntegerField(IntegerField):
db_field = 'smallint'
class PrimaryKeyField(IntegerField):
db_field = 'primary_key'
def __init__(self, *args, **kwargs):
kwargs['primary_key'] = True
super(PrimaryKeyField, self).__init__(*args, **kwargs)
class _AutoPrimaryKeyField(PrimaryKeyField):
_column_name = None
def add_to_class(self, model_class, name):
if name != self._column_name:
raise ValueError('%s must be named `%s`.' % (type(self), name))
super(_AutoPrimaryKeyField, self).add_to_class(model_class, name)
class FloatField(Field):
db_field = 'float'
coerce = float
class DoubleField(FloatField):
db_field = 'double'
class DecimalField(Field):
db_field = 'decimal'
def __init__(self, max_digits=10, decimal_places=5, auto_round=False,
rounding=None, *args, **kwargs):
self.max_digits = max_digits
self.decimal_places = decimal_places
self.auto_round = auto_round
self.rounding = rounding or decimal.DefaultContext.rounding
super(DecimalField, self).__init__(*args, **kwargs)
def clone_base(self, **kwargs):
return super(DecimalField, self).clone_base(
max_digits=self.max_digits,
decimal_places=self.decimal_places,
auto_round=self.auto_round,
rounding=self.rounding,
**kwargs)
def get_modifiers(self):
return [self.max_digits, self.decimal_places]
def db_value(self, value):
D = decimal.Decimal
if not value:
return value if value is None else D(0)
if self.auto_round:
exp = D(10) ** (-self.decimal_places)
rounding = self.rounding
return D(str(value)).quantize(exp, rounding=rounding)
return value
def python_value(self, value):
if value is not None:
if isinstance(value, decimal.Decimal):
return value
return decimal.Decimal(str(value))
def coerce_to_unicode(s, encoding='utf-8'):
if isinstance(s, unicode_type):
return s
elif isinstance(s, string_type):
return s.decode(encoding)
return unicode_type(s)
class CharField(Field):
db_field = 'string'
def __init__(self, max_length=255, *args, **kwargs):
self.max_length = max_length
super(CharField, self).__init__(*args, **kwargs)
def clone_base(self, **kwargs):
return super(CharField, self).clone_base(
max_length=self.max_length,
**kwargs)
def get_modifiers(self):
return self.max_length and [self.max_length] or None
def coerce(self, value):
return coerce_to_unicode(value or '')
class FixedCharField(CharField):
db_field = 'fixed_char'
def python_value(self, value):
value = super(FixedCharField, self).python_value(value)
if value:
value = value.strip()
return value
class TextField(Field):
db_field = 'text'
def coerce(self, value):
return coerce_to_unicode(value or '')
class BlobField(Field):
db_field = 'blob'
_constructor = binary_construct
def add_to_class(self, model_class, name):
if isinstance(model_class._meta.database, Proxy):
model_class._meta.database.attach_callback(self._set_constructor)
return super(BlobField, self).add_to_class(model_class, name)
def _set_constructor(self, database):
self._constructor = database.get_binary_type()
def db_value(self, value):
if isinstance(value, unicode_type):
value = value.encode('raw_unicode_escape')
if isinstance(value, basestring):
return self._constructor(value)
return value
class UUIDField(Field):
db_field = 'uuid'
def db_value(self, value):
if isinstance(value, uuid.UUID):
return value.hex
try:
return uuid.UUID(value).hex
except:
return value
def python_value(self, value):
if isinstance(value, uuid.UUID):
return value
return None if value is None else uuid.UUID(value)
def _date_part(date_part):
def dec(self):
return self.model_class._meta.database.extract_date(date_part, self)
return dec
class _BaseFormattedField(Field):
formats = None
def __init__(self, formats=None, *args, **kwargs):
if formats is not None:
self.formats = formats
super(_BaseFormattedField, self).__init__(*args, **kwargs)
def clone_base(self, **kwargs):
return super(_BaseFormattedField, self).clone_base(
formats=self.formats,
**kwargs)
class DateTimeField(_BaseFormattedField):
db_field = 'datetime'
formats = [
'%Y-%m-%d %H:%M:%S.%f',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d',
]
def python_value(self, value):
if value and isinstance(value, basestring):
return format_date_time(value, self.formats)
return value
year = property(_date_part('year'))
month = property(_date_part('month'))
day = property(_date_part('day'))
hour = property(_date_part('hour'))
minute = property(_date_part('minute'))
second = property(_date_part('second'))
class DateField(_BaseFormattedField):
db_field = 'date'
formats = [
'%Y-%m-%d',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
]
def python_value(self, value):
if value and isinstance(value, basestring):
pp = lambda x: x.date()
return format_date_time(value, self.formats, pp)
elif value and isinstance(value, datetime.datetime):
return value.date()
return value
year = property(_date_part('year'))
month = property(_date_part('month'))
day = property(_date_part('day'))
class TimeField(_BaseFormattedField):
db_field = 'time'
formats = [
'%H:%M:%S.%f',
'%H:%M:%S',
'%H:%M',
'%Y-%m-%d %H:%M:%S.%f',
'%Y-%m-%d %H:%M:%S',
]
def python_value(self, value):
if value:
if isinstance(value, basestring):
pp = lambda x: x.time()
return format_date_time(value, self.formats, pp)
elif isinstance(value, datetime.datetime):
return value.time()
if value is not None and isinstance(value, datetime.timedelta):
return (datetime.datetime.min + value).time()
return value
hour = property(_date_part('hour'))
minute = property(_date_part('minute'))
second = property(_date_part('second'))
class TimestampField(IntegerField):
# Support second -> microsecond resolution.
valid_resolutions = [10**i for i in range(7)]
def __init__(self, *args, **kwargs):
self.resolution = kwargs.pop('resolution', 1) or 1
if self.resolution not in self.valid_resolutions:
raise ValueError('TimestampField resolution must be one of: %s' %
', '.join(str(i) for i in self.valid_resolutions))
self.utc = kwargs.pop('utc', False) or False
_dt = datetime.datetime
self._conv = _dt.utcfromtimestamp if self.utc else _dt.fromtimestamp
_default = _dt.utcnow if self.utc else _dt.now
kwargs.setdefault('default', _default)
super(TimestampField, self).__init__(*args, **kwargs)
def get_db_field(self):
# For second resolution we can get away (for a while) with using
# 4 bytes to store the timestamp (as long as they're not > ~2038).
# Otherwise we'll need to use a BigInteger type.
return (self.db_field if self.resolution == 1
else BigIntegerField.db_field)
def db_value(self, value):
if value is None:
return
if isinstance(value, datetime.datetime):
pass
elif isinstance(value, datetime.date):
value = datetime.datetime(value.year, value.month, value.day)
else:
return int(round(value * self.resolution))
if self.utc:
timestamp = calendar.timegm(value.utctimetuple())
else:
timestamp = time.mktime(value.timetuple())
timestamp += (value.microsecond * .000001)
if self.resolution > 1:
timestamp *= self.resolution
return int(round(timestamp))
def python_value(self, value):
if value is not None and isinstance(value, (int, float, long)):
if value == 0:
return
elif self.resolution > 1:
ticks_to_microsecond = 1000000 // self.resolution
value, ticks = divmod(value, self.resolution)
microseconds = ticks * ticks_to_microsecond
return self._conv(value).replace(microsecond=microseconds)
else:
return self._conv(value)
return value
class BooleanField(Field):
db_field = 'bool'
coerce = bool
class RelationDescriptor(FieldDescriptor):
"""Foreign-key abstraction to replace a related PK with a related model."""
def __init__(self, field, rel_model):
self.rel_model = rel_model
super(RelationDescriptor, self).__init__(field)
def get_object_or_id(self, instance):
rel_id = instance._data.get(self.att_name)
if rel_id is not None or self.att_name in instance._obj_cache:
if self.att_name not in instance._obj_cache:
obj = self.rel_model.get(self.field.to_field == rel_id)
instance._obj_cache[self.att_name] = obj
return instance._obj_cache[self.att_name]
elif not self.field.null:
raise self.rel_model.DoesNotExist
return rel_id
def __get__(self, instance, instance_type=None):
if instance is not None:
return self.get_object_or_id(instance)
return self.field
def __set__(self, instance, value):
if isinstance(value, self.rel_model):
instance._data[self.att_name] = getattr(
value, self.field.to_field.name)
instance._obj_cache[self.att_name] = value
else:
orig_value = instance._data.get(self.att_name)
instance._data[self.att_name] = value
if orig_value != value and self.att_name in instance._obj_cache:
del instance._obj_cache[self.att_name]
instance._dirty.add(self.att_name)
class ReverseRelationDescriptor(object):
"""Back-reference to expose related objects as a `SelectQuery`."""
def __init__(self, field):
self.field = field
self.rel_model = field.model_class
def __get__(self, instance, instance_type=None):
if instance is not None:
return self.rel_model.select().where(
self.field == getattr(instance, self.field.to_field.name))
return self
class ObjectIdDescriptor(object):
"""Gives direct access to the underlying id"""
def __init__(self, field):
self.attr_name = field.name
self.field = weakref.ref(field)
def __get__(self, instance, instance_type=None):
if instance is not None:
return instance._data.get(self.attr_name)
return self.field()
def __set__(self, instance, value):
setattr(instance, self.attr_name, value)
class ForeignKeyField(IntegerField):
def __init__(self, rel_model, related_name=None, on_delete=None,
on_update=None, extra=None, to_field=None, *args, **kwargs):
if rel_model != 'self' and not \
isinstance(rel_model, (Proxy, DeferredRelation)) and not \
issubclass(rel_model, Model):
raise TypeError('Unexpected value for `rel_model`. Expected '
'`Model`, `Proxy`, `DeferredRelation`, or "self"')
self.rel_model = rel_model
self._related_name = related_name
self.deferred = isinstance(rel_model, (Proxy, DeferredRelation))
self.on_delete = on_delete
self.on_update = on_update
self.extra = extra
self.to_field = to_field
super(ForeignKeyField, self).__init__(*args, **kwargs)
def clone_base(self, **kwargs):
return super(ForeignKeyField, self).clone_base(
rel_model=self.rel_model,
related_name=self._get_related_name(),
on_delete=self.on_delete,
on_update=self.on_update,
extra=self.extra,
to_field=self.to_field,
**kwargs)
def _get_descriptor(self):
return RelationDescriptor(self, self.rel_model)
def _get_id_descriptor(self):
return ObjectIdDescriptor(self)
def _get_backref_descriptor(self):
return ReverseRelationDescriptor(self)
def _get_related_name(self):
if self._related_name and callable(self._related_name):
return self._related_name(self)
return self._related_name or ('%s_set' % self.model_class._meta.name)
def add_to_class(self, model_class, name):
if isinstance(self.rel_model, Proxy):
def callback(rel_model):
self.rel_model = rel_model
self.add_to_class(model_class, name)
self.rel_model.attach_callback(callback)
return
elif isinstance(self.rel_model, DeferredRelation):
self.rel_model.set_field(model_class, self, name)
return
self.name = name
self.model_class = model_class
self.db_column = obj_id_name = self.db_column or '%s_id' % self.name
if obj_id_name == self.name:
obj_id_name += '_id'
if not self.verbose_name:
self.verbose_name = re.sub('_+', ' ', name).title()
model_class._meta.add_field(self)
self.related_name = self._get_related_name()
if self.rel_model == 'self':
self.rel_model = self.model_class
if self.to_field is not None:
if not isinstance(self.to_field, Field):
self.to_field = getattr(self.rel_model, self.to_field)
else:
self.to_field = self.rel_model._meta.primary_key
# TODO: factor into separate method.
if model_class._meta.validate_backrefs:
def invalid(msg, **context):
context.update(
field='%s.%s' % (model_class._meta.name, name),
backref=self.related_name,
obj_id_name=obj_id_name)
raise AttributeError(msg % context)
if self.related_name in self.rel_model._meta.fields:
invalid('The related_name of %(field)s ("%(backref)s") '
'conflicts with a field of the same name.')
elif self.related_name in self.rel_model._meta.reverse_rel:
invalid('The related_name of %(field)s ("%(backref)s") '
'is already in use by another foreign key.')
if obj_id_name in model_class._meta.fields:
invalid('The object id descriptor of %(field)s conflicts '
'with a field named %(obj_id_name)s')
elif obj_id_name in model_class.__dict__:
invalid('Model attribute "%(obj_id_name)s" would be shadowed '
'by the object id descriptor of %(field)s.')
setattr(model_class, name, self._get_descriptor())
setattr(model_class, obj_id_name, self._get_id_descriptor())
setattr(self.rel_model,
self.related_name,
self._get_backref_descriptor())
self._is_bound = True
model_class._meta.rel[self.name] = self
self.rel_model._meta.reverse_rel[self.related_name] = self
def get_db_field(self):
"""
Overridden to ensure Foreign Keys use same column type as the primary
key they point to.
"""
if not isinstance(self.to_field, PrimaryKeyField):
return self.to_field.get_db_field()
return super(ForeignKeyField, self).get_db_field()
def get_modifiers(self):
if not isinstance(self.to_field, PrimaryKeyField):
return self.to_field.get_modifiers()
return super(ForeignKeyField, self).get_modifiers()
def coerce(self, value):
return self.to_field.coerce(value)
def db_value(self, value):
if isinstance(value, self.rel_model):
value = value._get_pk_value()
return self.to_field.db_value(value)
def python_value(self, value):
if isinstance(value, self.rel_model):
return value
return self.to_field.python_value(value)
class CompositeKey(object):
"""A primary key composed of multiple columns."""
sequence = None
def __init__(self, *field_names):
self.field_names = field_names
def add_to_class(self, model_class, name):
self.name = name
self.model_class = model_class
setattr(model_class, name, self)
def __get__(self, instance, instance_type=None):
if instance is not None:
return tuple([getattr(instance, field_name)
for field_name in self.field_names])
return self
def __set__(self, instance, value):
pass
def __eq__(self, other):
expressions = [(self.model_class._meta.fields[field] == value)
for field, value in zip(self.field_names, other)]
return reduce(operator.and_, expressions)
def __hash__(self):
return hash((self.model_class.__name__, self.field_names))
class AliasMap(object):
prefix = 't'
def __init__(self, start=0):
self._alias_map = {}
self._counter = start
def __repr__(self):
return '<AliasMap: %s>' % self._alias_map
def add(self, obj, alias=None):
if obj in self._alias_map:
return
self._counter += 1
self._alias_map[obj] = alias or '%s%s' % (self.prefix, self._counter)
def __getitem__(self, obj):
if obj not in self._alias_map:
self.add(obj)
return self._alias_map[obj]
def __contains__(self, obj):
return obj in self._alias_map
def update(self, alias_map):
if alias_map:
for obj, alias in alias_map._alias_map.items():
if obj not in self:
self._alias_map[obj] = alias
return self
class QueryCompiler(object):
# Mapping of `db_type` to actual column type used by database driver.
# Database classes may provide additional column types or overrides.
field_map = {
'bare': '',
'bigint': 'BIGINT',
'blob': 'BLOB',
'bool': 'SMALLINT',
'date': 'DATE',
'datetime': 'DATETIME',
'decimal': 'DECIMAL',
'double': 'REAL',
'fixed_char': 'CHAR',
'float': 'REAL',
'int': 'INTEGER',
'primary_key': 'INTEGER',
'smallint': 'SMALLINT',
'string': 'VARCHAR',
'text': 'TEXT',
'time': 'TIME',
}
# Mapping of OP. to actual SQL operation. For most databases this will be
# the same, but some column types or databases may support additional ops.
# Like `field_map`, Database classes may extend or override these.
op_map = {
OP.EQ: '=',
OP.LT: '<',
OP.LTE: '<=',
OP.GT: '>',
OP.GTE: '>=',
OP.NE: '!=',
OP.IN: 'IN',
OP.NOT_IN: 'NOT IN',
OP.IS: 'IS',
OP.IS_NOT: 'IS NOT',
OP.BIN_AND: '&',
OP.BIN_OR: '|',
OP.LIKE: 'LIKE',
OP.ILIKE: 'ILIKE',
OP.BETWEEN: 'BETWEEN',
OP.ADD: '+',
OP.SUB: '-',
OP.MUL: '*',
OP.DIV: '/',
OP.XOR: '#',
OP.AND: 'AND',
OP.OR: 'OR',
OP.MOD: '%',
OP.REGEXP: 'REGEXP',
OP.CONCAT: '||',
}
join_map = {
JOIN.INNER: 'INNER JOIN',
JOIN.LEFT_OUTER: 'LEFT OUTER JOIN',
JOIN.RIGHT_OUTER: 'RIGHT OUTER JOIN',
JOIN.FULL: 'FULL JOIN',
}
alias_map_class = AliasMap
def __init__(self, quote_char='"', interpolation='?', field_overrides=None,
op_overrides=None):
self.quote_char = quote_char
self.interpolation = interpolation
self._field_map = merge_dict(self.field_map, field_overrides or {})
self._op_map = merge_dict(self.op_map, op_overrides or {})
self._parse_map = self.get_parse_map()
self._unknown_types = set(['param'])
def get_parse_map(self):
# To avoid O(n) lookups when parsing nodes, use a lookup table for
# common node types O(1).
return {
'expression': self._parse_expression,
'param': self._parse_param,
'passthrough': self._parse_passthrough,
'func': self._parse_func,
'clause': self._parse_clause,
'entity': self._parse_entity,
'field': self._parse_field,
'sql': self._parse_sql,
'select_query': self._parse_select_query,
'compound_select_query': self._parse_compound_select_query,
'strip_parens': self._parse_strip_parens,
}
def quote(self, s):
return '%s%s%s' % (self.quote_char, s, self.quote_char)
def get_column_type(self, f):
return self._field_map[f] if f in self._field_map else f.upper()
def get_op(self, q):
return self._op_map[q]
def _sorted_fields(self, field_dict):
return sorted(field_dict.items(), key=lambda i: i[0]._sort_key)
def _parse_default(self, node, alias_map, conv):
return self.interpolation, [node]
def _parse_expression(self, node, alias_map, conv):
if isinstance(node.lhs, Field):
conv = node.lhs
lhs, lparams = self.parse_node(node.lhs, alias_map, conv)
rhs, rparams = self.parse_node(node.rhs, alias_map, conv)
if node.op == OP.IN and rhs == '()' and not rparams:
return ('0 = 1' if node.flat else '(0 = 1)'), []
template = '%s %s %s' if node.flat else '(%s %s %s)'
sql = template % (lhs, self.get_op(node.op), rhs)
return sql, lparams + rparams
def _parse_passthrough(self, node, alias_map, conv):
if node.adapt:
return self.parse_node(node.adapt(node.value), alias_map, None)
return self.interpolation, [node.value]
def _parse_param(self, node, alias_map, conv):
if node.adapt:
if conv and conv.db_value is node.adapt:
conv = None
return self.parse_node(node.adapt(node.value), alias_map, conv)
elif conv is not None:
return self.parse_node(conv.db_value(node.value), alias_map)
else:
return self.interpolation, [node.value]
def _parse_func(self, node, alias_map, conv):
conv = node._coerce and conv or None
sql, params = self.parse_node_list(node.arguments, alias_map, conv)
return '%s(%s)' % (node.name, strip_parens(sql)), params
def _parse_clause(self, node, alias_map, conv):
sql, params = self.parse_node_list(
node.nodes, alias_map, conv, node.glue)
if node.parens:
sql = '(%s)' % strip_parens(sql)
return sql, params
def _parse_entity(self, node, alias_map, conv):
return '.'.join(map(self.quote, node.path)), []
def _parse_sql(self, node, alias_map, conv):
return node.value, list(node.params)
def _parse_field(self, node, alias_map, conv):
if alias_map:
sql = '.'.join((
self.quote(alias_map[node.model_class]),
self.quote(node.db_column)))
else:
sql = self.quote(node.db_column)
return sql, []
def _parse_compound_select_query(self, node, alias_map, conv):
csq = 'compound_select_query'
if node.rhs._node_type == csq and node.lhs._node_type != csq:
first_q, second_q = node.rhs, node.lhs
inv = True
else:
first_q, second_q = node.lhs, node.rhs
inv = False
new_map = self.alias_map_class()
if first_q._node_type == csq:
new_map._counter = alias_map._counter
first, first_p = self.generate_select(first_q, new_map)
second, second_p = self.generate_select(
second_q,
self.calculate_alias_map(second_q, new_map))
if inv:
l, lp, r, rp = second, second_p, first, first_p
else:
l, lp, r, rp = first, first_p , second, second_p
# We add outer parentheses in the event the compound query is used in
# the `from_()` clause, in which case we'll need them.
if node.database.compound_select_parentheses:
sql = '((%s) %s (%s))' % (l, node.operator, r)
else:
sql = '(%s %s %s)' % (l, node.operator, r)
return sql, lp + rp
def _parse_select_query(self, node, alias_map, conv):
clone = node.clone()
if not node._explicit_selection:
if conv and isinstance(conv, ForeignKeyField):
select_field = conv.to_field
else:
select_field = clone.model_class._meta.primary_key
clone._select = (select_field,)
sub, params = self.generate_select(clone, alias_map)
return '(%s)' % strip_parens(sub), params
def _parse_strip_parens(self, node, alias_map, conv):
sql, params = self.parse_node(node.node, alias_map, conv)
return strip_parens(sql), params
def _parse(self, node, alias_map, conv):
# By default treat the incoming node as a raw value that should be
# parameterized.
node_type = getattr(node, '_node_type', None)
unknown = False
if node_type in self._parse_map:
sql, params = self._parse_map[node_type](node, alias_map, conv)
unknown = (node_type in self._unknown_types and
node.adapt is None and
conv is None)
elif isinstance(node, (list, tuple, set)):
# If you're wondering how to pass a list into your query, simply
# wrap it in Param().
sql, params = self.parse_node_list(node, alias_map, conv)
sql = '(%s)' % sql
elif isinstance(node, Model):
sql = self.interpolation
if conv and isinstance(conv, ForeignKeyField):
to_field = conv.to_field
if isinstance(to_field, ForeignKeyField):
value = conv.db_value(node)
else:
value = to_field.db_value(getattr(node, to_field.name))
else:
value = node._get_pk_value()
params = [value]
elif (isclass(node) and issubclass(node, Model)) or \
isinstance(node, ModelAlias):
entity = node.as_entity().alias(alias_map[node])
sql, params = self.parse_node(entity, alias_map, conv)
elif conv is not None:
value = conv.db_value(node)
sql, params, _ = self._parse(value, alias_map, None)
else:
sql, params = self._parse_default(node, alias_map, None)
unknown = True
return sql, params, unknown
def parse_node(self, node, alias_map=None, conv=None):
sql, params, unknown = self._parse(node, alias_map, conv)
if unknown and (conv is not None) and params:
params = [conv.db_value(i) for i in params]
if isinstance(node, Node):
if node._negated:
sql = 'NOT %s' % sql
if node._alias:
sql = ' '.join((sql, 'AS', node._alias))
if node._ordering:
sql = ' '.join((sql, node._ordering))
if params and any(isinstance(p, Node) for p in params):
clean_params = []
clean_sql = []
for idx, param in enumerate(params):
if isinstance(param, Node):
csql, cparams = self.parse_node(param)
return sql, params
def parse_node_list(self, nodes, alias_map, conv=None, glue=', '):
sql = []
params = []
for node in nodes:
node_sql, node_params = self.parse_node(node, alias_map, conv)
sql.append(node_sql)
params.extend(node_params)
return glue.join(sql), params
def calculate_alias_map(self, query, alias_map=None):
new_map = self.alias_map_class()
if alias_map is not None:
new_map._counter = alias_map._counter
new_map.add(query.model_class, query.model_class._meta.table_alias)
for src_model, joined_models in query._joins.items():
new_map.add(src_model, src_model._meta.table_alias)
for join_obj in joined_models:
if isinstance(join_obj.dest, Node):
new_map.add(join_obj.dest, join_obj.dest.alias)
else:
new_map.add(join_obj.dest, join_obj.dest._meta.table_alias)
return new_map.update(alias_map)
def build_query(self, clauses, alias_map=None):
return self.parse_node(Clause(*clauses), alias_map)
def generate_joins(self, joins, model_class, alias_map):
# Joins are implemented as an adjancency-list graph. Perform a
# depth-first search of the graph to generate all the necessary JOINs.
clauses = []
seen = set()
q = [model_class]
while q:
curr = q.pop()
if curr not in joins or curr in seen:
continue
seen.add(curr)
for join in joins[curr]:
src = curr
dest = join.dest
if isinstance(join.on, (Expression, Func, Clause, Entity)):
# Clear any alias on the join expression.
constraint = join.on.clone().alias()
else:
metadata = join.metadata
if metadata.is_backref:
fk_model = join.dest
pk_model = join.src
else:
fk_model = join.src
pk_model = join.dest
fk = metadata.foreign_key
if fk:
lhs = getattr(fk_model, fk.name)
rhs = getattr(pk_model, fk.to_field.name)
if metadata.is_backref:
lhs, rhs = rhs, lhs
constraint = (lhs == rhs)
else:
raise ValueError('Missing required join predicate.')
if isinstance(dest, Node):
# TODO: ensure alias?
dest_n = dest
else:
q.append(dest)
dest_n = dest.as_entity().alias(alias_map[dest])
join_type = join.get_join_type()
if join_type in self.join_map:
join_sql = SQL(self.join_map[join_type])
else:
join_sql = SQL(join_type)
clauses.append(
Clause(join_sql, dest_n, SQL('ON'), constraint))
return clauses
def generate_select(self, query, alias_map=None):
model = query.model_class
db = model._meta.database
alias_map = self.calculate_alias_map(query, alias_map)
if isinstance(query, CompoundSelect):
clauses = [_StripParens(query)]
else:
if not query._distinct:
clauses = [SQL('SELECT')]
else:
clauses = [SQL('SELECT DISTINCT')]
if query._distinct not in (True, False):
clauses += [SQL('ON'), EnclosedClause(*query._distinct)]
select_clause = Clause(*query._select)
select_clause.glue = ', '
clauses.extend((select_clause, SQL('FROM')))
if query._from is None:
clauses.append(model.as_entity().alias(alias_map[model]))
else:
clauses.append(CommaClause(*query._from))
if query._windows is not None:
clauses.append(SQL('WINDOW'))
clauses.append(CommaClause(*[
Clause(
SQL(window._alias),
SQL('AS'),
window.__sql__())
for window in query._windows]))
join_clauses = self.generate_joins(query._joins, model, alias_map)
if join_clauses:
clauses.extend(join_clauses)
if query._where is not None:
clauses.extend([SQL('WHERE'), query._where])
if query._group_by:
clauses.extend([SQL('GROUP BY'), CommaClause(*query._group_by)])
if query._having:
clauses.extend([SQL('HAVING'), query._having])
if query._order_by:
clauses.extend([SQL('ORDER BY'), CommaClause(*query._order_by)])
if query._limit is not None or (query._offset and db.limit_max):
limit = query._limit if query._limit is not None else db.limit_max
clauses.append(SQL('LIMIT %s' % limit))
if query._offset is not None:
clauses.append(SQL('OFFSET %s' % query._offset))
for_update, no_wait = query._for_update
if for_update:
stmt = 'FOR UPDATE NOWAIT' if no_wait else 'FOR UPDATE'
clauses.append(SQL(stmt))
return self.build_query(clauses, alias_map)
def generate_update(self, query):
model = query.model_class
alias_map = self.alias_map_class()
alias_map.add(model, model._meta.db_table)
if query._on_conflict:
statement = 'UPDATE OR %s' % query._on_conflict
else:
statement = 'UPDATE'
clauses = [SQL(statement), model.as_entity(), SQL('SET')]
update = []
for field, value in self._sorted_fields(query._update):
if not isinstance(value, (Node, Model)):
value = Param(value, adapt=field.db_value)
update.append(Expression(
field.as_entity(with_table=False),
OP.EQ,
value,
flat=True)) # No outer parens, no table alias.
clauses.append(CommaClause(*update))
if query._where:
clauses.extend([SQL('WHERE'), query._where])
if query._returning is not None:
returning_clause = Clause(*query._returning)
returning_clause.glue = ', '
clauses.extend([SQL('RETURNING'), returning_clause])
return self.build_query(clauses, alias_map)
def _get_field_clause(self, fields, clause_type=EnclosedClause):
return clause_type(*[
field.as_entity(with_table=False) for field in fields])
def generate_insert(self, query):
model = query.model_class
meta = model._meta
alias_map = self.alias_map_class()
alias_map.add(model, model._meta.db_table)
if query._upsert:
statement = meta.database.upsert_sql
elif query._on_conflict:
statement = 'INSERT OR %s INTO' % query._on_conflict
else:
statement = 'INSERT INTO'
clauses = [SQL(statement), model.as_entity()]
if query._query is not None:
# This INSERT query is of the form INSERT INTO ... SELECT FROM.
if query._fields:
clauses.append(self._get_field_clause(query._fields))
clauses.append(_StripParens(query._query))
elif query._rows is not None:
fields, value_clauses = [], []
have_fields = False
for row_dict in query._iter_rows():
if not have_fields:
fields = sorted(
row_dict.keys(), key=operator.attrgetter('_sort_key'))
have_fields = True
values = []
for field in fields:
value = row_dict[field]
if not isinstance(value, (Node, Model)):
value = Param(value, adapt=field.db_value)
values.append(value)
value_clauses.append(EnclosedClause(*values))
if fields:
clauses.extend([
self._get_field_clause(fields),
SQL('VALUES'),
CommaClause(*value_clauses)])
elif query.model_class._meta.auto_increment:
# Bare insert, use default value for primary key.
clauses.append(query.database.default_insert_clause(
query.model_class))
if query.is_insert_returning:
clauses.extend([
SQL('RETURNING'),
self._get_field_clause(
meta.get_primary_key_fields(),
clause_type=CommaClause)])
elif query._returning is not None:
returning_clause = Clause(*query._returning)
returning_clause.glue = ', '
clauses.extend([SQL('RETURNING'), returning_clause])
return self.build_query(clauses, alias_map)
def generate_delete(self, query):
model = query.model_class
clauses = [SQL('DELETE FROM'), model.as_entity()]
if query._where:
clauses.extend([SQL('WHERE'), query._where])
if query._returning is not None:
returning_clause = Clause(*query._returning)
returning_clause.glue = ', '
clauses.extend([SQL('RETURNING'), returning_clause])
return self.build_query(clauses)
def field_definition(self, field):
column_type = self.get_column_type(field.get_db_field())
ddl = field.__ddl__(column_type)
return Clause(*ddl)
def foreign_key_constraint(self, field):
ddl = [
SQL('FOREIGN KEY'),
EnclosedClause(field.as_entity()),
SQL('REFERENCES'),
field.rel_model.as_entity(),
EnclosedClause(field.to_field.as_entity())]
if field.on_delete:
ddl.append(SQL('ON DELETE %s' % field.on_delete))
if field.on_update:
ddl.append(SQL('ON UPDATE %s' % field.on_update))
return Clause(*ddl)
def return_parsed_node(function_name):
# TODO: treat all `generate_` functions as returning clauses, instead
# of SQL/params.
def inner(self, *args, **kwargs):
fn = getattr(self, function_name)
return self.parse_node(fn(*args, **kwargs))
return inner
def _create_foreign_key(self, model_class, field, constraint=None):
constraint = constraint or 'fk_%s_%s_refs_%s' % (
model_class._meta.db_table,
field.db_column,
field.rel_model._meta.db_table)
fk_clause = self.foreign_key_constraint(field)
return Clause(
SQL('ALTER TABLE'),
model_class.as_entity(),
SQL('ADD CONSTRAINT'),
Entity(constraint),
*fk_clause.nodes)
create_foreign_key = return_parsed_node('_create_foreign_key')
def _create_table(self, model_class, safe=False):
statement = 'CREATE TABLE IF NOT EXISTS' if safe else 'CREATE TABLE'
meta = model_class._meta
columns, constraints = [], []
if meta.composite_key:
pk_cols = [meta.fields[f].as_entity()
for f in meta.primary_key.field_names]
constraints.append(Clause(
SQL('PRIMARY KEY'), EnclosedClause(*pk_cols)))
for field in meta.declared_fields:
columns.append(self.field_definition(field))
if isinstance(field, ForeignKeyField) and not field.deferred:
constraints.append(self.foreign_key_constraint(field))
if model_class._meta.constraints:
for constraint in model_class._meta.constraints:
if not isinstance(constraint, Node):
constraint = SQL(constraint)
constraints.append(constraint)
return Clause(
SQL(statement),
model_class.as_entity(),
EnclosedClause(*(columns + constraints)))
create_table = return_parsed_node('_create_table')
def _drop_table(self, model_class, fail_silently=False, cascade=False):
statement = 'DROP TABLE IF EXISTS' if fail_silently else 'DROP TABLE'
ddl = [SQL(statement), model_class.as_entity()]
if cascade:
ddl.append(SQL('CASCADE'))
return Clause(*ddl)
drop_table = return_parsed_node('_drop_table')
def _truncate_table(self, model_class, restart_identity=False,
cascade=False):
ddl = [SQL('TRUNCATE TABLE'), model_class.as_entity()]
if restart_identity:
ddl.append(SQL('RESTART IDENTITY'))
if cascade:
ddl.append(SQL('CASCADE'))
return Clause(*ddl)
truncate_table = return_parsed_node('_truncate_table')
def index_name(self, table, columns):
index = '%s_%s' % (table, '_'.join(columns))
if len(index) > 64:
index_hash = hashlib.md5(index.encode('utf-8')).hexdigest()
index = '%s_%s' % (table[:55], index_hash[:8]) # 55 + 1 + 8 = 64
return index
def _create_index(self, model_class, fields, unique, *extra):
tbl_name = model_class._meta.db_table
statement = 'CREATE UNIQUE INDEX' if unique else 'CREATE INDEX'
index_name = self.index_name(tbl_name, [f.db_column for f in fields])
return Clause(
SQL(statement),
Entity(index_name),
SQL('ON'),
model_class.as_entity(),
EnclosedClause(*[field.as_entity() for field in fields]),
*extra)
create_index = return_parsed_node('_create_index')
def _drop_index(self, model_class, fields, fail_silently=False):
tbl_name = model_class._meta.db_table
statement = 'DROP INDEX IF EXISTS' if fail_silently else 'DROP INDEX'
index_name = self.index_name(tbl_name, [f.db_column for f in fields])
return Clause(SQL(statement), Entity(index_name))
drop_index = return_parsed_node('_drop_index')
def _create_sequence(self, sequence_name):
return Clause(SQL('CREATE SEQUENCE'), Entity(sequence_name))
create_sequence = return_parsed_node('_create_sequence')
def _drop_sequence(self, sequence_name):
return Clause(SQL('DROP SEQUENCE'), Entity(sequence_name))
drop_sequence = return_parsed_node('_drop_sequence')
class SqliteQueryCompiler(QueryCompiler):
def truncate_table(self, model_class, restart_identity=False,
cascade=False):
return model_class.delete().sql()
class ResultIterator(object):
def __init__(self, qrw):
self.qrw = qrw
self._idx = 0
def next(self):
if self._idx < self.qrw._ct:
obj = self.qrw._result_cache[self._idx]
elif not self.qrw._populated:
obj = self.qrw.iterate()
self.qrw._result_cache.append(obj)
self.qrw._ct += 1
else:
raise StopIteration
self._idx += 1
return obj
__next__ = next
class QueryResultWrapper(object):
"""
Provides an iterator over the results of a raw Query, additionally doing
two things:
- converts rows from the database into python representations
- ensures that multiple iterations do not result in multiple queries
"""
def __init__(self, model, cursor, meta=None):
self.model = model
self.cursor = cursor
self._ct = 0
self._idx = 0
self._result_cache = []
self._populated = False
self._initialized = False
if meta is not None:
self.column_meta, self.join_meta = meta
else:
self.column_meta = self.join_meta = None
def __iter__(self):
if self._populated:
return iter(self._result_cache)
else:
return ResultIterator(self)
@property
def count(self):
self.fill_cache()
return self._ct
def __len__(self):
return self.count
def process_row(self, row):
return row
def iterate(self):
row = self.cursor.fetchone()
if not row:
self._populated = True
if not getattr(self.cursor, 'name', None):
self.cursor.close()
raise StopIteration
elif not self._initialized:
self.initialize(self.cursor.description)
self._initialized = True
return self.process_row(row)
def iterator(self):
while True:
yield self.iterate()
def next(self):
if self._idx < self._ct:
inst = self._result_cache[self._idx]
self._idx += 1
return inst
elif self._populated:
raise StopIteration
obj = self.iterate()
self._result_cache.append(obj)
self._ct += 1
self._idx += 1
return obj
__next__ = next
def fill_cache(self, n=None):
n = n or float('Inf')
if n < 0:
raise ValueError('Negative values are not supported.')
self._idx = self._ct
while not self._populated and (n > self._ct):
try:
next(self)
except StopIteration:
break
class ExtQueryResultWrapper(QueryResultWrapper):
def initialize(self, description):
n_cols = len(description)
self.conv = conv = []
if self.column_meta is not None:
n_meta = len(self.column_meta)
for i, node in enumerate(self.column_meta):
if not self._initialize_node(node, i):
self._initialize_by_name(description[i][0], i)
if n_cols == n_meta:
return
else:
i = 0
for i in range(i, n_cols):
self._initialize_by_name(description[i][0], i)
def _initialize_by_name(self, name, i):
model_cols = self.model._meta.columns
if name in model_cols:
field = model_cols[name]
self.conv.append((i, field.name, field.python_value))
else:
self.conv.append((i, name, None))
def _initialize_node(self, node, i):
if isinstance(node, Field):
self.conv.append((i, node._alias or node.name, node.python_value))
return True
elif isinstance(node, Func) and len(node.arguments):
arg = node.arguments[0]
if isinstance(arg, Field):
name = node._alias or arg._alias or arg.name
func = node._coerce and arg.python_value or None
self.conv.append((i, name, func))
return True
return False
class TuplesQueryResultWrapper(ExtQueryResultWrapper):
def process_row(self, row):
return tuple([col if self.conv[i][2] is None else self.conv[i][2](col)
for i, col in enumerate(row)])
if _TuplesQueryResultWrapper is None:
_TuplesQueryResultWrapper = TuplesQueryResultWrapper
class NaiveQueryResultWrapper(ExtQueryResultWrapper):
def process_row(self, row):
instance = self.model()
for i, column, f in self.conv:
setattr(instance, column, f(row[i]) if f is not None else row[i])
instance._prepare_instance()
return instance
if _ModelQueryResultWrapper is None:
_ModelQueryResultWrapper = NaiveQueryResultWrapper
class DictQueryResultWrapper(ExtQueryResultWrapper):
def process_row(self, row):
res = {}
for i, column, f in self.conv:
res[column] = f(row[i]) if f is not None else row[i]
return res
if _DictQueryResultWrapper is None:
_DictQueryResultWrapper = DictQueryResultWrapper
class ModelQueryResultWrapper(QueryResultWrapper):
def initialize(self, description):
self.column_map, model_set = self.generate_column_map()
self._col_set = set(col for col in self.column_meta
if isinstance(col, Field))
self.join_list = self.generate_join_list(model_set)
def generate_column_map(self):
column_map = []
models = set([self.model])
for i, node in enumerate(self.column_meta):
attr = conv = None
if isinstance(node, Field):
if isinstance(node, FieldProxy):
key = node._model_alias
constructor = node.model
conv = node.field_instance.python_value
else:
key = constructor = node.model_class
conv = node.python_value
attr = node._alias or node.name
else:
if node._bind_to is None:
key = constructor = self.model
else:
key = constructor = node._bind_to
if isinstance(node, Node) and node._alias:
attr = node._alias
elif isinstance(node, Entity):
attr = node.path[-1]
column_map.append((key, constructor, attr, conv))
models.add(key)
return column_map, models
def generate_join_list(self, models):
join_list = []
joins = self.join_meta
stack = [self.model]
while stack:
current = stack.pop()
if current not in joins:
continue
for join in joins[current]:
metadata = join.metadata
if metadata.dest in models or metadata.dest_model in models:
if metadata.foreign_key is not None:
fk_present = metadata.foreign_key in self._col_set
pk_present = metadata.primary_key in self._col_set
check = metadata.foreign_key.null and (fk_present or
pk_present)
else:
check = fk_present = pk_present = False
join_list.append((
metadata,
check,
fk_present,
pk_present))
stack.append(join.dest)
return join_list
def process_row(self, row):
collected = self.construct_instances(row)
instances = self.follow_joins(collected)
for i in instances:
i._prepare_instance()
return instances[0]
def construct_instances(self, row, keys=None):
collected_models = {}
for i, (key, constructor, attr, conv) in enumerate(self.column_map):
if keys is not None and key not in keys:
continue
value = row[i]
if key not in collected_models:
collected_models[key] = constructor()
instance = collected_models[key]
if attr is None:
attr = self.cursor.description[i][0]
setattr(instance, attr, value if conv is None else conv(value))
return collected_models
def follow_joins(self, collected):
prepared = [collected[self.model]]
for (metadata, check_null, fk_present, pk_present) in self.join_list:
inst = collected[metadata.src]
try:
joined_inst = collected[metadata.dest]
except KeyError:
joined_inst = collected[metadata.dest_model]
has_fk = True
if check_null:
if fk_present:
has_fk = inst._data.get(metadata.foreign_key.name)
elif pk_present:
has_fk = joined_inst._data.get(metadata.primary_key.name)
if not has_fk:
continue
# Can we populate a value on the joined instance using the current?
mpk = metadata.primary_key is not None
can_populate_joined_pk = (
mpk and
(metadata.attr in inst._data) and
(getattr(joined_inst, metadata.primary_key.name) is None))
if can_populate_joined_pk:
setattr(
joined_inst,
metadata.primary_key.name,
inst._data[metadata.attr])
if metadata.is_backref:
can_populate_joined_fk = (
mpk and
(metadata.foreign_key is not None) and
(getattr(inst, metadata.primary_key.name) is not None) and
(joined_inst._data.get(metadata.foreign_key.name) is None))
if can_populate_joined_fk:
setattr(
joined_inst,
metadata.foreign_key.name,
inst)
setattr(inst, metadata.attr, joined_inst)
prepared.append(joined_inst)
return prepared
JoinCache = namedtuple('JoinCache', ('metadata', 'attr'))
class AggregateQueryResultWrapper(ModelQueryResultWrapper):
def __init__(self, *args, **kwargs):
self._row = []
super(AggregateQueryResultWrapper, self).__init__(*args, **kwargs)
def initialize(self, description):
super(AggregateQueryResultWrapper, self).initialize(description)
# Collect the set of all models (and ModelAlias objects) queried.
self.all_models = set()
for key, _, _, _ in self.column_map:
self.all_models.add(key)
# Prepare data structures for analyzing unique rows. Also cache
# foreign key and attribute names for joined models.
self.models_with_aggregate = set()
self.back_references = {}
self.source_to_dest = {}
self.dest_to_source = {}
for (metadata, _, _, _) in self.join_list:
if metadata.is_backref:
att_name = metadata.foreign_key.related_name
else:
att_name = metadata.attr
is_backref = metadata.is_backref or metadata.is_self_join
if is_backref:
self.models_with_aggregate.add(metadata.src)
else:
self.dest_to_source.setdefault(metadata.dest, set())
self.dest_to_source[metadata.dest].add(metadata.src)
self.source_to_dest.setdefault(metadata.src, {})
self.source_to_dest[metadata.src][metadata.dest] = JoinCache(
metadata=metadata,
attr=metadata.alias or att_name)
# Determine which columns could contain "duplicate" data, e.g. if
# getting Users and their Tweets, this would be the User columns.
self.columns_to_compare = {}
key_to_columns = {}
for idx, (key, model_class, col_name, _) in enumerate(self.column_map):
if key in self.models_with_aggregate:
self.columns_to_compare.setdefault(key, [])
self.columns_to_compare[key].append((idx, col_name))
key_to_columns.setdefault(key, [])
key_to_columns[key].append((idx, col_name))
# Also compare columns for joins -> many-related model.
for model_or_alias in self.models_with_aggregate:
if model_or_alias not in self.columns_to_compare:
continue
sources = self.dest_to_source.get(model_or_alias, ())
for joined_model in sources:
self.columns_to_compare[model_or_alias].extend(
key_to_columns[joined_model])
def read_model_data(self, row):
models = {}
for model_class, column_data in self.columns_to_compare.items():
models[model_class] = []
for idx, col_name in column_data:
models[model_class].append(row[idx])
return models
def iterate(self):
if self._row:
row = self._row.pop()
else:
row = self.cursor.fetchone()
if not row:
self._populated = True
if not getattr(self.cursor, 'name', None):
self.cursor.close()
raise StopIteration
elif not self._initialized:
self.initialize(self.cursor.description)
self._initialized = True
def _get_pk(instance):
if instance._meta.composite_key:
return tuple([
instance._data[field_name]
for field_name in instance._meta.primary_key.field_names])
return instance._get_pk_value()
identity_map = {}
_constructed = self.construct_instances(row)
primary_instance = _constructed[self.model]
for model_or_alias, instance in _constructed.items():
identity_map[model_or_alias] = OrderedDict()
identity_map[model_or_alias][_get_pk(instance)] = instance
model_data = self.read_model_data(row)
while True:
cur_row = self.cursor.fetchone()
if cur_row is None:
break
duplicate_models = set()
cur_row_data = self.read_model_data(cur_row)
for model_class, data in cur_row_data.items():
if model_data[model_class] == data:
duplicate_models.add(model_class)
if not duplicate_models:
self._row.append(cur_row)
break
different_models = self.all_models - duplicate_models
new_instances = self.construct_instances(cur_row, different_models)
for model_or_alias, instance in new_instances.items():
# Do not include any instances which are comprised solely of
# NULL values.
all_none = True
for value in instance._data.values():
if value is not None:
all_none = False
if not all_none:
identity_map[model_or_alias][_get_pk(instance)] = instance
stack = [self.model]
instances = [primary_instance]
while stack:
current = stack.pop()
if current not in self.join_meta:
continue
for join in self.join_meta[current]:
try:
metadata, attr = self.source_to_dest[current][join.dest]
except KeyError:
continue
if metadata.is_backref or metadata.is_self_join:
for instance in identity_map[current].values():
setattr(instance, attr, [])
if join.dest not in identity_map:
continue
for pk, inst in identity_map[join.dest].items():
if pk is None:
continue
try:
# XXX: if no FK exists, unable to join.
joined_inst = identity_map[current][
inst._data[metadata.foreign_key.name]]
except KeyError:
continue
getattr(joined_inst, attr).append(inst)
instances.append(inst)
elif attr:
if join.dest not in identity_map:
continue
for pk, instance in identity_map[current].items():
# XXX: if no FK exists, unable to join.
joined_inst = identity_map[join.dest][
instance._data[metadata.foreign_key.name]]
setattr(
instance,
metadata.foreign_key.name,
joined_inst)
instances.append(joined_inst)
stack.append(join.dest)
for instance in instances:
instance._prepare_instance()
return primary_instance
class Query(Node):
"""Base class representing a database query on one or more tables."""
require_commit = True
def __init__(self, model_class):
super(Query, self).__init__()
self.model_class = model_class
self.database = model_class._meta.database
self._dirty = True
self._query_ctx = model_class
self._joins = {self.model_class: []} # Join graph as adjacency list.
self._where = None
def __repr__(self):
sql, params = self.sql()
return '%s %s %s' % (self.model_class, sql, params)
def clone(self):
query = type(self)(self.model_class)
query.database = self.database
return self._clone_attributes(query)
def _clone_attributes(self, query):
if self._where is not None:
query._where = self._where.clone()
query._joins = self._clone_joins()
query._query_ctx = self._query_ctx
return query
def _clone_joins(self):
return dict(
(mc, list(j)) for mc, j in self._joins.items())
def _add_query_clauses(self, initial, expressions, conjunction=None):
reduced = reduce(operator.and_, expressions)
if initial is None:
return reduced
conjunction = conjunction or operator.and_
return conjunction(initial, reduced)
def _model_shorthand(self, args):
accum = []
for arg in args:
if isinstance(arg, Node):
accum.append(arg)
elif isinstance(arg, Query):
accum.append(arg)
elif isinstance(arg, ModelAlias):
accum.extend(arg.get_proxy_fields())
elif isclass(arg) and issubclass(arg, Model):
accum.extend(arg._meta.declared_fields)
return accum
@returns_clone
def where(self, *expressions):
self._where = self._add_query_clauses(self._where, expressions)
@returns_clone
def orwhere(self, *expressions):
self._where = self._add_query_clauses(
self._where, expressions, operator.or_)
@returns_clone
def join(self, dest, join_type=None, on=None):
src = self._query_ctx
if not on:
require_join_condition = (
isinstance(dest, SelectQuery) or
(isclass(dest) and not src._meta.rel_exists(dest)))
if require_join_condition:
raise ValueError('A join condition must be specified.')
elif isinstance(on, basestring):
on = src._meta.fields[on]
self._joins.setdefault(src, [])
self._joins[src].append(Join(src, dest, join_type, on))
if not isinstance(dest, SelectQuery):
self._query_ctx = dest
@returns_clone
def switch(self, model_class=None):
"""Change or reset the query context."""
self._query_ctx = model_class or self.model_class
def ensure_join(self, lm, rm, on=None, **join_kwargs):
ctx = self._query_ctx
for join in self._joins.get(lm, []):
if join.dest == rm:
return self
return self.switch(lm).join(rm, on=on, **join_kwargs).switch(ctx)
def convert_dict_to_node(self, qdict):
accum = []
joins = []
relationship = (ForeignKeyField, ReverseRelationDescriptor)
for key, value in sorted(qdict.items()):
curr = self.model_class
if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP:
key, op = key.rsplit('__', 1)
op = DJANGO_MAP[op]
else:
op = OP.EQ
for piece in key.split('__'):
model_attr = getattr(curr, piece)
if isinstance(model_attr, relationship):
curr = model_attr.rel_model
joins.append(model_attr)
accum.append(Expression(model_attr, op, value))
return accum, joins
def filter(self, *args, **kwargs):
# normalize args and kwargs into a new expression
dq_node = Node()
if args:
dq_node &= reduce(operator.and_, [a.clone() for a in args])
if kwargs:
dq_node &= DQ(**kwargs)
# dq_node should now be an Expression, lhs = Node(), rhs = ...
q = deque([dq_node])
dq_joins = set()
while q:
curr = q.popleft()
if not isinstance(curr, Expression):
continue
for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)):
if isinstance(piece, DQ):
query, joins = self.convert_dict_to_node(piece.query)
dq_joins.update(joins)
expression = reduce(operator.and_, query)
# Apply values from the DQ object.
expression._negated = piece._negated
expression._alias = piece._alias
setattr(curr, side, expression)
else:
q.append(piece)
dq_node = dq_node.rhs
query = self.clone()
for field in dq_joins:
if isinstance(field, ForeignKeyField):
lm, rm = field.model_class, field.rel_model
field_obj = field
elif isinstance(field, ReverseRelationDescriptor):
lm, rm = field.field.rel_model, field.rel_model
field_obj = field.field
query = query.ensure_join(lm, rm, field_obj)
return query.where(dq_node)
def compiler(self):
return self.database.compiler()
def sql(self):
raise NotImplementedError
def _execute(self):
sql, params = self.sql()
return self.database.execute_sql(sql, params, self.require_commit)
def execute(self):
raise NotImplementedError
def scalar(self, as_tuple=False, convert=False):
if convert:
row = self.tuples().first()
else:
row = self._execute().fetchone()
if row and not as_tuple:
return row[0]
else:
return row
class RawQuery(Query):
"""
Execute a SQL query, returning a standard iterable interface that returns
model instances.
"""
def __init__(self, model, query, *params):
self._sql = query
self._params = list(params)
self._qr = None
self._tuples = False
self._dicts = False
super(RawQuery, self).__init__(model)
def clone(self):
query = RawQuery(self.model_class, self._sql, *self._params)
query._tuples = self._tuples
query._dicts = self._dicts
return query
join = not_allowed('joining')
where = not_allowed('where')
switch = not_allowed('switch')
@returns_clone
def tuples(self, tuples=True):
self._tuples = tuples
@returns_clone
def dicts(self, dicts=True):
self._dicts = dicts
def sql(self):
return self._sql, self._params
def execute(self):
if self._qr is None:
if self._tuples:
QRW = self.database.get_result_wrapper(RESULTS_TUPLES)
elif self._dicts:
QRW = self.database.get_result_wrapper(RESULTS_DICTS)
else:
QRW = self.database.get_result_wrapper(RESULTS_NAIVE)
self._qr = QRW(self.model_class, self._execute(), None)
return self._qr
def __iter__(self):
return iter(self.execute())
def allow_extend(orig, new_val, **kwargs):
extend = kwargs.pop('extend', False)
if kwargs:
raise ValueError('"extend" is the only valid keyword argument.')
if extend:
return ((orig or []) + new_val) or None
elif new_val:
return new_val
class SelectQuery(Query):
_node_type = 'select_query'
def __init__(self, model_class, *selection):
super(SelectQuery, self).__init__(model_class)
self.require_commit = self.database.commit_select
self.__select(*selection)
self._from = None
self._group_by = None
self._having = None
self._order_by = None
self._windows = None
self._limit = None
self._offset = None
self._distinct = False
self._for_update = (False, False)
self._naive = False
self._tuples = False
self._dicts = False
self._aggregate_rows = False
self._alias = None
self._qr = None
def _clone_attributes(self, query):
query = super(SelectQuery, self)._clone_attributes(query)
query._explicit_selection = self._explicit_selection
query._select = list(self._select)
if self._from is not None:
query._from = []
for f in self._from:
if isinstance(f, Node):
query._from.append(f.clone())
else:
query._from.append(f)
if self._group_by is not None:
query._group_by = list(self._group_by)
if self._having:
query._having = self._having.clone()
if self._order_by is not None:
query._order_by = list(self._order_by)
if self._windows is not None:
query._windows = list(self._windows)
query._limit = self._limit
query._offset = self._offset
query._distinct = self._distinct
query._for_update = self._for_update
query._naive = self._naive
query._tuples = self._tuples
query._dicts = self._dicts
query._aggregate_rows = self._aggregate_rows
query._alias = self._alias
return query
def compound_op(operator):
def inner(self, other):
supported_ops = self.model_class._meta.database.compound_operations
if operator not in supported_ops:
raise ValueError(
'Your database does not support %s' % operator)
return CompoundSelect(self.model_class, self, operator, other)
return inner
_compound_op_static = staticmethod(compound_op)
__or__ = compound_op('UNION')
__and__ = compound_op('INTERSECT')
__sub__ = compound_op('EXCEPT')
def __xor__(self, rhs):
# Symmetric difference, should just be (self | rhs) - (self & rhs)...
wrapped_rhs = self.model_class.select(SQL('*')).from_(
EnclosedClause((self & rhs)).alias('_')).order_by()
return (self | rhs) - wrapped_rhs
def union_all(self, rhs):
return SelectQuery._compound_op_static('UNION ALL')(self, rhs)
def __select(self, *selection):
self._explicit_selection = len(selection) > 0
selection = selection or self.model_class._meta.declared_fields
self._select = self._model_shorthand(selection)
select = returns_clone(__select)
@returns_clone
def from_(self, *args):
self._from = list(args) if args else None
@returns_clone
def group_by(self, *args, **kwargs):
self._group_by = self._model_shorthand(args) if args else None
@returns_clone
def having(self, *expressions):
self._having = self._add_query_clauses(self._having, expressions)
@returns_clone
def order_by(self, *args, **kwargs):
self._order_by = allow_extend(self._order_by, list(args), **kwargs)
@returns_clone
def window(self, *windows, **kwargs):
self._windows = allow_extend(self._windows, list(windows), **kwargs)
@returns_clone
def limit(self, lim):
self._limit = lim
@returns_clone
def offset(self, off):
self._offset = off
@returns_clone
def paginate(self, page, paginate_by=20):
if page > 0:
page -= 1
self._limit = paginate_by
self._offset = page * paginate_by
@returns_clone
def distinct(self, is_distinct=True):
self._distinct = is_distinct
@returns_clone
def for_update(self, for_update=True, nowait=False):
self._for_update = (for_update, nowait)
@returns_clone
def naive(self, naive=True):
self._naive = naive
@returns_clone
def tuples(self, tuples=True):
self._tuples = tuples
@returns_clone
def dicts(self, dicts=True):
self._dicts = dicts
@returns_clone
def aggregate_rows(self, aggregate_rows=True):
self._aggregate_rows = aggregate_rows
@returns_clone
def alias(self, alias=None):
self._alias = alias
def annotate(self, rel_model, annotation=None):
if annotation is None:
annotation = fn.Count(rel_model._meta.primary_key).alias('count')
if self._query_ctx == rel_model:
query = self.switch(self.model_class)
else:
query = self.clone()
query = query.ensure_join(query._query_ctx, rel_model)
if not query._group_by:
query._group_by = [x.alias() for x in query._select]
query._select = tuple(query._select) + (annotation,)
return query
def _aggregate(self, aggregation=None):
if aggregation is None:
aggregation = fn.Count(SQL('*'))
query = self.order_by()
query._select = [aggregation]
return query
def aggregate(self, aggregation=None, convert=True):
return self._aggregate(aggregation).scalar(convert=convert)
def count(self, clear_limit=False):
if self._distinct or self._group_by or self._limit or self._offset:
return self.wrapped_count(clear_limit=clear_limit)
# defaults to a count() of the primary key
return self.aggregate(convert=False) or 0
def wrapped_count(self, clear_limit=False):
clone = self.order_by()
if clear_limit:
clone._limit = clone._offset = None
sql, params = clone.sql()
wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql
rq = self.model_class.raw(wrapped, *params)
return rq.scalar() or 0
def exists(self):
clone = self.paginate(1, 1)
clone._select = [SQL('1')]
return bool(clone.scalar())
def get(self):
clone = self.paginate(1, 1)
try:
return next(clone.execute())
except StopIteration:
raise self.model_class.DoesNotExist(
'Instance matching query does not exist:\nSQL: %s\nPARAMS: %s'
% self.sql())
def peek(self, n=1):
res = self.execute()
res.fill_cache(n)
models = res._result_cache[:n]
if models:
return models[0] if n == 1 else models
def first(self, n=1):
if self._limit != n:
self._limit = n
self._dirty = True
return self.peek(n=n)
def sql(self):
return self.compiler().generate_select(self)
def verify_naive(self):
model_class = self.model_class
for node in self._select:
if isinstance(node, Field) and node.model_class != model_class:
return False
elif isinstance(node, Node) and node._bind_to is not None:
if node._bind_to != model_class:
return False
return True
def get_query_meta(self):
return (self._select, self._joins)
def _get_result_wrapper(self):
if self._tuples:
return self.database.get_result_wrapper(RESULTS_TUPLES)
elif self._dicts:
return self.database.get_result_wrapper(RESULTS_DICTS)
elif self._naive or not self._joins or self.verify_naive():
return self.database.get_result_wrapper(RESULTS_NAIVE)
elif self._aggregate_rows:
return self.database.get_result_wrapper(RESULTS_AGGREGATE_MODELS)
else:
return self.database.get_result_wrapper(RESULTS_MODELS)
def execute(self):
if self._dirty or self._qr is None:
model_class = self.model_class
query_meta = self.get_query_meta()
ResultWrapper = self._get_result_wrapper()
self._qr = ResultWrapper(model_class, self._execute(), query_meta)
self._dirty = False
return self._qr
else:
return self._qr
def __iter__(self):
return iter(self.execute())
def iterator(self):
return iter(self.execute().iterator())
def __getitem__(self, value):
res = self.execute()
if isinstance(value, slice):
index = value.stop
else:
index = value
if index is not None:
index = index + 1 if index >= 0 else None
res.fill_cache(index)
return res._result_cache[value]
def __len__(self):
return len(self.execute())
if PY3:
def __hash__(self):
return id(self)
class NoopSelectQuery(SelectQuery):
def sql(self):
return (self.database.get_noop_sql(), ())
def get_query_meta(self):
return None, None
def _get_result_wrapper(self):
return self.database.get_result_wrapper(RESULTS_TUPLES)
class CompoundSelect(SelectQuery):
_node_type = 'compound_select_query'
def __init__(self, model_class, lhs=None, operator=None, rhs=None):
self.lhs = lhs
self.operator = operator
self.rhs = rhs
super(CompoundSelect, self).__init__(model_class, [])
def _clone_attributes(self, query):
query = super(CompoundSelect, self)._clone_attributes(query)
query.lhs = self.lhs
query.operator = self.operator
query.rhs = self.rhs
return query
def count(self, clear_limit=False):
return self.wrapped_count(clear_limit=clear_limit)
def get_query_meta(self):
return self.lhs.get_query_meta()
def verify_naive(self):
return self.lhs.verify_naive() and self.rhs.verify_naive()
def _get_result_wrapper(self):
if self._tuples:
return self.database.get_result_wrapper(RESULTS_TUPLES)
elif self._dicts:
return self.database.get_result_wrapper(RESULTS_DICTS)
elif self._aggregate_rows:
return self.database.get_result_wrapper(RESULTS_AGGREGATE_MODELS)
has_joins = self.lhs._joins or self.rhs._joins
is_naive = self.lhs._naive or self.rhs._naive or self._naive
if is_naive or not has_joins or self.verify_naive():
return self.database.get_result_wrapper(RESULTS_NAIVE)
else:
return self.database.get_result_wrapper(RESULTS_MODELS)
class _WriteQuery(Query):
def __init__(self, model_class):
self._returning = None
self._tuples = False
self._dicts = False
self._qr = None
super(_WriteQuery, self).__init__(model_class)
def _clone_attributes(self, query):
query = super(_WriteQuery, self)._clone_attributes(query)
if self._returning:
query._returning = list(self._returning)
query._tuples = self._tuples
query._dicts = self._dicts
return query
def requires_returning(method):
def inner(self, *args, **kwargs):
db = self.model_class._meta.database
if not db.returning_clause:
raise ValueError('RETURNING is not supported by your '
'database: %s' % type(db))
return method(self, *args, **kwargs)
return inner
@requires_returning
@returns_clone
def returning(self, *selection):
if len(selection) == 1 and selection[0] is None:
self._returning = None
else:
if not selection:
selection = self.model_class._meta.declared_fields
self._returning = self._model_shorthand(selection)
@requires_returning
@returns_clone
def tuples(self, tuples=True):
self._tuples = tuples
@requires_returning
@returns_clone
def dicts(self, dicts=True):
self._dicts = dicts
def get_result_wrapper(self):
if self._returning is not None:
if self._tuples:
return self.database.get_result_wrapper(RESULTS_TUPLES)
elif self._dicts:
return self.database.get_result_wrapper(RESULTS_DICTS)
return self.database.get_result_wrapper(RESULTS_NAIVE)
def _execute_with_result_wrapper(self):
ResultWrapper = self.get_result_wrapper()
meta = (self._returning, {self.model_class: []})
self._qr = ResultWrapper(self.model_class, self._execute(), meta)
return self._qr
class UpdateQuery(_WriteQuery):
def __init__(self, model_class, update=None):
self._update = update
self._on_conflict = None
super(UpdateQuery, self).__init__(model_class)
def _clone_attributes(self, query):
query = super(UpdateQuery, self)._clone_attributes(query)
query._update = dict(self._update)
query._on_conflict = self._on_conflict
return query
@returns_clone
def on_conflict(self, action=None):
self._on_conflict = action
join = not_allowed('joining')
def sql(self):
return self.compiler().generate_update(self)
def execute(self):
if self._returning is not None and self._qr is None:
return self._execute_with_result_wrapper()
elif self._qr is not None:
return self._qr
else:
return self.database.rows_affected(self._execute())
def __iter__(self):
if not self.model_class._meta.database.returning_clause:
raise ValueError('UPDATE queries cannot be iterated over unless '
'they specify a RETURNING clause, which is not '
'supported by your database.')
return iter(self.execute())
def iterator(self):
return iter(self.execute().iterator())
class InsertQuery(_WriteQuery):
def __init__(self, model_class, field_dict=None, rows=None,
fields=None, query=None, validate_fields=False):
super(InsertQuery, self).__init__(model_class)
self._upsert = False
self._is_multi_row_insert = rows is not None or query is not None
self._return_id_list = False
if rows is not None:
self._rows = rows
else:
self._rows = [field_dict or {}]
self._fields = fields
self._query = query
self._validate_fields = validate_fields
self._on_conflict = None
def _iter_rows(self):
model_meta = self.model_class._meta
if self._validate_fields:
valid_fields = model_meta.valid_fields
def validate_field(field):
if field not in valid_fields:
raise KeyError('"%s" is not a recognized field.' % field)
defaults = model_meta._default_dict
callables = model_meta._default_callables
for row_dict in self._rows:
field_row = defaults.copy()
seen = set()
for key in row_dict:
if self._validate_fields:
validate_field(key)
if key in model_meta.fields:
field = model_meta.fields[key]
else:
field = key
field_row[field] = row_dict[key]
seen.add(field)
if callables:
for field in callables:
if field not in seen:
field_row[field] = callables[field]()
yield field_row
def _clone_attributes(self, query):
query = super(InsertQuery, self)._clone_attributes(query)
query._rows = self._rows
query._upsert = self._upsert
query._is_multi_row_insert = self._is_multi_row_insert
query._fields = self._fields
query._query = self._query
query._return_id_list = self._return_id_list
query._validate_fields = self._validate_fields
query._on_conflict = self._on_conflict
return query
join = not_allowed('joining')
where = not_allowed('where clause')
@returns_clone
def upsert(self, upsert=True):
self._upsert = upsert
@returns_clone
def on_conflict(self, action=None):
self._on_conflict = action
@returns_clone
def return_id_list(self, return_id_list=True):
self._return_id_list = return_id_list
@property
def is_insert_returning(self):
if self.database.insert_returning:
if not self._is_multi_row_insert or self._return_id_list:
return True
return False
def sql(self):
return self.compiler().generate_insert(self)
def _insert_with_loop(self):
id_list = []
last_id = None
return_id_list = self._return_id_list
for row in self._rows:
last_id = (InsertQuery(self.model_class, row)
.upsert(self._upsert)
.execute())
if return_id_list:
id_list.append(last_id)
if return_id_list:
return id_list
else:
return last_id
def execute(self):
insert_with_loop = (
self._is_multi_row_insert and
self._query is None and
self._returning is None and
not self.database.insert_many)
if insert_with_loop:
return self._insert_with_loop()
if self._returning is not None and self._qr is None:
return self._execute_with_result_wrapper()
elif self._qr is not None:
return self._qr
else:
cursor = self._execute()
if not self._is_multi_row_insert:
if self.database.insert_returning:
pk_row = cursor.fetchone()
meta = self.model_class._meta
clean_data = [
field.python_value(column)
for field, column
in zip(meta.get_primary_key_fields(), pk_row)]
if self.model_class._meta.composite_key:
return clean_data
return clean_data[0]
return self.database.last_insert_id(cursor, self.model_class)
elif self._return_id_list:
return map(operator.itemgetter(0), cursor.fetchall())
else:
return True
class DeleteQuery(_WriteQuery):
join = not_allowed('joining')
def sql(self):
return self.compiler().generate_delete(self)
def execute(self):
if self._returning is not None and self._qr is None:
return self._execute_with_result_wrapper()
elif self._qr is not None:
return self._qr
else:
return self.database.rows_affected(self._execute())
IndexMetadata = namedtuple(
'IndexMetadata',
('name', 'sql', 'columns', 'unique', 'table'))
ColumnMetadata = namedtuple(
'ColumnMetadata',
('name', 'data_type', 'null', 'primary_key', 'table'))
ForeignKeyMetadata = namedtuple(
'ForeignKeyMetadata',
('column', 'dest_table', 'dest_column', 'table'))
class PeeweeException(Exception): pass
class ImproperlyConfigured(PeeweeException): pass
class DatabaseError(PeeweeException): pass
class DataError(DatabaseError): pass
class IntegrityError(DatabaseError): pass
class InterfaceError(PeeweeException): pass
class InternalError(DatabaseError): pass
class NotSupportedError(DatabaseError): pass
class OperationalError(DatabaseError): pass
class ProgrammingError(DatabaseError): pass
class ExceptionWrapper(object):
__slots__ = ['exceptions']
def __init__(self, exceptions):
self.exceptions = exceptions
def __enter__(self): pass
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
return
if exc_type.__name__ in self.exceptions:
new_type = self.exceptions[exc_type.__name__]
if PY26:
exc_args = exc_value
else:
exc_args = exc_value.args
reraise(new_type, new_type(*exc_args), traceback)
class _BaseConnectionLocal(object):
def __init__(self, **kwargs):
super(_BaseConnectionLocal, self).__init__(**kwargs)
self.autocommit = None
self.closed = True
self.conn = None
self.context_stack = []
self.transactions = []
class _ConnectionLocal(_BaseConnectionLocal, threading.local):
pass
class Database(object):
commit_select = False
compiler_class = QueryCompiler
compound_operations = ['UNION', 'INTERSECT', 'EXCEPT', 'UNION ALL']
compound_select_parentheses = False
distinct_on = False
drop_cascade = False
field_overrides = {}
foreign_keys = True
for_update = False
for_update_nowait = False
insert_many = True
insert_returning = False
interpolation = '?'
limit_max = None
op_overrides = {}
quote_char = '"'
reserved_tables = []
returning_clause = False
savepoints = True
sequences = False
subquery_delete_same_table = True
upsert_sql = None
window_functions = False
exceptions = {
'ConstraintError': IntegrityError,
'DatabaseError': DatabaseError,
'DataError': DataError,
'IntegrityError': IntegrityError,
'InterfaceError': InterfaceError,
'InternalError': InternalError,
'NotSupportedError': NotSupportedError,
'OperationalError': OperationalError,
'ProgrammingError': ProgrammingError}
def __init__(self, database, threadlocals=True, autocommit=True,
fields=None, ops=None, autorollback=False, use_speedups=True,
**connect_kwargs):
self.connect_kwargs = {}
if threadlocals:
self._local = _ConnectionLocal()
else:
self._local = _BaseConnectionLocal()
self.init(database, **connect_kwargs)
self._conn_lock = threading.Lock()
self.autocommit = autocommit
self.autorollback = autorollback
self.use_speedups = use_speedups
self.field_overrides = merge_dict(self.field_overrides, fields or {})
self.op_overrides = merge_dict(self.op_overrides, ops or {})
def init(self, database, **connect_kwargs):
if not self.is_closed():
self.close()
self.deferred = database is None
self.database = database
self.connect_kwargs.update(connect_kwargs)
def exception_wrapper(self):
return ExceptionWrapper(self.exceptions)
def connect(self):
with self._conn_lock:
if self.deferred:
raise Exception('Error, database not properly initialized '
'before opening connection')
with self.exception_wrapper():
self._local.conn = self._connect(
self.database,
**self.connect_kwargs)
self._local.closed = False
self.initialize_connection(self._local.conn)
def initialize_connection(self, conn):
pass
def close(self):
with self._conn_lock:
if self.deferred:
raise Exception('Error, database not properly initialized '
'before closing connection')
with self.exception_wrapper():
self._close(self._local.conn)
self._local.closed = True
def get_conn(self):
if self._local.context_stack:
conn = self._local.context_stack[-1].connection
if conn is not None:
return conn
if self._local.closed:
self.connect()
return self._local.conn
def is_closed(self):
return self._local.closed
def get_cursor(self):
return self.get_conn().cursor()
def _close(self, conn):
conn.close()
def _connect(self, database, **kwargs):
raise NotImplementedError
@classmethod
def register_fields(cls, fields):
cls.field_overrides = merge_dict(cls.field_overrides, fields)
@classmethod
def register_ops(cls, ops):
cls.op_overrides = merge_dict(cls.op_overrides, ops)
def get_result_wrapper(self, wrapper_type):
if wrapper_type == RESULTS_NAIVE:
return (_ModelQueryResultWrapper if self.use_speedups
else NaiveQueryResultWrapper)
elif wrapper_type == RESULTS_MODELS:
return ModelQueryResultWrapper
elif wrapper_type == RESULTS_TUPLES:
return (_TuplesQueryResultWrapper if self.use_speedups
else TuplesQueryResultWrapper)
elif wrapper_type == RESULTS_DICTS:
return (_DictQueryResultWrapper if self.use_speedups
else DictQueryResultWrapper)
elif wrapper_type == RESULTS_AGGREGATE_MODELS:
return AggregateQueryResultWrapper
else:
return (_ModelQueryResultWrapper if self.use_speedups
else NaiveQueryResultWrapper)
def last_insert_id(self, cursor, model):
if model._meta.auto_increment:
return cursor.lastrowid
def rows_affected(self, cursor):
return cursor.rowcount
def compiler(self):
return self.compiler_class(
self.quote_char, self.interpolation, self.field_overrides,
self.op_overrides)
def execute(self, clause):
return self.execute_sql(*self.compiler().parse_node(clause))
def execute_sql(self, sql, params=None, require_commit=True):
logger.debug((sql, params))
with self.exception_wrapper():
cursor = self.get_cursor()
try:
cursor.execute(sql, params or ())
except Exception:
if self.get_autocommit() and self.autorollback:
self.rollback()
raise
else:
if require_commit and self.get_autocommit():
self.commit()
return cursor
def begin(self):
pass
def commit(self):
self.get_conn().commit()
def rollback(self):
self.get_conn().rollback()
def set_autocommit(self, autocommit):
self._local.autocommit = autocommit
def get_autocommit(self):
if self._local.autocommit is None:
self.set_autocommit(self.autocommit)
return self._local.autocommit
def push_execution_context(self, transaction):
self._local.context_stack.append(transaction)
def pop_execution_context(self):
self._local.context_stack.pop()
def execution_context_depth(self):
return len(self._local.context_stack)
def execution_context(self, with_transaction=True):
return ExecutionContext(self, with_transaction=with_transaction)
def push_transaction(self, transaction):
self._local.transactions.append(transaction)
def pop_transaction(self):
self._local.transactions.pop()
def transaction_depth(self):
return len(self._local.transactions)
def transaction(self):
return transaction(self)
def commit_on_success(self, func):
@wraps(func)
def inner(*args, **kwargs):
with self.transaction():
return func(*args, **kwargs)
return inner
def savepoint(self, sid=None):
if not self.savepoints:
raise NotImplementedError
return savepoint(self, sid)
def atomic(self):
return _atomic(self)
def get_tables(self, schema=None):
raise NotImplementedError
def get_indexes(self, table, schema=None):
raise NotImplementedError
def get_columns(self, table, schema=None):
raise NotImplementedError
def get_primary_keys(self, table, schema=None):
raise NotImplementedError
def get_foreign_keys(self, table, schema=None):
raise NotImplementedError
def sequence_exists(self, seq):
raise NotImplementedError
def create_table(self, model_class, safe=False):
qc = self.compiler()
return self.execute_sql(*qc.create_table(model_class, safe))
def create_tables(self, models, safe=False):
create_model_tables(models, fail_silently=safe)
def create_index(self, model_class, fields, unique=False):
qc = self.compiler()
if not isinstance(fields, (list, tuple)):
raise ValueError('Fields passed to "create_index" must be a list '
'or tuple: "%s"' % fields)
fobjs = [
model_class._meta.fields[f] if isinstance(f, basestring) else f
for f in fields]
return self.execute_sql(*qc.create_index(model_class, fobjs, unique))
def drop_index(self, model_class, fields, safe=False):
qc = self.compiler()
if not isinstance(fields, (list, tuple)):
raise ValueError('Fields passed to "drop_index" must be a list '
'or tuple: "%s"' % fields)
fobjs = [
model_class._meta.fields[f] if isinstance(f, basestring) else f
for f in fields]
return self.execute_sql(*qc.drop_index(model_class, fobjs, safe))
def create_foreign_key(self, model_class, field, constraint=None):
qc = self.compiler()
return self.execute_sql(*qc.create_foreign_key(
model_class, field, constraint))
def create_sequence(self, seq):
if self.sequences:
qc = self.compiler()
return self.execute_sql(*qc.create_sequence(seq))
def drop_table(self, model_class, fail_silently=False, cascade=False):
qc = self.compiler()
return self.execute_sql(*qc.drop_table(
model_class, fail_silently, cascade))
def drop_tables(self, models, safe=False, cascade=False):
drop_model_tables(models, fail_silently=safe, cascade=cascade)
def truncate_table(self, model_class, restart_identity=False,
cascade=False):
qc = self.compiler()
return self.execute_sql(*qc.truncate_table(
model_class, restart_identity, cascade))
def truncate_tables(self, models, restart_identity=False, cascade=False):
for model in reversed(sort_models_topologically(models)):
model.truncate_table(restart_identity, cascade)
def drop_sequence(self, seq):
if self.sequences:
qc = self.compiler()
return self.execute_sql(*qc.drop_sequence(seq))
def extract_date(self, date_part, date_field):
return fn.EXTRACT(Clause(date_part, R('FROM'), date_field))
def truncate_date(self, date_part, date_field):
return fn.DATE_TRUNC(date_part, date_field)
def default_insert_clause(self, model_class):
return SQL('DEFAULT VALUES')
def get_noop_sql(self):
return 'SELECT 0 WHERE 0'
def get_binary_type(self):
return binary_construct
class SqliteDatabase(Database):
compiler_class = SqliteQueryCompiler
field_overrides = {
'bool': 'INTEGER',
'smallint': 'INTEGER',
'uuid': 'TEXT',
}
foreign_keys = False
insert_many = sqlite3 and sqlite3.sqlite_version_info >= (3, 7, 11, 0)
limit_max = -1
op_overrides = {
OP.LIKE: 'GLOB',
OP.ILIKE: 'LIKE',
}
upsert_sql = 'INSERT OR REPLACE INTO'
def __init__(self, database, pragmas=None, *args, **kwargs):
self._pragmas = pragmas or []
journal_mode = kwargs.pop('journal_mode', None) # Backwards-compat.
if journal_mode:
self._pragmas.append(('journal_mode', journal_mode))
super(SqliteDatabase, self).__init__(database, *args, **kwargs)
def _connect(self, database, **kwargs):
if not sqlite3:
raise ImproperlyConfigured('pysqlite or sqlite3 must be installed.')
conn = sqlite3.connect(database, **kwargs)
conn.isolation_level = None
try:
self._add_conn_hooks(conn)
except:
conn.close()
raise
return conn
def _add_conn_hooks(self, conn):
self._set_pragmas(conn)
conn.create_function('date_part', 2, _sqlite_date_part)
conn.create_function('date_trunc', 2, _sqlite_date_trunc)
conn.create_function('regexp', 2, _sqlite_regexp)
def _set_pragmas(self, conn):
if self._pragmas:
cursor = conn.cursor()
for pragma, value in self._pragmas:
cursor.execute('PRAGMA %s = %s;' % (pragma, value))
cursor.close()
def begin(self, lock_type='DEFERRED'):
self.execute_sql('BEGIN %s' % lock_type, require_commit=False)
def create_foreign_key(self, model_class, field, constraint=None):
raise OperationalError('SQLite does not support ALTER TABLE '
'statements to add constraints.')
def get_tables(self, schema=None):
cursor = self.execute_sql('SELECT name FROM sqlite_master WHERE '
'type = ? ORDER BY name;', ('table',))
return [row[0] for row in cursor.fetchall()]
def get_indexes(self, table, schema=None):
query = ('SELECT name, sql FROM sqlite_master '
'WHERE tbl_name = ? AND type = ? ORDER BY name')
cursor = self.execute_sql(query, (table, 'index'))
index_to_sql = dict(cursor.fetchall())
# Determine which indexes have a unique constraint.
unique_indexes = set()
cursor = self.execute_sql('PRAGMA index_list("%s")' % table)
for row in cursor.fetchall():
name = row[1]
is_unique = int(row[2]) == 1
if is_unique:
unique_indexes.add(name)
# Retrieve the indexed columns.
index_columns = {}
for index_name in sorted(index_to_sql):
cursor = self.execute_sql('PRAGMA index_info("%s")' % index_name)
index_columns[index_name] = [row[2] for row in cursor.fetchall()]
return [
IndexMetadata(
name,
index_to_sql[name],
index_columns[name],
name in unique_indexes,
table)
for name in sorted(index_to_sql)]
def get_columns(self, table, schema=None):
cursor = self.execute_sql('PRAGMA table_info("%s")' % table)
return [ColumnMetadata(row[1], row[2], not row[3], bool(row[5]), table)
for row in cursor.fetchall()]
def get_primary_keys(self, table, schema=None):
cursor = self.execute_sql('PRAGMA table_info("%s")' % table)
return [row[1] for row in cursor.fetchall() if row[-1]]
def get_foreign_keys(self, table, schema=None):
cursor = self.execute_sql('PRAGMA foreign_key_list("%s")' % table)
return [ForeignKeyMetadata(row[3], row[2], row[4], table)
for row in cursor.fetchall()]
def savepoint(self, sid=None):
return savepoint_sqlite(self, sid)
def extract_date(self, date_part, date_field):
return fn.date_part(date_part, date_field)
def truncate_date(self, date_part, date_field):
return fn.strftime(SQLITE_DATE_TRUNC_MAPPING[date_part], date_field)
def get_binary_type(self):
return sqlite3.Binary
class PostgresqlDatabase(Database):
commit_select = True
compound_select_parentheses = True
distinct_on = True
drop_cascade = True
field_overrides = {
'blob': 'BYTEA',
'bool': 'BOOLEAN',
'datetime': 'TIMESTAMP',
'decimal': 'NUMERIC',
'double': 'DOUBLE PRECISION',
'primary_key': 'SERIAL',
'uuid': 'UUID',
}
for_update = True
for_update_nowait = True
insert_returning = True
interpolation = '%s'
op_overrides = {
OP.REGEXP: '~',
}
reserved_tables = ['user']
returning_clause = True
sequences = True
window_functions = True
register_unicode = True
def _connect(self, database, encoding=None, **kwargs):
if not psycopg2:
raise ImproperlyConfigured('psycopg2 must be installed.')
conn = psycopg2.connect(database=database, **kwargs)
if self.register_unicode:
pg_extensions.register_type(pg_extensions.UNICODE, conn)
pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn)
if encoding:
conn.set_client_encoding(encoding)
return conn
def _get_pk_sequence(self, model):
meta = model._meta
if meta.primary_key is not False and meta.primary_key.sequence:
return meta.primary_key.sequence
elif meta.auto_increment:
return '%s_%s_seq' % (meta.db_table, meta.primary_key.db_column)
def last_insert_id(self, cursor, model):
sequence = self._get_pk_sequence(model)
if not sequence:
return
meta = model._meta
if meta.schema:
schema = '%s.' % meta.schema
else:
schema = ''
cursor.execute("SELECT CURRVAL('%s\"%s\"')" % (schema, sequence))
result = cursor.fetchone()[0]
if self.get_autocommit():
self.commit()
return result
def get_tables(self, schema='public'):
query = ('SELECT tablename FROM pg_catalog.pg_tables '
'WHERE schemaname = %s ORDER BY tablename')
return [r for r, in self.execute_sql(query, (schema,)).fetchall()]
def get_indexes(self, table, schema='public'):
query = """
SELECT
i.relname, idxs.indexdef, idx.indisunique,
array_to_string(array_agg(cols.attname), ',')
FROM pg_catalog.pg_class AS t
INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid
INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid
INNER JOIN pg_catalog.pg_indexes AS idxs ON
(idxs.tablename = t.relname AND idxs.indexname = i.relname)
LEFT OUTER JOIN pg_catalog.pg_attribute AS cols ON
(cols.attrelid = t.oid AND cols.attnum = ANY(idx.indkey))
WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s
GROUP BY i.relname, idxs.indexdef, idx.indisunique
ORDER BY idx.indisunique DESC, i.relname;"""
cursor = self.execute_sql(query, (table, 'r', schema))
return [IndexMetadata(row[0], row[1], row[3].split(','), row[2], table)
for row in cursor.fetchall()]
def get_columns(self, table, schema='public'):
query = """
SELECT column_name, is_nullable, data_type
FROM information_schema.columns
WHERE table_name = %s AND table_schema = %s
ORDER BY ordinal_position"""
cursor = self.execute_sql(query, (table, schema))
pks = set(self.get_primary_keys(table, schema))
return [ColumnMetadata(name, dt, null == 'YES', name in pks, table)
for name, null, dt in cursor.fetchall()]
def get_primary_keys(self, table, schema='public'):
query = """
SELECT kc.column_name
FROM information_schema.table_constraints AS tc
INNER JOIN information_schema.key_column_usage AS kc ON (
tc.table_name = kc.table_name AND
tc.table_schema = kc.table_schema AND
tc.constraint_name = kc.constraint_name)
WHERE
tc.constraint_type = %s AND
tc.table_name = %s AND
tc.table_schema = %s"""
cursor = self.execute_sql(query, ('PRIMARY KEY', table, schema))
return [row for row, in cursor.fetchall()]
def get_foreign_keys(self, table, schema='public'):
sql = """
SELECT
kcu.column_name, ccu.table_name, ccu.column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON (tc.constraint_name = kcu.constraint_name AND
tc.constraint_schema = kcu.constraint_schema)
JOIN information_schema.constraint_column_usage AS ccu
ON (ccu.constraint_name = tc.constraint_name AND
ccu.constraint_schema = tc.constraint_schema)
WHERE
tc.constraint_type = 'FOREIGN KEY' AND
tc.table_name = %s AND
tc.table_schema = %s"""
cursor = self.execute_sql(sql, (table, schema))
return [ForeignKeyMetadata(row[0], row[1], row[2], table)
for row in cursor.fetchall()]
def sequence_exists(self, sequence):
res = self.execute_sql("""
SELECT COUNT(*) FROM pg_class, pg_namespace
WHERE relkind='S'
AND pg_class.relnamespace = pg_namespace.oid
AND relname=%s""", (sequence,))
return bool(res.fetchone()[0])
def set_search_path(self, *search_path):
path_params = ','.join(['%s'] * len(search_path))
self.execute_sql('SET search_path TO %s' % path_params, search_path)
def get_noop_sql(self):
return 'SELECT 0 WHERE false'
def get_binary_type(self):
return psycopg2.Binary
class MySQLDatabase(Database):
commit_select = True
compound_operations = ['UNION', 'UNION ALL']
field_overrides = {
'bool': 'BOOL',
'decimal': 'NUMERIC',
'double': 'DOUBLE PRECISION',
'float': 'FLOAT',
'primary_key': 'INTEGER AUTO_INCREMENT',
'text': 'LONGTEXT',
'uuid': 'VARCHAR(40)',
}
for_update = True
interpolation = '%s'
limit_max = 2 ** 64 - 1 # MySQL quirk
op_overrides = {
OP.LIKE: 'LIKE BINARY',
OP.ILIKE: 'LIKE',
OP.XOR: 'XOR',
}
quote_char = '`'
subquery_delete_same_table = False
upsert_sql = 'REPLACE INTO'
def _connect(self, database, **kwargs):
if not mysql:
raise ImproperlyConfigured('MySQLdb or PyMySQL must be installed.')
conn_kwargs = {
'charset': 'utf8',
'use_unicode': True,
}
conn_kwargs.update(kwargs)
if 'password' in conn_kwargs:
conn_kwargs['passwd'] = conn_kwargs.pop('password')
return mysql.connect(db=database, **conn_kwargs)
def get_tables(self, schema=None):
return [row for row, in self.execute_sql('SHOW TABLES')]
def get_indexes(self, table, schema=None):
cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table)
unique = set()
indexes = {}
for row in cursor.fetchall():
if not row[1]:
unique.add(row[2])
indexes.setdefault(row[2], [])
indexes[row[2]].append(row[4])
return [IndexMetadata(name, None, indexes[name], name in unique, table)
for name in indexes]
def get_columns(self, table, schema=None):
sql = """
SELECT column_name, is_nullable, data_type
FROM information_schema.columns
WHERE table_name = %s AND table_schema = DATABASE()"""
cursor = self.execute_sql(sql, (table,))
pks = set(self.get_primary_keys(table))
return [ColumnMetadata(name, dt, null == 'YES', name in pks, table)
for name, null, dt in cursor.fetchall()]
def get_primary_keys(self, table, schema=None):
cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table)
return [row[4] for row in cursor.fetchall() if row[2] == 'PRIMARY']
def get_foreign_keys(self, table, schema=None):
query = """
SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL"""
cursor = self.execute_sql(query, (table,))
return [
ForeignKeyMetadata(column, dest_table, dest_column, table)
for column, dest_table, dest_column in cursor.fetchall()]
def extract_date(self, date_part, date_field):
return fn.EXTRACT(Clause(R(date_part), R('FROM'), date_field))
def truncate_date(self, date_part, date_field):
return fn.DATE_FORMAT(date_field, MYSQL_DATE_TRUNC_MAPPING[date_part])
def default_insert_clause(self, model_class):
return Clause(
EnclosedClause(model_class._meta.primary_key),
SQL('VALUES (DEFAULT)'))
def get_noop_sql(self):
return 'DO 0'
def get_binary_type(self):
return mysql.Binary
class _callable_context_manager(object):
def __call__(self, fn):
@wraps(fn)
def inner(*args, **kwargs):
with self:
return fn(*args, **kwargs)
return inner
class ExecutionContext(_callable_context_manager):
def __init__(self, database, with_transaction=True):
self.database = database
self.with_transaction = with_transaction
self.connection = None
def __enter__(self):
with self.database._conn_lock:
self.database.push_execution_context(self)
self.connection = self.database._connect(
self.database.database,
**self.database.connect_kwargs)
if self.with_transaction:
self.txn = self.database.transaction()
self.txn.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
with self.database._conn_lock:
if self.connection is None:
self.database.pop_execution_context()
else:
try:
if self.with_transaction:
if not exc_type:
self.txn.commit(False)
self.txn.__exit__(exc_type, exc_val, exc_tb)
finally:
self.database.pop_execution_context()
self.database._close(self.connection)
class Using(ExecutionContext):
def __init__(self, database, models, with_transaction=True):
super(Using, self).__init__(database, with_transaction)
self.models = models
def __enter__(self):
self._orig = []
for model in self.models:
self._orig.append(model._meta.database)
model._meta.database = self.database
return super(Using, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
super(Using, self).__exit__(exc_type, exc_val, exc_tb)
for i, model in enumerate(self.models):
model._meta.database = self._orig[i]
class _atomic(_callable_context_manager):
def __init__(self, db):
self.db = db
def __enter__(self):
if self.db.transaction_depth() == 0:
self._helper = self.db.transaction()
else:
self._helper = self.db.savepoint()
return self._helper.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
return self._helper.__exit__(exc_type, exc_val, exc_tb)
class transaction(_callable_context_manager):
def __init__(self, db):
self.db = db
def _begin(self):
self.db.begin()
def commit(self, begin=True):
self.db.commit()
if begin:
self._begin()
def rollback(self, begin=True):
self.db.rollback()
if begin:
self._begin()
def __enter__(self):
self._orig = self.db.get_autocommit()
self.db.set_autocommit(False)
if self.db.transaction_depth() == 0:
self._begin()
self.db.push_transaction(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type:
self.rollback(False)
elif self.db.transaction_depth() == 1:
try:
self.commit(False)
except:
self.rollback(False)
raise
finally:
self.db.set_autocommit(self._orig)
self.db.pop_transaction()
class savepoint(_callable_context_manager):
def __init__(self, db, sid=None):
self.db = db
_compiler = db.compiler()
self.sid = sid or 's' + uuid.uuid4().hex
self.quoted_sid = _compiler.quote(self.sid)
def _execute(self, query):
self.db.execute_sql(query, require_commit=False)
def commit(self):
self._execute('RELEASE SAVEPOINT %s;' % self.quoted_sid)
def rollback(self):
self._execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid)
def __enter__(self):
self._orig_autocommit = self.db.get_autocommit()
self.db.set_autocommit(False)
self._execute('SAVEPOINT %s;' % self.quoted_sid)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type:
self.rollback()
else:
try:
self.commit()
except:
self.rollback()
raise
finally:
self.db.set_autocommit(self._orig_autocommit)
class savepoint_sqlite(savepoint):
def __enter__(self):
conn = self.db.get_conn()
# For sqlite, the connection's isolation_level *must* be set to None.
# The act of setting it, though, will break any existing savepoints,
# so only write to it if necessary.
if conn.isolation_level is not None:
self._orig_isolation_level = conn.isolation_level
conn.isolation_level = None
else:
self._orig_isolation_level = None
return super(savepoint_sqlite, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
try:
return super(savepoint_sqlite, self).__exit__(
exc_type, exc_val, exc_tb)
finally:
if self._orig_isolation_level is not None:
self.db.get_conn().isolation_level = self._orig_isolation_level
class FieldProxy(Field):
def __init__(self, alias, field_instance):
self._model_alias = alias
self.model = self._model_alias.model_class
self.field_instance = field_instance
def clone_base(self):
return FieldProxy(self._model_alias, self.field_instance)
def coerce(self, value):
return self.field_instance.coerce(value)
def python_value(self, value):
return self.field_instance.python_value(value)
def db_value(self, value):
return self.field_instance.db_value(value)
def __getattr__(self, attr):
if attr == 'model_class':
return self._model_alias
return getattr(self.field_instance, attr)
class ModelAlias(object):
def __init__(self, model_class):
self.__dict__['model_class'] = model_class
def __getattr__(self, attr):
model_attr = getattr(self.model_class, attr)
if isinstance(model_attr, Field):
return FieldProxy(self, model_attr)
return model_attr
def __setattr__(self, attr, value):
raise AttributeError('Cannot set attributes on ModelAlias instances')
def get_proxy_fields(self, declared_fields=False):
mm = self.model_class._meta
fields = mm.declared_fields if declared_fields else mm.sorted_fields
return [FieldProxy(self, f) for f in fields]
def select(self, *selection):
if not selection:
selection = self.get_proxy_fields()
query = SelectQuery(self, *selection)
if self._meta.order_by:
query = query.order_by(*self._meta.order_by)
return query
def __call__(self, **kwargs):
return self.model_class(**kwargs)
if _SortedFieldList is None:
class _SortedFieldList(object):
__slots__ = ('_keys', '_items')
def __init__(self):
self._keys = []
self._items = []
def __getitem__(self, i):
return self._items[i]
def __iter__(self):
return iter(self._items)
def __contains__(self, item):
k = item._sort_key
i = bisect_left(self._keys, k)
j = bisect_right(self._keys, k)
return item in self._items[i:j]
def index(self, field):
return self._keys.index(field._sort_key)
def insert(self, item):
k = item._sort_key
i = bisect_left(self._keys, k)
self._keys.insert(i, k)
self._items.insert(i, item)
def remove(self, item):
idx = self.index(item)
del self._items[idx]
del self._keys[idx]
class DoesNotExist(Exception): pass
if sqlite3:
default_database = SqliteDatabase('peewee.db')
else:
default_database = None
class ModelOptions(object):
def __init__(self, cls, database=None, db_table=None, db_table_func=None,
indexes=None, order_by=None, primary_key=None,
table_alias=None, constraints=None, schema=None,
validate_backrefs=True, only_save_dirty=False, **kwargs):
self.model_class = cls
self.name = cls.__name__.lower()
self.fields = {}
self.columns = {}
self.defaults = {}
self._default_by_name = {}
self._default_dict = {}
self._default_callables = {}
self._default_callable_list = []
self._sorted_field_list = _SortedFieldList()
self.sorted_fields = []
self.sorted_field_names = []
self.valid_fields = set()
self.declared_fields = []
self.database = database if database is not None else default_database
self.db_table = db_table
self.db_table_func = db_table_func
self.indexes = list(indexes or [])
self.order_by = order_by
self.primary_key = primary_key
self.table_alias = table_alias
self.constraints = constraints
self.schema = schema
self.validate_backrefs = validate_backrefs
self.only_save_dirty = only_save_dirty
self.auto_increment = None
self.composite_key = False
self.rel = {}
self.reverse_rel = {}
for key, value in kwargs.items():
setattr(self, key, value)
self._additional_keys = set(kwargs.keys())
if self.db_table_func and not self.db_table:
self.db_table = self.db_table_func(cls)
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.name)
def prepared(self):
if self.order_by:
norm_order_by = []
for item in self.order_by:
if isinstance(item, Field):
prefix = '-' if item._ordering == 'DESC' else ''
item = prefix + item.name
field = self.fields[item.lstrip('-')]
if item.startswith('-'):
norm_order_by.append(field.desc())
else:
norm_order_by.append(field.asc())
self.order_by = norm_order_by
def _update_field_lists(self):
self.sorted_fields = list(self._sorted_field_list)
self.sorted_field_names = [f.name for f in self.sorted_fields]
self.valid_fields = (set(self.fields.keys()) |
set(self.fields.values()) |
set((self.primary_key,)))
self.declared_fields = [field for field in self.sorted_fields
if not isinstance(field, _AutoPrimaryKeyField)]
def add_field(self, field):
self.remove_field(field.name)
self.fields[field.name] = field
self.columns[field.db_column] = field
self._sorted_field_list.insert(field)
self._update_field_lists()
if field.default is not None:
self.defaults[field] = field.default
if callable(field.default):
self._default_callables[field] = field.default
self._default_callable_list.append((field.name, field.default))
else:
self._default_dict[field] = field.default
self._default_by_name[field.name] = field.default
def remove_field(self, field_name):
if field_name not in self.fields:
return
original = self.fields.pop(field_name)
del self.columns[original.db_column]
self._sorted_field_list.remove(original)
self._update_field_lists()
if original.default is not None:
del self.defaults[original]
if self._default_callables.pop(original, None):
for i, (name, _) in enumerate(self._default_callable_list):
if name == field_name:
self._default_callable_list.pop(i)
break
else:
self._default_dict.pop(original, None)
self._default_by_name.pop(original.name, None)
def get_default_dict(self):
dd = self._default_by_name.copy()
for field_name, default in self._default_callable_list:
dd[field_name] = default()
return dd
def get_field_index(self, field):
try:
return self._sorted_field_list.index(field)
except ValueError:
return -1
def get_primary_key_fields(self):
if self.composite_key:
return [
self.fields[field_name]
for field_name in self.primary_key.field_names]
return [self.primary_key]
def rel_for_model(self, model, field_obj=None, multi=False):
is_field = isinstance(field_obj, Field)
is_node = not is_field and isinstance(field_obj, Node)
if multi:
accum = []
for field in self.sorted_fields:
if isinstance(field, ForeignKeyField) and field.rel_model == model:
is_match = (
(field_obj is None) or
(is_field and field_obj.name == field.name) or
(is_node and field_obj._alias == field.name))
if is_match:
if not multi:
return field
accum.append(field)
if multi:
return accum
def reverse_rel_for_model(self, model, field_obj=None, multi=False):
return model._meta.rel_for_model(self.model_class, field_obj, multi)
def rel_exists(self, model):
return self.rel_for_model(model) or self.reverse_rel_for_model(model)
def related_models(self, backrefs=False):
models = []
stack = [self.model_class]
while stack:
model = stack.pop()
if model in models:
continue
models.append(model)
for fk in model._meta.rel.values():
stack.append(fk.rel_model)
if backrefs:
for fk in model._meta.reverse_rel.values():
stack.append(fk.model_class)
return models
class BaseModel(type):
inheritable = set([
'constraints', 'database', 'db_table_func', 'indexes', 'order_by',
'primary_key', 'schema', 'validate_backrefs', 'only_save_dirty'])
def __new__(cls, name, bases, attrs):
if name == _METACLASS_ or bases[0].__name__ == _METACLASS_:
return super(BaseModel, cls).__new__(cls, name, bases, attrs)
meta_options = {}
meta = attrs.pop('Meta', None)
if meta:
for k, v in meta.__dict__.items():
if not k.startswith('_'):
meta_options[k] = v
model_pk = getattr(meta, 'primary_key', None)
parent_pk = None
# inherit any field descriptors by deep copying the underlying field
# into the attrs of the new model, additionally see if the bases define
# inheritable model options and swipe them
for b in bases:
if not hasattr(b, '_meta'):
continue
base_meta = getattr(b, '_meta')
if parent_pk is None:
parent_pk = deepcopy(base_meta.primary_key)
all_inheritable = cls.inheritable | base_meta._additional_keys
for (k, v) in base_meta.__dict__.items():
if k in all_inheritable and k not in meta_options:
meta_options[k] = v
for (k, v) in b.__dict__.items():
if k in attrs:
continue
if isinstance(v, FieldDescriptor):
if not v.field.primary_key:
attrs[k] = deepcopy(v.field)
# initialize the new class and set the magic attributes
cls = super(BaseModel, cls).__new__(cls, name, bases, attrs)
ModelOptionsBase = meta_options.get('model_options_base', ModelOptions)
cls._meta = ModelOptionsBase(cls, **meta_options)
cls._data = None
cls._meta.indexes = list(cls._meta.indexes)
if not cls._meta.db_table:
cls._meta.db_table = re.sub('[^\w]+', '_', cls.__name__.lower())
# replace fields with field descriptors, calling the add_to_class hook
fields = []
for name, attr in cls.__dict__.items():
if isinstance(attr, Field):
if attr.primary_key and model_pk:
raise ValueError('primary key is overdetermined.')
elif attr.primary_key:
model_pk, pk_name = attr, name
else:
fields.append((attr, name))
composite_key = False
if model_pk is None:
if parent_pk:
model_pk, pk_name = parent_pk, parent_pk.name
else:
model_pk, pk_name = PrimaryKeyField(primary_key=True), 'id'
elif isinstance(model_pk, CompositeKey):
pk_name = '_composite_key'
composite_key = True
if model_pk is not False:
model_pk.add_to_class(cls, pk_name)
cls._meta.primary_key = model_pk
cls._meta.auto_increment = (
isinstance(model_pk, PrimaryKeyField) or
bool(model_pk.sequence))
cls._meta.composite_key = composite_key
for field, name in fields:
field.add_to_class(cls, name)
# create a repr and error class before finalizing
if hasattr(cls, '__unicode__'):
setattr(cls, '__repr__', lambda self: '<%s: %r>' % (
cls.__name__, self.__unicode__()))
exc_name = '%sDoesNotExist' % cls.__name__
exc_attrs = {'__module__': cls.__module__}
exception_class = type(exc_name, (DoesNotExist,), exc_attrs)
cls.DoesNotExist = exception_class
cls._meta.prepared()
if hasattr(cls, 'validate_model'):
cls.validate_model()
DeferredRelation.resolve(cls)
return cls
def __iter__(self):
return iter(self.select())
class Model(with_metaclass(BaseModel)):
def __init__(self, *args, **kwargs):
self._data = self._meta.get_default_dict()
self._dirty = set(self._data)
self._obj_cache = {}
for k, v in kwargs.items():
setattr(self, k, v)
@classmethod
def alias(cls):
return ModelAlias(cls)
@classmethod
def select(cls, *selection):
query = SelectQuery(cls, *selection)
if cls._meta.order_by:
query = query.order_by(*cls._meta.order_by)
return query
@classmethod
def update(cls, __data=None, **update):
fdict = __data or {}
fdict.update([(cls._meta.fields[f], update[f]) for f in update])
return UpdateQuery(cls, fdict)
@classmethod
def insert(cls, __data=None, **insert):
fdict = __data or {}
fdict.update([(cls._meta.fields[f], insert[f]) for f in insert])
return InsertQuery(cls, fdict)
@classmethod
def insert_many(cls, rows, validate_fields=True):
return InsertQuery(cls, rows=rows, validate_fields=validate_fields)
@classmethod
def insert_from(cls, fields, query):
return InsertQuery(cls, fields=fields, query=query)
@classmethod
def delete(cls):
return DeleteQuery(cls)
@classmethod
def raw(cls, sql, *params):
return RawQuery(cls, sql, *params)
@classmethod
def create(cls, **query):
inst = cls(**query)
inst.save(force_insert=True)
inst._prepare_instance()
return inst
@classmethod
def get(cls, *query, **kwargs):
sq = cls.select().naive()
if query:
sq = sq.where(*query)
if kwargs:
sq = sq.filter(**kwargs)
return sq.get()
@classmethod
def get_or_create(cls, **kwargs):
defaults = kwargs.pop('defaults', {})
query = cls.select()
for field, value in kwargs.items():
if '__' in field:
query = query.filter(**{field: value})
else:
query = query.where(getattr(cls, field) == value)
try:
return query.get(), False
except cls.DoesNotExist:
try:
params = dict((k, v) for k, v in kwargs.items()
if '__' not in k)
params.update(defaults)
with cls._meta.database.atomic():
return cls.create(**params), True
except IntegrityError as exc:
try:
return query.get(), False
except cls.DoesNotExist:
raise exc
@classmethod
def create_or_get(cls, **kwargs):
try:
with cls._meta.database.atomic():
return cls.create(**kwargs), True
except IntegrityError:
query = [] # TODO: multi-column unique constraints.
for field_name, value in kwargs.items():
field = getattr(cls, field_name)
if field.unique or field.primary_key:
query.append(field == value)
return cls.get(*query), False
@classmethod
def filter(cls, *dq, **query):
return cls.select().filter(*dq, **query)
@classmethod
def table_exists(cls):
kwargs = {}
if cls._meta.schema:
kwargs['schema'] = cls._meta.schema
return cls._meta.db_table in cls._meta.database.get_tables(**kwargs)
@classmethod
def create_table(cls, fail_silently=False):
if fail_silently and cls.table_exists():
return
db = cls._meta.database
pk = cls._meta.primary_key
if db.sequences and pk is not False and pk.sequence:
if not db.sequence_exists(pk.sequence):
db.create_sequence(pk.sequence)
db.create_table(cls)
cls._create_indexes()
@classmethod
def _fields_to_index(cls):
fields = []
for field in cls._meta.sorted_fields:
if field.primary_key:
continue
requires_index = any((
field.index,
field.unique,
isinstance(field, ForeignKeyField)))
if requires_index:
fields.append(field)
return fields
@classmethod
def _index_data(cls):
return itertools.chain(
[((field,), field.unique) for field in cls._fields_to_index()],
cls._meta.indexes or ())
@classmethod
def _create_indexes(cls):
for field_list, is_unique in cls._index_data():
cls._meta.database.create_index(cls, field_list, is_unique)
@classmethod
def _drop_indexes(cls, safe=False):
for field_list, is_unique in cls._index_data():
cls._meta.database.drop_index(cls, field_list, safe)
@classmethod
def sqlall(cls):
queries = []
compiler = cls._meta.database.compiler()
pk = cls._meta.primary_key
if cls._meta.database.sequences and pk.sequence:
queries.append(compiler.create_sequence(pk.sequence))
queries.append(compiler.create_table(cls))
for field in cls._fields_to_index():
queries.append(compiler.create_index(cls, [field], field.unique))
if cls._meta.indexes:
for field_names, unique in cls._meta.indexes:
fields = [cls._meta.fields[f] for f in field_names]
queries.append(compiler.create_index(cls, fields, unique))
return [sql for sql, _ in queries]
@classmethod
def drop_table(cls, fail_silently=False, cascade=False):
cls._meta.database.drop_table(cls, fail_silently, cascade)
@classmethod
def truncate_table(cls, restart_identity=False, cascade=False):
cls._meta.database.truncate_table(cls, restart_identity, cascade)
@classmethod
def as_entity(cls):
if cls._meta.schema:
return Entity(cls._meta.schema, cls._meta.db_table)
return Entity(cls._meta.db_table)
@classmethod
def noop(cls, *args, **kwargs):
return NoopSelectQuery(cls, *args, **kwargs)
def _get_pk_value(self):
return getattr(self, self._meta.primary_key.name)
get_id = _get_pk_value # Backwards-compatibility.
def _set_pk_value(self, value):
if not self._meta.composite_key:
setattr(self, self._meta.primary_key.name, value)
set_id = _set_pk_value # Backwards-compatibility.
def _pk_expr(self):
return self._meta.primary_key == self._get_pk_value()
def _prepare_instance(self):
self._dirty.clear()
self.prepared()
def prepared(self):
pass
def _prune_fields(self, field_dict, only):
new_data = {}
for field in only:
if field.name in field_dict:
new_data[field.name] = field_dict[field.name]
return new_data
def _populate_unsaved_relations(self, field_dict):
for key in self._meta.rel:
conditions = (
key in self._dirty and
key in field_dict and
field_dict[key] is None and
self._obj_cache.get(key) is not None)
if conditions:
setattr(self, key, getattr(self, key))
field_dict[key] = self._data[key]
def save(self, force_insert=False, only=None):
field_dict = dict(self._data)
if self._meta.primary_key is not False:
pk_field = self._meta.primary_key
pk_value = self._get_pk_value()
else:
pk_field = pk_value = None
if only:
field_dict = self._prune_fields(field_dict, only)
elif self._meta.only_save_dirty and not force_insert:
field_dict = self._prune_fields(
field_dict,
self.dirty_fields)
if not field_dict:
self._dirty.clear()
return False
self._populate_unsaved_relations(field_dict)
if pk_value is not None and not force_insert:
if self._meta.composite_key:
for pk_part_name in pk_field.field_names:
field_dict.pop(pk_part_name, None)
else:
field_dict.pop(pk_field.name, None)
rows = self.update(**field_dict).where(self._pk_expr()).execute()
elif pk_field is None:
self.insert(**field_dict).execute()
rows = 1
else:
pk_from_cursor = self.insert(**field_dict).execute()
if pk_from_cursor is not None:
pk_value = pk_from_cursor
self._set_pk_value(pk_value)
rows = 1
self._dirty.clear()
return rows
def is_dirty(self):
return bool(self._dirty)
@property
def dirty_fields(self):
return [f for f in self._meta.sorted_fields if f.name in self._dirty]
def dependencies(self, search_nullable=False):
model_class = type(self)
query = self.select().where(self._pk_expr())
stack = [(type(self), query)]
seen = set()
while stack:
klass, query = stack.pop()
if klass in seen:
continue
seen.add(klass)
for rel_name, fk in klass._meta.reverse_rel.items():
rel_model = fk.model_class
if fk.rel_model is model_class:
node = (fk == self._data[fk.to_field.name])
subquery = rel_model.select().where(node)
else:
node = fk << query
subquery = rel_model.select().where(node)
if not fk.null or search_nullable:
stack.append((rel_model, subquery))
yield (node, fk)
def delete_instance(self, recursive=False, delete_nullable=False):
if recursive:
dependencies = self.dependencies(delete_nullable)
for query, fk in reversed(list(dependencies)):
model = fk.model_class
if fk.null and not delete_nullable:
model.update(**{fk.name: None}).where(query).execute()
else:
model.delete().where(query).execute()
return self.delete().where(self._pk_expr()).execute()
def __hash__(self):
return hash((self.__class__, self._get_pk_value()))
def __eq__(self, other):
return (
other.__class__ == self.__class__ and
self._get_pk_value() is not None and
other._get_pk_value() == self._get_pk_value())
def __ne__(self, other):
return not self == other
def clean_prefetch_subquery(query):
query = query.clone()
query._group_by = query._having = None
return query
def prefetch_add_subquery(sq, subqueries):
fixed_queries = [PrefetchResult(sq)]
for i, subquery in enumerate(subqueries):
if isinstance(subquery, tuple):
subquery, target_model = subquery
else:
target_model = None
if not isinstance(subquery, Query) and issubclass(subquery, Model):
subquery = subquery.select()
subquery_model = subquery.model_class
fks = backrefs = None
for j in reversed(range(i + 1)):
prefetch_result = fixed_queries[j]
last_query = prefetch_result.query
last_model = prefetch_result.model
rels = subquery_model._meta.rel_for_model(last_model, multi=True)
if rels:
fks = [getattr(subquery_model, fk.name) for fk in rels]
pks = [getattr(last_model, fk.to_field.name) for fk in rels]
else:
backrefs = last_model._meta.rel_for_model(
subquery_model,
multi=True)
if (fks or backrefs) and ((target_model is last_model) or
(target_model is None)):
break
if not (fks or backrefs):
tgt_err = ' using %s' % target_model if target_model else ''
raise AttributeError('Error: unable to find foreign key for '
'query: %s%s' % (subquery, tgt_err))
if fks:
cleaned = clean_prefetch_subquery(last_query)
expr = reduce(operator.or_, [
(fk << cleaned.select(pk))
for (fk, pk) in zip(fks, pks)])
subquery = subquery.where(expr)
fixed_queries.append(PrefetchResult(subquery, fks, False))
elif backrefs:
cleaned = clean_prefetch_subquery(last_query)
expr = reduce(operator.or_, [
(backref.to_field << cleaned.select(backref))
for backref in backrefs])
subquery = subquery.where(expr)
fixed_queries.append(PrefetchResult(subquery, backrefs, True))
return fixed_queries
__prefetched = namedtuple('__prefetched', (
'query', 'fields', 'backref', 'rel_models', 'field_to_name', 'model'))
class PrefetchResult(__prefetched):
def __new__(cls, query, fields=None, backref=None, rel_models=None,
field_to_name=None, model=None):
if fields:
if backref:
rel_models = [field.model_class for field in fields]
foreign_key_attrs = [field.to_field.name for field in fields]
else:
rel_models = [field.rel_model for field in fields]
foreign_key_attrs = [field.name for field in fields]
field_to_name = list(zip(fields, foreign_key_attrs))
model = query.model_class
return super(PrefetchResult, cls).__new__(
cls, query, fields, backref, rel_models, field_to_name, model)
def populate_instance(self, instance, id_map):
if self.backref:
for field in self.fields:
identifier = instance._data[field.name]
key = (field, identifier)
if key in id_map:
setattr(instance, field.name, id_map[key])
else:
for field, attname in self.field_to_name:
identifier = instance._data[field.to_field.name]
key = (field, identifier)
rel_instances = id_map.get(key, [])
dest = '%s_prefetch' % field.related_name
for inst in rel_instances:
setattr(inst, attname, instance)
setattr(instance, dest, rel_instances)
def store_instance(self, instance, id_map):
for field, attname in self.field_to_name:
identity = field.to_field.python_value(instance._data[attname])
key = (field, identity)
if self.backref:
id_map[key] = instance
else:
id_map.setdefault(key, [])
id_map[key].append(instance)
def prefetch(sq, *subqueries):
if not subqueries:
return sq
fixed_queries = prefetch_add_subquery(sq, subqueries)
deps = {}
rel_map = {}
for prefetch_result in reversed(fixed_queries):
query_model = prefetch_result.model
if prefetch_result.fields:
for rel_model in prefetch_result.rel_models:
rel_map.setdefault(rel_model, [])
rel_map[rel_model].append(prefetch_result)
deps[query_model] = {}
id_map = deps[query_model]
has_relations = bool(rel_map.get(query_model))
for instance in prefetch_result.query:
if prefetch_result.fields:
prefetch_result.store_instance(instance, id_map)
if has_relations:
for rel in rel_map[query_model]:
rel.populate_instance(instance, deps[rel.model])
return prefetch_result.query
def create_model_tables(models, **create_table_kwargs):
"""Create tables for all given models (in the right order)."""
for m in sort_models_topologically(models):
m.create_table(**create_table_kwargs)
def drop_model_tables(models, **drop_table_kwargs):
"""Drop tables for all given models (in the right order)."""
for m in reversed(sort_models_topologically(models)):
m.drop_table(**drop_table_kwargs)