构建属于自己的ORM框架之二--IQueryable的奥秘
Posted 張暁磊
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了构建属于自己的ORM框架之二--IQueryable的奥秘相关的知识,希望对你有一定的参考价值。
上篇文章标题乱起,被吐槽了,这次学乖了。
上篇文章中介绍了如何解析Expression生成对应的SQL语句,以及IQueryable的一些概念,以及我们所搭建的框架的思想等。但还没把它们结合并应用起来。这一篇文章将更黄更暴力,揭露IQueryable在实际使用中延迟加载的实现原理,结合上篇对Expression的解析,我们来实现一个自己的“延迟加载”
如果还不太了解如何解析Expression和IQueryable的一些基本概念,可以先看看我的上篇文章
我们先来做些基本工作,定义一个IDataBase接口,里面可以定义些查询,删除,修改,新增等方法,为了节约时间,我们就定义一个查询和删除的方法,再定义一个获取IQueryable<T>实例的方法
public interface IDataBase { List<T> FindAs<T>(Expression<Func<T, bool>> lambdawhere); int Remove<T>(Expression<Func<T, bool>> lambdawhere); IQueryable<T> Source<T>(); }
再添加一个类DBSql,实现我们上面的IDataBase接口,这个类是负责提供对sql数据库的操作
public class DBSql : IDataBase { public List<T> FindAs<T>(Expression<Func<T, bool>> lambdawhere) { throw new NotImplementedException(); } public int Remove<T>(Expression<Func<T, bool>> lambdawhere) { throw new NotImplementedException(); } public IQueryable<T> Source<T>() { throw new NotImplementedException(); } }
IQueryable<T>
上篇文章中有个朋友的回复对IQueryable的解释十分到位,“IQueryable只存贮条件,不立即运行,从而可以实现延迟加载。”那它是如何存贮条件,如何延迟加载的?
这时我们为了提供 public IQueryable<T> Source<T>() 所需的对象。我们再来建一个SqlQuery类,实现IQueryable<T>。
public class SqlQuery<T> : IQueryable<T> { public IEnumerator<T> GetEnumerator() { throw new NotImplementedException(); } System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw new NotImplementedException(); } public Type ElementType { get { throw new NotImplementedException(); } } public Expression Expression { get { throw new NotImplementedException(); } } public IQueryProvider Provider { get { throw new NotImplementedException(); } } }
看到这里大家都不陌生吧?
GetEnumerator()是IEnumerable<T>里的。有了它我们就能foreach了。有泛型和非泛型版本,所以有2个
Type提供访问当前对象的类型(反正由你定义。。。)
Expression是贮存查询条件的
IQueryProvider简单的翻译过来就是查询提供者,它是负责创建查询条件和执行查询的。我们写一个SqlProvider类来实现它
public class SqlProvider<T> : IQueryProvider { public IQueryable<TElement> CreateQuery<TElement>(Expression expression) { throw new NotImplementedException(); } public IQueryable CreateQuery(Expression expression) { throw new NotImplementedException(); } public TResult Execute<TResult>(Expression expression) { throw new NotImplementedException(); } public object Execute(Expression expression) { throw new NotImplementedException(); } }
CreateQuery是创建查询条件。。
平时我们
IQueryable query=xxx源;
query=query.Where(x=>x.Name=="123");
这时Where方法里做的其实就是将前面query的Expression属性和Where里的(x=>x.Name=="123")相并,并且调用Provider属性里的CreateQuery方法。我们可以把我们的代码改成这样,来看看到底是不是这么回事。
public class DBSql : IDataBase { public IQueryable<T> Source<T>() { return new SqlQuery<T>(); } public List<T> FindAs<T>(Expression<Func<T, bool>> lambdawhere) { throw new NotImplementedException(); } public int Remove<T>(Expression<Func<T, bool>> lambdawhere) { throw new NotImplementedException(); } } public class SqlQuery<T> : IQueryable<T> { private Expression _expression; private IQueryProvider _provider; public SqlQuery() { _provider = new SqlProvider<T>(); _expression = Expression.Constant(this); } public SqlQuery(Expression expression, IQueryProvider provider) { _expression = expression; _provider = provider; } public IEnumerator<T> GetEnumerator() { throw new NotImplementedException(); } System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw new NotImplementedException(); } public Type ElementType { get { return typeof(SqlQuery<T>); } } public Expression Expression { get { return _expression; } } public IQueryProvider Provider { get { return _provider; } } } public class SqlProvider<T> : IQueryProvider { public IQueryable<TElement> CreateQuery<TElement>(Expression expression) { IQueryable<TElement> query = new SqlQuery<TElement>(expression, this); return query; } public IQueryable CreateQuery(Expression expression) { throw new NotImplementedException(); } public TResult Execute<TResult>(Expression expression) { throw new NotImplementedException(); } public object Execute(Expression expression) { throw new NotImplementedException(); } }
public class Staff { public int ID { get; set; } public string Code { get; set; } public string Name { get; set; } public DateTime? Birthday { get; set; } public bool Deletion { get; set; } } static void Main(string[] args) { IDataBase db = new DBSql(); IQueryable<Staff> query = db.Source<Staff>(); string name = "张三"; Expression express = null; query = query.Where(x => x.Name == "赵建华"); express = query.Expression; query = query.Where(x => x.Name == name); express = query.Expression; }
段点打在
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
每次query.Where都会跑这里来。并且Expression都是前后相并的结果。
到了这一步,相信大家都明白了IQueryable只存贮条件这个概念了吧。
那延迟加载呢?什么时候加载啊!当我们foreach或者ToList/ToArray时啊。这时你想到了什么?GetEnumerator()。在调用GetEnumerator()时。我们再调用Provider里的Execute(Expression)。里面解析Expression,生成SQL语句,通过反射的方式生成实例,再一个个返回回去。完成!下面我直接给代码了。解析Expression的类我也改了,这个更黄更暴力。
public class ResolveExpression { public Dictionary<string, object> Argument; public string SqlWhere; public SqlParameter[] Paras; private int index = 0; /// <summary> /// 解析lamdba,生成Sql查询条件 /// </summary> /// <param name="expression"></param> /// <returns></returns> public void ResolveToSql(Expression expression) { this.index = 0; this.Argument = new Dictionary<string, object>(); this.SqlWhere = Resolve(expression); this.Paras = Argument.Select(x => new SqlParameter(x.Key, x.Value)).ToArray(); } private object GetValue(Expression expression) { if (expression is ConstantExpression) return (expression as ConstantExpression).Value; if (expression is UnaryExpression) { UnaryExpression unary = expression as UnaryExpression; LambdaExpression lambda = Expression.Lambda(unary.Operand); Delegate fn = lambda.Compile(); return fn.DynamicInvoke(null); } if (expression is MemberExpression) { MemberExpression member = expression as MemberExpression; string name = member.Member.Name; var constant = member.Expression as ConstantExpression; if (constant == null) throw new Exception("取值时发生异常" + member); return constant.Value.GetType().GetFields().First(x => x.Name == name).GetValue(constant.Value); } throw new Exception("无法获取值" + expression); } private string Resolve(Expression expression) { if (expression is LambdaExpression) { LambdaExpression lambda = expression as LambdaExpression; expression = lambda.Body; return Resolve(expression); } if (expression is BinaryExpression)//解析二元运算符 { BinaryExpression binary = expression as BinaryExpression; if (binary.Left is MemberExpression) { object value = GetValue(binary.Right); return ResolveFunc(binary.Left, value, binary.NodeType); } if (binary.Left is MethodCallExpression && (binary.Right is UnaryExpression || binary.Right is MemberExpression)) { object value = GetValue(binary.Right); return ResolveLinqToObject(binary.Left, value, binary.NodeType); } } if (expression is UnaryExpression)//解析一元运算符 { UnaryExpression unary = expression as UnaryExpression; if (unary.Operand is MethodCallExpression) { return ResolveLinqToObject(unary.Operand, false); } if (unary.Operand is MemberExpression) { return ResolveFunc(unary.Operand, false, ExpressionType.Equal); } } if (expression is MethodCallExpression)//解析扩展方法 { return ResolveLinqToObject(expression, true); } if (expression is MemberExpression)//解析属性。。如x.Deletion { return ResolveFunc(expression, true, ExpressionType.Equal); } var body = expression as BinaryExpression; if (body == null) throw new Exception("无法解析" + expression); var Operator = GetOperator(body.NodeType); var Left = Resolve(body.Left); var Right = Resolve(body.Right); string Result = string.Format("({0} {1} {2})", Left, Operator, Right); return Result; } /// <summary> /// 根据条件生成对应的sql查询操作符 /// </summary> /// <param name="expressiontype"></param> /// <returns></returns> private string GetOperator(ExpressionType expressiontype) { switch (expressiontype) { case ExpressionType.And: return "and"; case ExpressionType.AndAlso: return "and"; case ExpressionType.Or: return "or"; case ExpressionType.OrElse: return "or"; case ExpressionType.Equal: return "="; case ExpressionType.NotEqual: return "<>"; case ExpressionType.LessThan: return "<"; case ExpressionType.LessThanOrEqual: return "<="; case ExpressionType.GreaterThan: return ">"; case ExpressionType.GreaterThanOrEqual: return ">="; default: throw new Exception(string.Format("不支持{0}此种运算符查找!" + expressiontype)); } } private string ResolveFunc(Expression left, object value, ExpressionType expressiontype) { string Name = (left as MemberExpression).Member.Name; string Operator = GetOperator(expressiontype); string Value = value.ToString(); string CompName = SetArgument(Name, Value); string Result = string.Format("({0} {1} {2})", Name, Operator, CompName); return Result; } private string ResolveLinqToObject(Expression expression, object value, ExpressionType? expressiontype = null) { var MethodCall = expression as MethodCallExpression; var MethodName = MethodCall.Method.Name; switch (MethodName)//这里其实还可以改成反射调用,不用写switch { case "Contains": if (MethodCall.Object != null) return Like(MethodCall); return In(MethodCall, value); case "Count": return Len(MethodCall, value, expressiontype.Value); case "LongCount": return Len(MethodCall, value, expressiontype.Value); default: throw new Exception(string.Format("不支持{0}方法的查找!", MethodName)); } } private string SetArgument(string name, string value) { name = "@" + name; string temp = name; while (Argument.ContainsKey(temp)) { temp = name + index; index = index + 1; } Argument[temp] = value; return temp; } private string In(MethodCallExpression expression, object isTrue) { var Argument1 = expression.Arguments[0]; var Argument2 = expression.Arguments[1] as MemberExpression; var fieldValue = GetValue(Argument1); object[] array = fieldValue as object[]; List<string> SetInPara = new List<string>(); for (int i = 0; i < array.Length; i++) { string Name_para = "InParameter" + i; string Value = array[i].ToString(); string Key = SetArgument(Name_para, Value); SetInPara.Add(Key); } string Name = Argument2.Member.Name; string Operator = Convert.ToBoolean(isTrue) ? "in" : " not in"; string CompName = string.Join(",", SetInPara); string Result = string.Format("{0} {1} ({2})", Name, Operator, CompName); return Result; } private string Like(MethodCallExpression expression) { Expression argument = expression.Arguments[0]; object Temp_Vale = GetValue(argument); string Value = string.Format("%{0}%", Temp_Vale); string Name = (expression.Object as MemberExpression).Member.Name; string CompName = SetArgument(Name, Value); string Result = string.Format("{0} like {1}", Name, CompName); return Result; } private string Len(MethodCallExpression expression, object value, ExpressionType expressiontype) { object Name = (expression.Arguments[0] as MemberExpression).Member.Name; string Operator = GetOperator(expressiontype); string CompName = SetArgument(Name.ToString(), value.ToString()); string Result = string.Format("len({0}){1}{2}", Name, Operator, CompName); return Result; } }
public interface IDataBase { List<T> FindAs<T>(Expression<Func<T, bool>> lambdawhere); int Remove<T>(Expression<Func<T, bool>> lambdawhere); IQueryable<T> Source<T>(); }
namespace Data.DataBase { public class DBSql : IDataBase { private readonly static string ConnectionString = @"Data Source=.;Initial Catalog=btmmcms-Standard;Persist Security Info=True;User ID=sa;Password=sa;"; public IQueryable<T> Source<T>() { return new SqlQuery<T>(); } public List<T> FindAs<T>(Expression<Func<T, bool>> lambdawhere) { using (SqlConnection Conn = new SqlConnection(ConnectionString)) { using (SqlCommand Command = new SqlCommand()) { try { Command.Connection = Conn; Conn.Open(); string sql = string.Format("select * from {0}", typeof(T).Name); if (lambdawhere != null) { ResolveExpression resolve = new ResolveExpression(); resolve.ResolveToSql(lambdawhere); sql = string.Format("{0} where {1}", sql, resolve.SqlWhere); Command.Parameters.AddRange(resolve.Paras); } //为了测试,就在这里打印出sql语句了 Console.WriteLine(sql); Command.CommandText = sql; SqlDataReader dataReader = Command.ExecuteReader(); List<T> ListEntity = new List<T>(); while (dataReader.Read()) { var constructor = typeof(T).GetConstructor(new Type[] { }); T Entity = (T)constructor.Invoke(null); foreach (var item in Entity.GetType().GetProperties()) { var value = dataReader[item.Name]; if (value == null) continue; if (value is DBNull) value = null; item.SetValue(Entity, value, null); } ListEntity.Add(Entity); } if (ListEntity.Count == 0) return null; return ListEntity; } catch (Exception ex) { throw ex; } finally { Conn.Close(); } } } } public int Remove<T>(Expression<Func<T, bool>> lambdawhere) { throw new NotImplementedException(); } } public class SqlQuery<T> : IQueryable<T> { private Expression _expression; private IQueryProvider _provider; public SqlQuery() { _provider = new SqlProvider<T>(); _expression = Expression.Constant(this); } public SqlQuery(Expression expression, IQueryProvider provider) { _expression = expression; _provider = provider; } public IEnumerator<T> GetEnumerator() { var result = _provider.Execute<List<T>>(_expression); if (result == null) yield break; foreach (var item in result) { yield return item; } } System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw new NotImplementedException(); } public Type ElementType { get { return typeof(SqlQuery<T>); } } public Expression Expression { get { return _expression; } } public IQueryProvider Provider { get { return _provider; } } } public class SqlProvider<T> : IQueryProvider { public IQueryable<TElement> CreateQuery<TElement>(Expression expression) { IQueryable<TElement> query = new SqlQuery<TElement>(expression, this); return query; } public IQueryable CreateQuery(Expression expression) { throw new NotImplementedException(); } public TResult Execute<TResult>(Expression expression) { MethodCallExpression methodCall = expression as MethodCallExpression; Expression<Func<T, bool>> result = null; while (methodCall != null) { Expression method = methodCall.Arguments[0]; Expression lambda = methodCall.Arguments[1]; LambdaExpression right = (lambda as UnaryExpression).Operand as LambdaExpression; if (result == null) { result = Expression.Lambda<Func<T, bool>>(right.Body, right.Parameters); } else { Expression left = (result as LambdaExpression).Body; Expression temp = Expression.And(right.Body, left); result = Expression.Lambda<Func<T, bool>>(temp, result.Parameters); } methodCall = method as MethodCallExpression; } var source = new DBSql().FindAs<T>(result); dynamic _temp = source; TResult t = (TResult)_temp; return t; } public object Execute(Expression expression) { throw new NotImplementedException(); } } }
搞定,这时可以改下数据库连接,连到自己的数据库,然后像下面这样,添加一个实体类(要与数据库表对应),就可以使用了
class Program { public class Staff { public int ID { get; set; } public string Code { get; set; } public string Name { get; set; } public DateTime? Birthday { get; set; } public bool Deletion { get; set; } } static void Main(string[] args) { IDataBase db = new DBSql(); IQueryable<Staff> query = db.Source<Staff>(); query = query.Where(x => x.Name == "张三"); foreach (var item in query) { } } }
是不是很简单?
虽然信息量有点大,但慢慢理清并消化,我相信会对你又很大帮助!
以上是关于构建属于自己的ORM框架之二--IQueryable的奥秘的主要内容,如果未能解决你的问题,请参考以下文章
DRF框架基础四之二次封装Response,数据库关系分析,ORM操作关系,序列化和十大接口