简单ORM的实现

Posted zx125

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了简单ORM的实现相关的知识,希望对你有一定的参考价值。

简单的orm实现

我们在使用各种框架的时候,关于数据库这方面的使用,框架给我们提供了很好的封装,这个就是orm

关系映射

orm的底层无非就是做了关系映射

数据库的表(table) --> 类(class)
记录(record,行数据)--> 对象(object)
字段(field)--> 对象的属性(attribute)

ORM设计

字段

首先是字段,每个字段都有很多的字段属性,然后考虑到,每个表的字段可能都不同,为了给他提供更好的拓展性,所以这里我们选择用类来封装

class Field(object):
    def __init__(self, name, column_type, primary_key, default):
        self.name = name
        self.column_type = column_type
        self.primary_key = primary_key
        self.default = default

class StringField(Field):
    def __init__(self,name,
                 column_type='varchar=(255)',
                 primary_key=False,
                 default=None):
        super().__init__(name,column_type,primary_key,default)

class IntegerField(Field):
    def __init__(self,
                 name,
                 column_type='int',
                 primary_key=False,
                 default=None):
        super().__init__(name, column_type, primary_key, default)

表有表名和字段信息等信息

class Teacher(Models):
    print("teacher")
    table_name='teacher'
    tid = IntegerField(name='tid',primary_key=True)
    tname = StringField(name='tname')

为了能更好的展示,dict数据类型更适合我们,所以最后让他继承dict,但是字典取值只能通过key取值,这样不方便,所以我们也要改写他的取值方式为.取值

class Models(dict,metaclass=ModelMetaClass):
    print("Models")
    def __init__(self,**kwargs):
        print(f'Models_init')
        super().__init__(self,**kwargs)

    def __getattr__(self, item):
        return self.get(item,"没有该值")

    def __setattr__(self, key, value):
        self[key]=value

为了确保我们自己定义表的的时候不会有错误,我们需要加一步,字段的检测步骤,并对字段更好整理

class ModelMetaClass(type):
    print("ModelMetaClass")
    def __new__(cls,class_name,class_base,class_attrs):
        print("ModelMetaClass_new")
        #实例化对象的时候也会执行,我们要把这一次拦截掉
        if class_name == 'Models':
            #为了能让实例化顺利完成,返回一个空对象就行
            return type.__new__(cls,class_name,class_base,class_attrs)
        #获取表名
        table_name = class_attrs.get('table_name',class_name)

        #定义一个存主键的的变量
        primary_key = None

        #定义一个字典存储字段信息
        mapping = 

        #name='tid',primary_key=True
        #for来找到主键字段
        for k,v in class_attrs.items():
            #判断信息是否是字段
            if isinstance(v,Field):
                mapping[k] = v
                #寻找主键
                if v.primary_key:
                    if primary_key:
                        raise TypeError("主键只有一个")
                    primary_key=v.name

        #将重复的键值对删除,因为已经放入了mapping
        for k in mapping.keys():
            class_attrs.pop(k)
        if not primary_key:
            raise TypeError("表必须要有一个主键")
        class_attrs['table_name']=table_name
        class_attrs['primary_key']=primary_key
        class_attrs['mapping']=mapping
        return type.__new__(cls,class_name,class_base,class_attrs)

数据库操作

数据库操作最好就是放在类里面,然后使用类方法

    #查找
    @classmethod
    def select(cls,**kwargs):
        ms=mysql()

        #如果没有参数默认是查询全部的
        if not kwargs:
            sql='select * from %s'%cls.table_name
            res=ms.select(sql)
        else:
            k = list(kwargs.keys())[0]
            v = kwargs.get(k)
            sql='select * from %s where %s=?'%(cls.table_name,k)

            #防sql注入
            sql=sql.replace('?','%s')

            res=ms.select(sql,v)
        if res:
            return [cls(**i) for i in res]

    #新增
    def save(self):
        ms=MySQL()

        #存字段名
        fields=[]
        #存值
        values=[]
        args=[]

        for k,v in self.mapping.items():
            #主键自增,不用给他赋值
            if not v.primary_key:
                fields.append(v.name)
                args.append("?")
                values.append(getattr(self,v.name))

            sql = "insert into %s(%s) values(%s)"%(self.table_name,",".join(fields),",".join((args)))

            sql = sql.replace('?','%s')

        ms.execute(sql,values)

    def update(self):
        ms = MySQL()
        fields = []
        valuse = []
        pr = None
        for k,v in self.mapping.items():
            #获取主键值
            if v.primary_key:
                pr = getattr(self,v.name,v.default)
            else:
                fields.append(v.name+'=?')
                valuse.append(getattr(self,v.name,v.default))
            print(fields,valuse)
        sql = 'update %s set %s where %s = %s'%(self.table_name,','.join(fields),self.primary_key,pr)

        sql = sql.replace('?',"%s")

        ms.execute(sql,valuse)

Mysql连接

import pymysql


class MySQL:

    #单例模式
    __instance = None

    def __new__(cls, *args, **kwargs):
        if not cls.__instance:
            cls.__instance = object.__new__(cls)
        return cls.__instance

    def __init__(self):
        self.mysql = pymysql.connect(
            host='127.0.0.1',
            port=3306,
            user='root',
            database='orm_demo',
            password='root',
            charset='utf8',
            autocommit=True
        )

        #获取游标
        self.cursor = self.mysql.cursor(
            pymysql.cursors.DictCursor
        )

    #查看
    def select(self,sql,args=None):
        #提交sql语句
        self.cursor.execute(sql,args)

        #获取查询的结果
        res = self.cursor.fetchall()
        return res

    #提交
    def execute(self,sql,args):
        #提交语句可能会发生异常
        try:
            self.cursor.execute(sql,args)
        except Exception as e:
            print(e)

    def close(self):
        self.cursor.close()
        self.mysql.close()

