ORM即把数据库中的一个数据表给映射到代码里的一个类上,表的字段对应着类的属性。将增删改查等基本操作封装为类对应的方法,从而写出更干净和更富有层次性的代码。
以查询数据为例,原始的写法要Python代码sql混合,示例代码如下:
import MySQLdb
import os,sys def main():
conn=MySQLdb.connect(host="localhost",port=3306,passwd='toor',user='root')
conn.select_db("xdyweb")
cursor=conn.cursor()
count=cursor.execute("select * from users")
result=cursor.fetchmany()
print(isinstance(result,tuple))
print(type(result))
print(len(result))
for i in result:
print(i)
for j in i:
print(j)
print("row count is %s"%count)
cursor.close()
conn.close() if __name__=="__main__":
cp=os.path.abspath('.')
sys.path.append(cp)
main()
而我们现在想要实现的是类似这样的效果:
#查找:
u=user.get(id=1)
#添加
u=user(name='y',password='y',email='1@q.com')
u.insert()
实现思路是遍历Model的属性,得出要操作的字段,然后根据不同的操作要求(增,删,改,查)去动态生成不同的sql语句。
#coding:utf-8 #author:xudongyang #19:25 2015/4/15 import logging,time,sys,os,threading
import test as db
# logging.basicConfig(level=logging.INFO,format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',datefmt='%a, %d %b %Y %H:%M:%S')
logging.basicConfig(level=logging.INFO) class Field(object):
#映射数据表中一个字段的属性,包括字段名称,默认值,是否主键,可空,可更新,可插入,字段类型(varchar,text,Integer之类),字段顺序
_count=0#当前定义的字段是类的第几个字段
def __init__(self,**kw):
self.name = kw.get('name', None)
self._default = kw.get('default', None)
self.primary_key = kw.get('primary_key', False)
self.nullable = kw.get('nullable', False)
self.updatable = kw.get('updatable', True)
self.insertable = kw.get('insertable', True)
self.ddl = kw.get('ddl', '')
self._order = Field._count
Field._count = Field._count + 1
@property
def default(self):
d = self._default
return d() if callable(d) else d class StringField(Field):
#继承自Field,
def __init__(self, **kw):
if not 'default' in kw:
kw['default'] = ''
if not 'ddl' in kw:
kw['ddl'] = 'varchar(255)'
super(StringField, self).__init__(**kw) class IntegerField(Field): def __init__(self, **kw):
if not 'default' in kw:
kw['default'] = 0
if not 'ddl' in kw:
kw['ddl'] = 'bigint'
super(IntegerField, self).__init__(**kw)
class FloatField(Field): def __init__(self, **kw):
if not 'default' in kw:
kw['default'] = 0.0
if not 'ddl' in kw:
kw['ddl'] = 'real'
super(FloatField, self).__init__(**kw) class BooleanField(Field): def __init__(self, **kw):
if not 'default' in kw:
kw['default'] = False
if not 'ddl' in kw:
kw['ddl'] = 'bool'
super(BooleanField, self).__init__(**kw) class TextField(Field): def __init__(self, **kw):
if not 'default' in kw:
kw['default'] = ''
if not 'ddl' in kw:
kw['ddl'] = 'text'
super(TextField, self).__init__(**kw) class BlobField(Field): def __init__(self, **kw):
if not 'default' in kw:
kw['default'] = ''
if not 'ddl' in kw:
kw['ddl'] = 'blob'
super(BlobField, self).__init__(**kw) class VersionField(Field): def __init__(self, name=None):
super(VersionField, self).__init__(name=name, default=0, ddl='bigint') def _gen_sql(table_name, mappings):
print(__name__+'is called'+str(time.time()))
pk = None
sql = ['-- generating SQL for %s:' % table_name, 'create table `%s` (' % table_name]
for f in sorted(mappings.values(), lambda x, y: cmp(x._order, y._order)):
if not hasattr(f, 'ddl'):
raise StandardError('no ddl in field "%s".' % n)
ddl = f.ddl
nullable = f.nullable
if f.primary_key:
pk = f.name
sql.append(nullable and ' `%s` %s,' % (f.name, ddl) or ' `%s` %s not null,' % (f.name, ddl))
sql.append(' primary key(`%s`)' % pk)
sql.append(');')
sql='\n'.join(sql)
logging.info('sql is :'+sql)
return sql class ModelMetaClass(type):
#为什么__new__方法会被调用两次
#为什么attrs.pop(k)要进行这个,而且进行了之后u.name就可以输出yy而不是一个Field对象
def __new__(cls,name,base,attrs):
logging.info("cls is:"+str(cls))
logging.info("name is:"+str(name))
logging.info("base is:"+str(base))
logging.info("attrs is:"+str(attrs))
print('new is called at '+str(cls)+str(time.time())) if name =="Model":
return type.__new__(cls,name,base,attrs)
mapping=dict()
primary_key=None
for k,v in attrs.iteritems():
primary_key=None
if isinstance(v,Field):
if not v.name:
v.name=k
mapping[k]=v
#检测是否是主键
if v.primary_key:
if primary_key:
raise TypeError("There only should be on primary_key")
if v.updatable:
logging.warning('primary_key should not be changed')
v.updatable=False
if v.nullable:
logging.warning('pri.. not be.null')
v.nullable=False
primary_key=v for k in mapping.iterkeys():
attrs.pop(k) attrs['__mappings__']=mapping
logging.info('mapping is :'+str(mapping))
attrs['__primary_key__']=primary_key
attrs['__sql__']=lambda self: _gen_sql(attrs['__table__'], mapping)
return type.__new__(cls,name,base,attrs)
class ModelMetaclass(type):
'''
Metaclass for model objects.
'''
def __new__(cls, name, bases, attrs):
# skip base Model class:
if name=='Model':
return type.__new__(cls, name, bases, attrs) # store all subclasses info:
if not hasattr(cls, 'subclasses'):
cls.subclasses = {}
if not name in cls.subclasses:
cls.subclasses[name] = name
else:
logging.warning('Redefine class: %s' % name) logging.info('Scan ORMapping %s...' % name)
mappings = dict()
primary_key = None
for k, v in attrs.iteritems():
if isinstance(v, Field):
if not v.name:
v.name = k
logging.info('Found mapping: %s => %s' % (k, v))
# check duplicate primary key:
if v.primary_key:
if primary_key:
raise TypeError('Cannot define more than 1 primary key in class: %s' % name)
if v.updatable:
logging.warning('NOTE: change primary key to non-updatable.')
v.updatable = False
if v.nullable:
logging.warning('NOTE: change primary key to non-nullable.')
v.nullable = False
primary_key = v
mappings[k] = v
# check exist of primary key:
if not primary_key:
raise TypeError('Primary key not defined in class: %s' % name)
for k in mappings.iterkeys():
attrs.pop(k)
if not '__table__' in attrs:
attrs['__table__'] = name.lower()
attrs['__mappings__'] = mappings
attrs['__primary_key__'] = primary_key
attrs['__sql__'] = lambda self: _gen_sql(attrs['__table__'], mappings)
# for trigger in _triggers:
# if not trigger in attrs:
# attrs[trigger] = None
return type.__new__(cls, name, bases, attrs)
class Model(dict):
__metaclass__ = ModelMetaClass
def __init__(self, **kw):
super(Model, self).__init__(**kw) def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Dict' object has no attribute '%s'" % key) def __setattr__(self, key, value):
self[key] = value @classmethod
def get(cls, pk):
'''
Get by primary key.
'''
d = db.select_one('select * from %s where %s=?' % (cls.__table__, cls.__primary_key__.name), pk)
return cls(**d) if d else None @classmethod
def find_first(cls, where, *args):
'''
Find by where clause and return one result. If multiple results found,
only the first one returned. If no result found, return None.
'''
d = db.select_one('select * from %s %s' % (cls.__table__, where), *args)
return cls(**d) if d else None @classmethod
def find_all(cls, *args):
'''
Find all and return list.
'''
L = db.select('select * from `%s`' % cls.__table__)
return [cls(**d) for d in L] @classmethod
def find_by(cls, where, *args):
'''
Find by where clause and return list.
'''
L = db.select('select * from `%s` %s' % (cls.__table__, where), *args)
return [cls(**d) for d in L] @classmethod
def count_all(cls):
'''
Find by 'select count(pk) from table' and return integer.
'''
return db.select_int('select count(`%s`) from `%s`' % (cls.__primary_key__.name, cls.__table__)) @classmethod
def count_by(cls, where, *args):
'''
Find by 'select count(pk) from table where ... ' and return int.
'''
return db.select_int('select count(`%s`) from `%s` %s' % (cls.__primary_key__.name, cls.__table__, where), *args) def update(self):
self.pre_update and self.pre_update()
L = []
args = []
for k, v in self.__mappings__.iteritems():
if v.updatable:
if hasattr(self, k):
arg = getattr(self, k)
else:
arg = v.default
setattr(self, k, arg)
L.append('`%s`=?' % k)
args.append(arg)
pk = self.__primary_key__.name
args.append(getattr(self, pk))
db.update('update `%s` set %s where %s=?' % (self.__table__, ','.join(L), pk), *args)
return self def delete(self):
self.pre_delete and self.pre_delete()
pk = self.__primary_key__.name
args = (getattr(self, pk), )
db.update('delete from `%s` where `%s`=?' % (self.__table__, pk), *args)
return self def insert(self):
self.pre_insert and self.pre_insert()
params = {}
for k, v in self.__mappings__.iteritems():
if v.insertable:
if not hasattr(self, k):
setattr(self, k, v.default)
params[v.name] = getattr(self, k)
db.insert('%s' % self.__table__, **params)
return self
class user(Model):
name=StringField(name='name',primary_key=True)
password=StringField(name='password') def main():
u=user(name='yy',password='yyp') logging.info(u.__sql__)
logging.info(dir(u.__mappings__.values()))
u.password='xxx'
print(u.password) if __name__ == '__main__':
main()
要注意的是遍历Model属性这部分代码,利用了Python的__metaclass__实现,截断了Model的创建过程,进而对Model的属性进行遍历,具体代码见ModelMetaclass的__new__方法实现。
这是模仿廖老师的代码,[http://liaoxuefeng.com],感谢。还有两个疑问注释在了代码中,希望有看明白的人解惑。