"""
Collection of postgres-specific extensions, currently including:
* Support for hstore, a key/value type storage
"""
import uuid
from peewee import *
from peewee import Expression
from peewee import logger
from peewee import Node
from peewee import OP
from peewee import Param
from peewee import Passthrough
from peewee import returns_clone
from peewee import QueryCompiler
from peewee import SelectQuery
from peewee import UUIDField # For backwards-compatibility.
try:
from psycopg2cffi import compat
compat.register()
except ImportError:
pass
from psycopg2.extensions import adapt
from psycopg2.extensions import AsIs
from psycopg2.extensions import register_adapter
from psycopg2.extras import register_hstore
try:
from psycopg2.extras import Json
except:
Json = None
@Node.extend(clone=False)
def cast(self, as_type):
return Expression(self, OP.CAST, SQL(as_type))
class _LookupNode(Node):
def __init__(self, node, parts):
self.node = node
self.parts = parts
super(_LookupNode, self).__init__()
def clone_base(self):
return type(self)(self.node, list(self.parts))
def cast(self, as_type):
return Expression(Clause(self, parens=True), OP.CAST, SQL(as_type))
class _JsonLookupBase(_LookupNode):
def __init__(self, node, parts, as_json=False):
super(_JsonLookupBase, self).__init__(node, parts)
self._as_json = as_json
def clone_base(self):
return type(self)(self.node, list(self.parts), self._as_json)
@returns_clone
def as_json(self, as_json=True):
self._as_json = as_json
def contains(self, other):
clone = self.as_json(True)
if isinstance(other, (list, dict)):
return Expression(clone, OP.JSONB_CONTAINS, Json(other))
return Expression(clone, OP.JSONB_EXISTS, other)
def contains_any(self, *keys):
return Expression(
self.as_json(True),
OP.JSONB_CONTAINS_ANY_KEY,
Passthrough(list(keys)))
def contains_all(self, *keys):
return Expression(
self.as_json(True),
OP.JSONB_CONTAINS_ALL_KEYS,
Passthrough(list(keys)))
class JsonLookup(_JsonLookupBase):
_node_type = 'json_lookup'
def __getitem__(self, value):
return JsonLookup(self.node, self.parts + [value], self._as_json)
class JsonPath(_JsonLookupBase):
_node_type = 'json_path'
class ObjectSlice(_LookupNode):
_node_type = 'object_slice'
@classmethod
def create(cls, node, value):
if isinstance(value, slice):
parts = [value.start or 0, value.stop or 0]
elif isinstance(value, int):
parts = [value]
else:
parts = map(int, value.split(':'))
return cls(node, parts)
def __getitem__(self, value):
return ObjectSlice.create(self, value)
class _Array(object):
def __init__(self, field, items):
self.field = field
self.items = items
def adapt_array(arr):
conn = arr.field.model_class._meta.database.get_conn()
items = adapt(arr.items)
items.prepare(conn)
return AsIs('%s::%s%s' % (
items,
arr.field.get_column_type(),
'[]' * arr.field.dimensions))
register_adapter(_Array, adapt_array)
class IndexedFieldMixin(object):
default_index_type = 'GiST'
def __init__(self, index_type=None, *args, **kwargs):
kwargs.setdefault('index', True) # By default, use an index.
super(IndexedFieldMixin, self).__init__(*args, **kwargs)
if self.index:
self.index_type = index_type or self.default_index_type
else:
self.index_type = None
class ArrayField(IndexedFieldMixin, Field):
default_index_type = 'GIN'
def __init__(self, field_class=IntegerField, dimensions=1, *args,
**kwargs):
self.__field = field_class(*args, **kwargs)
self.dimensions = dimensions
self.db_field = self.__field.get_db_field()
super(ArrayField, self).__init__(*args, **kwargs)
def __ddl_column__(self, column_type):
sql = self.__field.__ddl_column__(column_type)
sql.value += '[]' * self.dimensions
return sql
def db_value(self, value):
if value is None:
return
if not isinstance(value, (list, _Array)):
value = list(value)
return _Array(self, value)
def __getitem__(self, value):
return ObjectSlice.create(self, value)
def contains(self, *items):
return Expression(self, OP.ACONTAINS, Param(items))
def contains_any(self, *items):
return Expression(self, OP.ACONTAINS_ANY, Param(items))
class DateTimeTZField(DateTimeField):
db_field = 'datetime_tz'
class HStoreField(IndexedFieldMixin, Field):
db_field = 'hash'
def __getitem__(self, key):
return Expression(self, OP.HKEY, Param(key))
def keys(self):
return fn.akeys(self)
def values(self):
return fn.avals(self)
def items(self):
return fn.hstore_to_matrix(self)
def slice(self, *args):
return fn.slice(self, Passthrough(list(args)))
def exists(self, key):
return fn.exist(self, key)
def defined(self, key):
return fn.defined(self, key)
def update(self, **data):
return Expression(self, OP.HUPDATE, data)
def delete(self, *keys):
return fn.delete(self, Passthrough(list(keys)))
def contains(self, value):
if isinstance(value, dict):
return Expression(self, OP.HCONTAINS_DICT, Passthrough(value))
elif isinstance(value, (list, tuple)):
return Expression(self, OP.HCONTAINS_KEYS, Passthrough(value))
return Expression(self, OP.HCONTAINS_KEY, value)
def contains_any(self, *keys):
return Expression(self, OP.HCONTAINS_ANY_KEY, Passthrough(list(keys)))
class JSONField(Field):
db_field = 'json'
def __init__(self, dumps=None, *args, **kwargs):
if Json is None:
raise Exception('Your version of psycopg2 does not support JSON.')
self.dumps = dumps
super(JSONField, self).__init__(*args, **kwargs)
def db_value(self, value):
if value is None:
return value
if not isinstance(value, Json):
return Json(value, dumps=self.dumps)
return value
def __getitem__(self, value):
return JsonLookup(self, [value])
def path(self, *keys):
return JsonPath(self, keys)
class BinaryJSONField(IndexedFieldMixin, JSONField):
db_field = 'jsonb'
default_index_type = 'GIN'
def contains(self, other):
if isinstance(other, (list, dict)):
return Expression(self, OP.JSONB_CONTAINS, Json(other))
return Expression(self, OP.JSONB_EXISTS, Passthrough(other))
def contained_by(self, other):
return Expression(self, OP.JSONB_CONTAINED_BY, Json(other))
def contains_any(self, *items):
return Expression(
self,
OP.JSONB_CONTAINS_ANY_KEY,
Passthrough(list(items)),)
def contains_all(self, *items):
return Expression(
self,
OP.JSONB_CONTAINS_ALL_KEYS,
Passthrough(list(items)))
class TSVectorField(IndexedFieldMixin, TextField):
db_field = 'tsvector'
default_index_type = 'GIN'
def match(self, query):
return Expression(self, OP.TS_MATCH, fn.to_tsquery(query))
def Match(field, query):
return Expression(fn.to_tsvector(field), OP.TS_MATCH, fn.to_tsquery(query))
OP.update(
HKEY='key',
HUPDATE='H@>',
HCONTAINS_DICT='H?&',
HCONTAINS_KEYS='H?',
HCONTAINS_KEY='H?|',
HCONTAINS_ANY_KEY='H||',
ACONTAINS='A@>',
ACONTAINS_ANY='A||',
TS_MATCH='T@@',
JSONB_CONTAINS='JB@>',
JSONB_CONTAINED_BY='JB<@',
JSONB_CONTAINS_ANY_KEY='JB?|',
JSONB_CONTAINS_ALL_KEYS='JB?&',
JSONB_EXISTS='JB?',
CAST='::',
)
class PostgresqlExtCompiler(QueryCompiler):
def _create_index(self, model_class, fields, unique=False):
clause = super(PostgresqlExtCompiler, self)._create_index(
model_class, fields, unique)
# Allow fields to specify a type of index. HStore and Array fields
# may want to use GiST indexes, for example.
index_type = None
for field in fields:
if isinstance(field, IndexedFieldMixin):
index_type = field.index_type
if index_type:
clause.nodes.insert(-1, SQL('USING %s' % index_type))
return clause
def _parse_object_slice(self, node, alias_map, conv):
sql, params = self.parse_node(node.node, alias_map, conv)
# Postgresql uses 1-based indexes.
parts = [str(part + 1) for part in node.parts]
sql = '%s[%s]' % (sql, ':'.join(parts))
return sql, params
def _parse_json_lookup(self, node, alias_map, conv):
sql, params = self.parse_node(node.node, alias_map, conv)
lookups = [sql]
for part in node.parts:
part_sql, part_params = self.parse_node(
part, alias_map, conv)
lookups.append(part_sql)
params.extend(part_params)
if node._as_json:
sql = '->'.join(lookups)
else:
# The last lookup should be converted to text.
head, tail = lookups[:-1], lookups[-1]
sql = '->>'.join(('->'.join(head), tail))
return sql, params
def _parse_json_path(self, node, alias_map, conv):
sql, params = self.parse_node(node.node, alias_map, conv)
if node._as_json:
operand = '#>'
else:
operand = '#>>'
params.append('{%s}' % ','.join(map(str, node.parts)))
return operand.join((sql, self.interpolation)), params
def get_parse_map(self):
parse_map = super(PostgresqlExtCompiler, self).get_parse_map()
parse_map.update(
object_slice=self._parse_object_slice,
json_lookup=self._parse_json_lookup,
json_path=self._parse_json_path)
return parse_map
class PostgresqlExtDatabase(PostgresqlDatabase):
compiler_class = PostgresqlExtCompiler
def __init__(self, *args, **kwargs):
self.server_side_cursors = kwargs.pop('server_side_cursors', False)
self.register_hstore = kwargs.pop('register_hstore', True)
super(PostgresqlExtDatabase, self).__init__(*args, **kwargs)
def get_cursor(self, name=None):
if name:
return self.get_conn().cursor(name=name)
return self.get_conn().cursor()
def execute_sql(self, sql, params=None, require_commit=True,
named_cursor=False):
logger.debug((sql, params))
use_named_cursor = (named_cursor or (
self.server_side_cursors and
sql.lower().startswith('select')))
with self.exception_wrapper():
if use_named_cursor:
cursor = self.get_cursor(name=str(uuid.uuid1()))
require_commit = False
else:
cursor = self.get_cursor()
try:
cursor.execute(sql, params or ())
except Exception as exc:
if self.get_autocommit() and self.autorollback:
self.rollback()
raise
else:
if require_commit and self.get_autocommit():
self.commit()
return cursor
def _connect(self, database, **kwargs):
conn = super(PostgresqlExtDatabase, self)._connect(database, **kwargs)
if self.register_hstore:
register_hstore(conn, globally=True)
return conn
class ServerSideSelectQuery(SelectQuery):
@classmethod
def clone_from_query(cls, query):
clone = ServerSideSelectQuery(query.model_class)
return query._clone_attributes(clone)
def _execute(self):
sql, params = self.sql()
return self.database.execute_sql(
sql, params, require_commit=False, named_cursor=True)
PostgresqlExtDatabase.register_fields({
'datetime_tz': 'timestamp with time zone',
'hash': 'hstore',
'json': 'json',
'jsonb': 'jsonb',
'tsvector': 'tsvector',
})
PostgresqlExtDatabase.register_ops({
OP.HCONTAINS_DICT: '@>',
OP.HCONTAINS_KEYS: '?&',
OP.HCONTAINS_KEY: '?',
OP.HCONTAINS_ANY_KEY: '?|',
OP.HKEY: '->',
OP.HUPDATE: '||',
OP.ACONTAINS: '@>',
OP.ACONTAINS_ANY: '&&',
OP.TS_MATCH: '@@',
OP.JSONB_CONTAINS: '@>',
OP.JSONB_CONTAINED_BY: '<@',
OP.JSONB_CONTAINS_ANY_KEY: '?|',
OP.JSONB_CONTAINS_ALL_KEYS: '?&',
OP.JSONB_EXISTS: '?',
OP.CAST: '::',
})
def ServerSide(select_query):
# Flag query for execution using server-side cursors.
clone = ServerSideSelectQuery.clone_from_query(select_query)
with clone.database.transaction():
# Execute the query.
query_result = clone.execute()
# Patch QueryResultWrapper onto original query.
select_query._qr = query_result
# Expose generator for iterating over query.
for obj in query_result.iterator():
yield obj
def LateralJoin(lhs, rhs, join_type='LEFT', condition=True):
return Clause(
lhs,
SQL('%s JOIN LATERAL' % join_type),
rhs,
SQL('ON %s', condition))