整体代码部分

MySQL.py

import pymysql


class MySQL:

    #单例模式
    __instance = None

    def __new__(cls, *args, **kwargs):
        if not cls.__instance:
            cls.__instance = object.__new__(cls)
        return cls.__instance

    def __init__(self):
        self.mysql = pymysql.connect(
            host='127.0.0.1',
            port=3306,
            user='root',
            database='orm_demo',
            password='root',
            charset='utf8',
            autocommit=True
        )

        #获取游标
        self.cursor = self.mysql.cursor(
            pymysql.cursors.DictCursor
        )

    #查看
    def select(self,sql,args=None):
        #提交sql语句
        self.cursor.execute(sql,args)

        #获取查询的结果
        res = self.cursor.fetchall()
        return res

    #提交
    def execute(self,sql,args):
        #提交语句可能会发生异常
        try:
            self.cursor.execute(sql,args)
        except Exception as e:
            print(e)

    def close(self):
        self.cursor.close()
        self.mysql.close()

orm.py

from MySQL import MySQL

# 定义字段类
class Field(object):
    def __init__(self, name, column_type, primary_key, default):
        self.name = name
        self.column_type = column_type
        self.primary_key = primary_key
        self.default = default

class StringField(Field):
    def __init__(self,name,
                 column_type='varchar=(255)',
                 primary_key=False,
                 default=None):
        super().__init__(name,column_type,primary_key,default)

class IntegerField(Field):
    def __init__(self,
                 name,
                 column_type='int',
                 primary_key=False,
                 default=None):
        super().__init__(name, column_type, primary_key, default)

class ModelMetaClass(type):
    print("ModelMetaClass")
    def __new__(cls,class_name,class_base,class_attrs):
        print("ModelMetaClass_new")
        #实例化对象的时候也会执行,我们要把这一次拦截掉
        if class_name == 'Models':
            #为了能让实例化顺利完成,返回一个空对象就行
            return type.__new__(cls,class_name,class_base,class_attrs)
        #获取表名
        table_name = class_attrs.get('table_name',class_name)

        #定义一个存主键的的变量
        primary_key = None

        #定义一个字典存储字段信息
        mapping = 

        #name='tid',primary_key=True
        #for来找到主键字段
        for k,v in class_attrs.items():
            #判断信息是否是字段
            if isinstance(v,Field):
                mapping[k] = v
                #寻找主键
                if v.primary_key:
                    if primary_key:
                        raise TypeError("主键只有一个")
                    primary_key=v.name

        #将重复的键值对删除,因为已经放入了mapping
        for k in mapping.keys():
            class_attrs.pop(k)
        if not primary_key:
            raise TypeError("表必须要有一个主键")
        class_attrs['table_name']=table_name
        class_attrs['primary_key']=primary_key
        class_attrs['mapping']=mapping
        return type.__new__(cls,class_name,class_base,class_attrs)

class Models(dict,metaclass=ModelMetaClass):
    print("Models")
    def __init__(self,**kwargs):
        print(f'Models_init')
        super().__init__(self,**kwargs)

    def __getattr__(self, item):
        return self.get(item,"没有该值")

    def __setattr__(self, key, value):
        self[key]=value

    #查找
    @classmethod
    def select(cls,**kwargs):
        ms=MySQL()

        #如果没有参数默认是查询全部的
        if not kwargs:
            sql='select * from %s'%cls.table_name
            res=ms.select(sql)
        else:
            k = list(kwargs.keys())[0]
            v = kwargs.get(k)
            sql='select * from %s where %s=?'%(cls.table_name,k)

            #防sql注入
            sql=sql.replace('?','%s')

            res=ms.select(sql,v)
        if res:
            return [cls(**i) for i in res]

    #新增
    def save(self):
        ms=MySQL()

        #存字段名
        fields=[]
        #存值
        values=[]
        args=[]

        for k,v in self.mapping.items():
            #主键自增,不用给他赋值
            if not v.primary_key:
                fields.append(v.name)
                args.append("?")
                values.append(getattr(self,v.name))

            sql = "insert into %s(%s) values(%s)"%(self.table_name,",".join(fields),",".join((args)))

            sql = sql.replace('?','%s')

        ms.execute(sql,values)

    def update(self):
        ms = MySQL()
        fields = []
        valuse = []
        pr = None
        for k,v in self.mapping.items():
            #获取主键值
            if v.primary_key:
                pr = getattr(self,v.name,v.default)
            else:
                fields.append(v.name+'=?')
                valuse.append(getattr(self,v.name,v.default))
            print(fields,valuse)
        sql = 'update %s set %s where %s = %s'%(self.table_name,','.join(fields),self.primary_key,pr)

        sql = sql.replace('?',"%s")

        ms.execute(sql,valuse)


class Teacher(Models):
    print("teacher")
    table_name='teacher'
    tid = IntegerField(name='tid',primary_key=True)
    tname = StringField(name='tname')

if __name__ == '__main__':
    # tea=Teacher(tname="haha")
    tea2=Teacher(tname="haha",tid=5)
    # print(Teacher.select(tid=1))
    # Teacher.save(tea)
    print(Teacher.update(tea2))

以上是关于简单ORM的实现的主要内容,如果未能解决你的问题,请参考以下文章

利用python实现ORM

自定义注解实现简单的orm映射框架

通过java反射实现简单的关于MongoDB的对象关系映射(ORM).

D2010 RTTI + Attribute 简单实现ORM

python中通过元类(TYPE)简单实现对象关系映射(ORM)

[python] 理解metaclass并实现一个简单ORM框架