try:
import cPickle as pickle
except ImportError:
import pickle
import re
import sys
PY2 = sys.version_info[0] == 2
# Conditional standard library imports.
try:
from cStringIO import StringIO
except ImportError:
if sys.version_info[0] == 2:
from StringIO import StringIO
else:
from io import StringIO
try:
import bz2
except ImportError:
bz2 = None
try:
import zlib
except ImportError:
zlib = None
try:
from Crypto.Cipher import AES
from Crypto import Random
except ImportError:
AES = Random = None
try:
from bcrypt import hashpw, gensalt
except ImportError:
hashpw = gensalt = None
from peewee import *
from peewee import binary_construct
from peewee import Field
from peewee import FieldDescriptor
from peewee import SelectQuery
from peewee import unicode_type
if hashpw and gensalt:
class PasswordHash(bytes):
def check_password(self, password):
password = password.encode('utf-8')
return hashpw(password, self) == self
class PasswordField(BlobField):
def __init__(self, iterations=12, *args, **kwargs):
if None in (hashpw, gensalt):
raise ValueError('Missing library required for PasswordField: bcrypt')
self.bcrypt_iterations = iterations
self.raw_password = None
super(PasswordField, self).__init__(*args, **kwargs)
def db_value(self, value):
"""Convert the python value for storage in the database."""
if isinstance(value, PasswordHash):
return bytes(value)
if isinstance(value, unicode_type):
value = value.encode('utf-8')
salt = gensalt(self.bcrypt_iterations)
return value if value is None else hashpw(value, salt)
def python_value(self, value):
"""Convert the database value to a pythonic value."""
if isinstance(value, unicode_type):
value = value.encode('utf-8')
return PasswordHash(value)
class DeferredThroughModel(object):
def set_field(self, model_class, field, name):
self.model_class = model_class
self.field = field
self.name = name
def set_model(self, through_model):
self.field._through_model = through_model
self.field.add_to_class(self.model_class, self.name)
class ManyToManyField(Field):
def __init__(self, rel_model, related_name=None, through_model=None,
_is_backref=False):
if through_model is not None and not (
isinstance(through_model, (Proxy, DeferredThroughModel)) or
issubclass(through_model, Model)):
raise TypeError('Unexpected value for `through_model`. Expected '
'`Model`, `Proxy` or `DeferredThroughModel`.')
self.rel_model = rel_model
self._related_name = related_name
self._through_model = through_model
self._is_backref = _is_backref
self.primary_key = False
self.verbose_name = None
def _get_descriptor(self):
return ManyToManyFieldDescriptor(self)
def add_to_class(self, model_class, name):
if isinstance(self._through_model, Proxy):
def callback(through_model):
self._through_model = through_model
self.add_to_class(model_class, name)
self._through_model.attach_callback(callback)
return
elif isinstance(self._through_model, DeferredThroughModel):
self._through_model.set_field(model_class, self, name)
return
self.name = name
self.model_class = model_class
if not self.verbose_name:
self.verbose_name = re.sub('_+', ' ', name).title()
setattr(model_class, name, self._get_descriptor())
if not self._is_backref:
backref = ManyToManyField(
self.model_class,
through_model=self._through_model,
_is_backref=True)
related_name = self._related_name or model_class._meta.name + 's'
backref.add_to_class(self.rel_model, related_name)
def get_models(self):
return [model for _, model in sorted((
(self._is_backref, self.model_class),
(not self._is_backref, self.rel_model)))]
def get_through_model(self):
if not self._through_model:
lhs, rhs = self.get_models()
tables = [model._meta.db_table for model in (lhs, rhs)]
class Meta:
database = self.model_class._meta.database
db_table = '%s_%s_through' % tuple(tables)
indexes = (
((lhs._meta.name, rhs._meta.name),
True),)
validate_backrefs = False
attrs = {
lhs._meta.name: ForeignKeyField(rel_model=lhs),
rhs._meta.name: ForeignKeyField(rel_model=rhs)}
attrs['Meta'] = Meta
self._through_model = type(
'%s%sThrough' % (lhs.__name__, rhs.__name__),
(Model,),
attrs)
return self._through_model
class ManyToManyFieldDescriptor(FieldDescriptor):
def __init__(self, field):
super(ManyToManyFieldDescriptor, self).__init__(field)
self.model_class = field.model_class
self.rel_model = field.rel_model
self.through_model = field.get_through_model()
self.src_fk = self.through_model._meta.rel_for_model(self.model_class)
self.dest_fk = self.through_model._meta.rel_for_model(self.rel_model)
def __get__(self, instance, instance_type=None):
if instance is not None:
return (ManyToManyQuery(instance, self, self.rel_model)
.select()
.join(self.through_model)
.join(self.model_class)
.where(self.src_fk == instance))
return self.field
def __set__(self, instance, value):
query = self.__get__(instance)
query.add(value, clear_existing=True)
class ManyToManyQuery(SelectQuery):
def __init__(self, instance, field_descriptor, *args, **kwargs):
self._instance = instance
self._field_descriptor = field_descriptor
super(ManyToManyQuery, self).__init__(*args, **kwargs)
def clone(self):
query = type(self)(
self._instance,
self._field_descriptor,
self.model_class)
query.database = self.database
return self._clone_attributes(query)
def _id_list(self, model_or_id_list):
if isinstance(model_or_id_list[0], Model):
return [obj.get_id() for obj in model_or_id_list]
return model_or_id_list
def add(self, value, clear_existing=False):
if clear_existing:
self.clear()
fd = self._field_descriptor
if isinstance(value, SelectQuery):
query = value.select(
SQL(str(self._instance.get_id())),
fd.rel_model._meta.primary_key)
fd.through_model.insert_from(
fields=[fd.src_fk, fd.dest_fk],
query=query).execute()
else:
if not isinstance(value, (list, tuple)):
value = [value]
if not value:
return
inserts = [{
fd.src_fk.name: self._instance.get_id(),
fd.dest_fk.name: rel_id}
for rel_id in self._id_list(value)]
fd.through_model.insert_many(inserts).execute()
def remove(self, value):
fd = self._field_descriptor
if isinstance(value, SelectQuery):
subquery = value.select(value.model_class._meta.primary_key)
return (fd.through_model
.delete()
.where(
(fd.dest_fk << subquery) &
(fd.src_fk == self._instance.get_id()))
.execute())
else:
if not isinstance(value, (list, tuple)):
value = [value]
if not value:
return
return (fd.through_model
.delete()
.where(
(fd.dest_fk << self._id_list(value)) &
(fd.src_fk == self._instance.get_id()))
.execute())
def clear(self):
return (self._field_descriptor.through_model
.delete()
.where(self._field_descriptor.src_fk == self._instance)
.execute())
class CompressedField(BlobField):
ZLIB = 'zlib'
BZ2 = 'bz2'
algorithm_to_import = {
ZLIB: zlib,
BZ2: bz2,
}
def __init__(self, compression_level=6, algorithm=ZLIB, *args,
**kwargs):
self.compression_level = compression_level
if algorithm not in self.algorithm_to_import:
raise ValueError('Unrecognized algorithm %s' % algorithm)
compress_module = self.algorithm_to_import[algorithm]
if compress_module is None:
raise ValueError('Missing library required for %s.' % algorithm)
self.algorithm = algorithm
self.compress = compress_module.compress
self.decompress = compress_module.decompress
super(CompressedField, self).__init__(*args, **kwargs)
if PY2:
def db_value(self, value):
if value is not None:
return binary_construct(
self.compress(value, self.compression_level))
def python_value(self, value):
if value is not None:
return self.decompress(value)
else:
def db_value(self, value):
if value is not None:
return self.compress(
binary_construct(value), self.compression_level)
def python_value(self, value):
if value is not None:
return self.decompress(value).decode('utf-8')
class PickledField(BlobField):
def db_value(self, value):
if value is not None:
return pickle.dumps(value)
def python_value(self, value):
if value is not None:
if isinstance(value, unicode_type):
value = value.encode('raw_unicode_escape')
return pickle.loads(value)
if AES and Random:
class AESEncryptedField(BlobField):
def __init__(self, key, *args, **kwargs):
self.key = key
super(AESEncryptedField, self).__init__(*args, **kwargs)
def get_cipher(self, key, iv):
if len(key) > 32:
raise ValueError('Key length cannot exceed 32 bytes.')
key = key + ' ' * (32 - len(key))
return AES.new(key, AES.MODE_CFB, iv)
def encrypt(self, value):
iv = Random.get_random_bytes(AES.block_size)
cipher = self.get_cipher(self.key, iv)
return iv + cipher.encrypt(value)
def decrypt(self, value):
iv = value[:AES.block_size]
cipher = self.get_cipher(self.key, iv)
return cipher.decrypt(value[AES.block_size:])
if PY2:
def db_value(self, value):
if value is not None:
return binary_construct(self.encrypt(value))
def python_value(self, value):
if value is not None:
return self.decrypt(value)
else:
def db_value(self, value):
if value is not None:
return self.encrypt(value)
def python_value(self, value):
if value is not None:
return self.decrypt(value)