#yyds干货盘点#30个类手写Spring核心原理之自定义ORM(下)

Posted Tom弹架构

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了#yyds干货盘点#30个类手写Spring核心原理之自定义ORM(下)相关的知识,希望对你有一定的参考价值。

3.1 ClassMappings

ClassMappings主要定义基础的映射类型,代码如下:


package com.tom.orm.framework;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class ClassMappings 

   private ClassMappings()

    static final Set<Class<?>> SUPPORTED_SQL_OBJECTS = new HashSet<Class<?>>();

       static 
          //只要这里写了,默认支持自动类型转换
           Class<?>[] classes = 
                   boolean.class, Boolean.class,
                   short.class, Short.class,
                   int.class, Integer.class,
                   long.class, Long.class,
                   float.class, Float.class,
                   double.class, Double.class,
                   String.class,
                   Date.class,
                   Timestamp.class,
                   BigDecimal.class
           ;
           SUPPORTED_SQL_OBJECTS.addAll(Arrays.asList(classes));
       

       static boolean isSupportedSQLObject(Class<?> clazz) 
           return clazz.isEnum() || SUPPORTED_SQL_OBJECTS.contains(clazz);
       

       public static Map<String, Method> findPublicGetters(Class<?> clazz) 
           Map<String, Method> map = new HashMap<String, Method>();
           Method[] methods = clazz.getMethods();
           for (Method method : methods) 
               if (Modifier.isStatic(method.getModifiers()))
                   continue;
               if (method.getParameterTypes().length != 0)
                   continue;
               if (method.getName().equals("getClass"))
                   continue;
               Class<?> returnType = method.getReturnType();
               if (void.class.equals(returnType))
                   continue;
               if(!isSupportedSQLObject(returnType))
                  continue;
               
               if ((returnType.equals(boolean.class)
                       || returnType.equals(Boolean.class))
                       && method.getName().startsWith("is")
                       && method.getName().length() > 2) 
                   map.put(getGetterName(method), method);
                   continue;
               
               if ( ! method.getName().startsWith("get"))
                   continue;
               if (method.getName().length() < 4)
                   continue;
               map.put(getGetterName(method), method);
           
           return map;
       

       public static Field[] findFields(Class<?> clazz)
           return clazz.getDeclaredFields();
       

       public static Map<String, Method> findPublicSetters(Class<?> clazz) 
           Map<String, Method> map = new HashMap<String, Method>();
           Method[] methods = clazz.getMethods();
           for (Method method : methods) 
               if (Modifier.isStatic(method.getModifiers()))
                   continue;
               if ( ! void.class.equals(method.getReturnType()))
                   continue;
               if (method.getParameterTypes().length != 1)
                   continue;
               if ( ! method.getName().startsWith("set"))
                   continue;
               if (method.getName().length() < 4)
                   continue;
               if(!isSupportedSQLObject(method.getParameterTypes()[0]))
                  continue;
               
               map.put(getSetterName(method), method);
           
           return map;
       

       public static String getGetterName(Method getter) 
           String name = getter.getName();
           if (name.startsWith("is"))
               name = name.substring(2);
           else
               name = name.substring(3);
           return Character.toLowerCase(name.charAt(0)) + name.substring(1);
       

       private static String getSetterName(Method setter) 
           String name = setter.getName().substring(3);
           return Character.toLowerCase(name.charAt(0)) + name.substring(1);
       

3.2 EntityOperation

EntityOperation主要实现数据库表结构和对象类结构的映射关系,代码如下:


package com.tom.orm.framework;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.Id;
import javax.persistence.Table;
import javax.persistence.Transient;
import org.apache.log4j.Logger;
import org.springframework.jdbc.core.RowMapper;
import javax.core.common.utils.StringUtils;

/**
 * 实体对象的反射操作
 *
 * @param <T>
 */
public class EntityOperation<T> 
   private Logger log = Logger.getLogger(EntityOperation.class);
   public Class<T> entityClass = null; // 泛型实体Class对象
   public final Map<String, PropertyMapping> mappings;
   public final RowMapper<T> rowMapper;

   public final String tableName;
   public String allColumn = "*";
   public Field pkField;

   public EntityOperation(Class<T> clazz,String pk) throws Exception
      if(!clazz.isAnnotationPresent(Entity.class))
         throw new Exception("在" + clazz.getName() + "中没有找到Entity注解,不能做ORM映射");
      
      this.entityClass = clazz;
      Table table = entityClass.getAnnotation(Table.class);
       if (table != null) 
             this.tableName = table.name();
        else 
             this.tableName =  entityClass.getSimpleName();
       
      Map<String, Method> getters = ClassMappings.findPublicGetters(entityClass);
       Map<String, Method> setters = ClassMappings.findPublicSetters(entityClass);
       Field[] fields = ClassMappings.findFields(entityClass);
       fillPkFieldAndAllColumn(pk,fields);
       this.mappings = getPropertyMappings(getters, setters, fields);
       this.allColumn = this.mappings.keySet().toString().replace("[", "").replace("]",""). replaceAll(" ","");
       this.rowMapper = createRowMapper();
   

    Map<String, PropertyMapping> getPropertyMappings(Map<String, Method> getters, Map<String, Method> setters, Field[] fields) 
        Map<String, PropertyMapping> mappings = new HashMap<String, PropertyMapping>();
        String name;
        for (Field field : fields) 
            if (field.isAnnotationPresent(Transient.class))
                continue;
            name = field.getName();
            if(name.startsWith("is"))
               name = name.substring(2);
            
            name = Character.toLowerCase(name.charAt(0)) + name.substring(1);
            Method setter = setters.get(name);
            Method getter = getters.get(name);
            if (setter == null || getter == null)
                continue;
            
            Column column = field.getAnnotation(Column.class);
            if (column == null) 
                mappings.put(field.getName(), new PropertyMapping(getter, setter, field));
             else 
                mappings.put(column.name(), new PropertyMapping(getter, setter, field));
            
        
        return mappings;
    

   RowMapper<T> createRowMapper() 
           return new RowMapper<T>() 
               public T mapRow(ResultSet rs, int rowNum) throws SQLException 
                   try 
                       T t = entityClass.newInstance();
                       ResultSetMetaData meta = rs.getMetaData();
                       int columns = meta.getColumnCount();
                       String columnName;
                       for (int i = 1; i <= columns; i++) 
                           Object value = rs.getObject(i);
                           columnName = meta.getColumnName(i);
                           fillBeanFieldValue(t,columnName,value);
                       
                       return t;
                   catch (Exception e) 
                       throw new RuntimeException(e);
                   
               
           ;
       

   protected void fillBeanFieldValue(T t, String columnName, Object value) 
       if (value != null) 
             PropertyMapping pm = mappings.get(columnName);
             if (pm != null) 
                 try 
               pm.set(t, value);
             catch (Exception e) 
               e.printStackTrace();
            
             
         
   

   private void fillPkFieldAndAllColumn(String pk, Field[] fields) 
      //设定主键
       try 
          if(!StringUtils.isEmpty(pk))
             pkField = entityClass.getDeclaredField(pk);
             pkField.setAccessible(true);
          
        catch (Exception e) 
             log.debug("没找到主键列,主键列名必须与属性名相同");
       
      for (int i = 0 ; i < fields.length ;i ++) 
         Field f = fields[i];
         if(StringUtils.isEmpty(pk))
            Id id = f.getAnnotation(Id.class);
            if(id != null)
               pkField = f;
               break;
            
         
      
   

   public T parse(ResultSet rs) 
      T t = null;
      if (null == rs) 
         return null;
      
      Object value = null;
      try 
         t = (T) entityClass.newInstance();
         for (String columnName : mappings.keySet()) 
            try 
               value = rs.getObject(columnName);
             catch (Exception e) 
               e.printStackTrace();
            
            fillBeanFieldValue(t,columnName,value);
         
       catch (Exception ex) 
         ex.printStackTrace();
      
      return t;
   

   public Map<String, Object> parse(T t) 
      Map<String, Object> _map = new TreeMap<String, Object>();
      try 

         for (String columnName : mappings.keySet()) 
            Object value = mappings.get(columnName).getter.invoke(t);
            if (value == null)
               continue;
            _map.put(columnName, value);

         
       catch (Exception e) 
         e.printStackTrace();
      
      return _map;
   

   public void println(T t) 
      try 
         for (String columnName : mappings.keySet()) 
            Object value = mappings.get(columnName).getter.invoke(t);
            if (value == null)
               continue;
            System.out.println(columnName + " = " + value);
         
       catch (Exception e) 
         e.printStackTrace();
      
   


class PropertyMapping 

    final boolean insertable;
    final boolean updatable;
    final String columnName;
    final boolean id;
    final Method getter;
    final Method setter;
    final Class enumClass;
    final String fieldName;

    public PropertyMapping(Method getter, Method setter, Field field) 
        this.getter = getter;
        this.setter = setter;
        this.enumClass = getter.getReturnType().isEnum() ? getter.getReturnType() : null;
        Column column = field.getAnnotation(Column.class);
        this.insertable = column == null || column.insertable();
        this.updatable = column == null || column.updatable();
        this.columnName = column == null ? ClassMappings.getGetterName(getter) : ("".equals(column.name()) ? ClassMappings.getGetterName(getter) : column.name());
        this.id = field.isAnnotationPresent(Id.class);
        this.fieldName = field.getName();
    

    @SuppressWarnings("unchecked")
    Object get(Object target) throws Exception 
        Object r = getter.invoke(target);
        return enumClass == null ? r : Enum.valueOf(enumClass, (String) r);
    

    @SuppressWarnings("unchecked")
    void set(Object target, Object value) throws Exception 
        if (enumClass != null && value != null) 
            value = Enum.valueOf(enumClass, (String) value);
        
        //BeanUtils.setProperty(target, fieldName, value);
        try 
            if(value != null)
                setter.invoke(target, setter.getParameterTypes()[0].cast(value));
             
       catch (Exception e) 
         e.printStackTrace();
         /**
          * 出错原因如果是boolean字段、mysql字段类型,设置tinyint(1)
          */
         System.err.println(fieldName + "--" + value);
      

    

3.3 QueryRuleSqlBuilder

QueryRuleSqlBuilder根据用户构建好的QueryRule来自动生成SQL语句,代码如下:


package com.tom.orm.framework;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang.ArrayUtils;
import com.tom.orm.framework.QueryRule.Rule;
import javax.core.common.utils.StringUtils;

/**
 * 根据QueryRule自动构建SQL语句
 */
public class QueryRuleSqlBuilder 

   private int CURR_INDEX = 0; //记录参数所在的位置
   private List<String> properties; //保存列名列表
   private List<Object> values; //保存参数值列表
   private List<Order> orders; //保存排序规则列表

   private String whereSql = ""; 
   private String orderSql = "";
   private Object [] valueArr = new Object[];
   private Map<Object,Object> valueMap = new HashMap<Object,Object>();

   /**
    * 获得查询条件
    * @return
    */
   public String getWhereSql()
      return this.whereSql;
   

   /**
    * 获得排序条件
    * @return
    */
   public String getOrderSql()
      return this.orderSql;
   

   /**
    * 获得参数值列表
    * @return
    */
   public Object [] getValues()
      return this.valueArr;
   

   /**
    * 获取参数列表
    * @return
    */
   public Map<Object,Object> getValueMap()
      return this.valueMap;
   

   /**
    * 创建SQL构造器
    * @param queryRule
    */
   public  QueryRuleSqlBuilder(QueryRule queryRule) 
      CURR_INDEX = 0;
      properties = new ArrayList<String>();
      values = new ArrayList<Object>();
      orders = new ArrayList<Order>();
      for (QueryRule.Rule rule : queryRule.getRuleList()) 
         switch (rule.getType()) 
         case QueryRule.BETWEEN:
            processBetween(rule);
            break;
         case QueryRule.EQ:
            processEqual(rule);
            break;
         case QueryRule.LIKE:
            processLike(rule);
            break;
         case QueryRule.NOTEQ:
            processNotEqual(rule);
            break;
         case QueryRule.GT:
            processGreaterThen(rule);
            break;
         case QueryRule.GE:
            processGreaterEqual(rule);
            break;
         case QueryRule.LT:
            processLessThen(rule);
            break;
         case QueryRule.LE:
            processLessEqual(rule);
            break;
         case QueryRule.IN:
            processIN(rule);
            break;
         case QueryRule.NOTIN:
            processNotIN(rule);
            break;
         case QueryRule.ISNULL:
            processIsNull(rule);
            break;
         case QueryRule.ISNOTNULL:
            processIsNotNull(rule);
            break;
         case QueryRule.ISEMPTY:
            processIsEmpty(rule);
            break;
         case QueryRule.ISNOTEMPTY:
            processIsNotEmpty(rule);
            break;
         case QueryRule.ASC_ORDER:
            processOrder(rule);
            break;
         case QueryRule.DESC_ORDER:
            processOrder(rule);
            break;
         default:
            throw new IllegalArgumentException("type " + rule.getType() + " not supported.");
         
      
      //拼装where语句
      appendWhereSql();
      //拼装排序语句
      appendOrderSql();
      //拼装参数值
      appendValues();
   

   /**
    * 去掉order
    * 
    * @param sql
    * @return
    */
   protected String removeOrders(String sql) 
      Pattern p = Pattern.compile("order\\\\s*by[\\\\w|\\\\W|\\\\s|\\\\S]*", Pattern.CASE_INSENSITIVE);
      Matcher m = p.matcher(sql);
      StringBuffer sb = new StringBuffer();
      while (m.find()) 
         m.appendReplacement(sb, "");
      
      m.appendTail(sb);
      return sb.toString();
   

   /**
    * 去掉select
    * 
    * @param sql
    * @return
    */
   protected String removeSelect(String sql) 
      if(sql.toLowerCase().matches("from\\\\s+"))
         int beginPos = sql.toLowerCase().indexOf("from");
         return sql.substring(beginPos);
      else
         return sql;
      
   

   /**
    * 处理like
    * @param rule
    */
   private  void processLike(QueryRule.Rule rule) 
      if (ArrayUtils.isEmpty(rule.getValues())) 
         return;
      
      Object obj = rule.getValues()[0];

      if (obj != null) 
         String value = obj.toString();
         if (!StringUtils.isEmpty(value)) 
            value = value.replace(*, %);
            obj = value;
         
      
      add(rule.getAndOr(),rule.getPropertyName(),"like","%"+rule.getValues()[0]+"%");
   

   /**
    * 处理between
    * @param rule
    */
   private  void processBetween(QueryRule.Rule rule) 
      if ((ArrayUtils.isEmpty(rule.getValues()))
            || (rule.getValues().length < 2)) 
         return;
      
      add(rule.getAndOr(),rule.getPropertyName(),"","between",rule.getValues()[0],"and");
      add(0,"","","",rule.getValues()[1],"");
   

   /**
    * 处理 =
    * @param rule
    */
   private  void processEqual(QueryRule.Rule rule) 
      if (ArrayUtils.isEmpty(rule.getValues())) 
         return;
      
      add(rule.getAndOr(),rule.getPropertyName(),"=",rule.getValues()[0]);
   

   /**
    * 处理 <>
    * @param rule
    */
   private  void processNotEqual(QueryRule.Rule rule) 
      if (ArrayUtils.isEmpty(rule.getValues())) 
         return;
      
      add(rule.getAndOr(),rule.getPropertyName(),"<>",rule.getValues()[0]);
   

   /**
    * 处理 >
    * @param rule
    */
   private  void processGreaterThen(
         QueryRule.Rule rule) 
      if (ArrayUtils.isEmpty(rule.getValues())) 
         return;
      
      add(rule.getAndOr(),rule.getPropertyName(),">",rule.getValues()[0]);
   

   /**
    * 处理>=
    * @param rule
    */
   private  void processGreaterEqual(
         QueryRule.Rule rule) 
      if (ArrayUtils.isEmpty(rule.getValues())) 
         return;
      
      add(rule.getAndOr(),rule.getPropertyName(),">=",rule.getValues()[0]);
   

   /**
    * 处理<
    * @param rule
    */
   private  void processLessThen(QueryRule.Rule rule) 
      if (ArrayUtils.isEmpty(rule.getValues())) 
         return;
      
      add(rule.getAndOr(),rule.getPropertyName(),"<",rule.getValues()[0]);
   

   /**
    * 处理<=
    * @param rule
    */
   private  void processLessEqual(
         QueryRule.Rule rule) 
      if (ArrayUtils.isEmpty(rule.getValues())) 
         return;
      
      add(rule.getAndOr(),rule.getPropertyName(),"<=",rule.getValues()[0]);
   

   /**
    * 处理  is null
    * @param rule
    */
   private  void processIsNull(QueryRule.Rule rule) 
      add(rule.getAndOr(),rule.getPropertyName(),"is null",null);
   

   /**
    * 处理 is not null
    * @param rule
    */
   private  void processIsNotNull(QueryRule.Rule rule) 
      add(rule.getAndOr(),rule.getPropertyName(),"is not null",null);
   

   /**
    * 处理  <>
    * @param rule
    */
   private  void processIsNotEmpty(QueryRule.Rule rule) 
      add(rule.getAndOr(),rule.getPropertyName(),"<>","");
   

   /**
    * 处理 =
    * @param rule
    */
   private  void processIsEmpty(QueryRule.Rule rule) 
      add(rule.getAndOr(),rule.getPropertyName(),"=","");
   

   /**
    * 处理in和not in
    * @param rule
    * @param name
    */
   private void inAndNotIn(QueryRule.Rule rule,String name)
      if (ArrayUtils.isEmpty(rule.getValues())) 
         return;
      
      if ((rule.getValues().length == 1) && (rule.getValues()[0] != null)
            && (rule.getValues()[0] instanceof List)) 
         List<Object> list = (List) rule.getValues()[0];

         if ((list != null) && (list.size() > 0))
            for (int i = 0; i < list.size(); i++) 
               if(i == 0 && i == list.size() - 1)
                  add(rule.getAndOr(),rule.getPropertyName(),"",name + " (",list.get(i),")");
               else if(i == 0 && i < list.size() - 1)
                  add(rule.getAndOr(),rule.getPropertyName(),"",name + " (",list.get(i),"");
               
               if(i > 0 && i < list.size() - 1)
                  add(0,"",",","",list.get(i),"");
               
               if(i == list.size() - 1 && i != 0)
                  add(0,"",",","",list.get(i),")");
               
            
         
       else 
         Object[] list =  rule.getValues();
         for (int i = 0; i < list.length; i++) 
            if(i == 0 && i == list.length - 1)
               add(rule.getAndOr(),rule.getPropertyName(),"",name + " (",list[i],")");
            else if(i == 0 && i < list.length - 1)
               add(rule.getAndOr(),rule.getPropertyName(),"",name + " (",list[i],"");
            
            if(i > 0 && i < list.length - 1)
               add(0,"",",","",list[i],"");
            
            if(i == list.length - 1 && i != 0)
               add(0,"",",","",list[i],")");
            
         
      
   

   /**
    * 处理 not in
    * @param rule
    */
   private void processNotIN(QueryRule.Rule rule)
      inAndNotIn(rule,"not in");
   

   /**
    * 处理 in
    * @param rule
    */
   private  void processIN(QueryRule.Rule rule) 
      inAndNotIn(rule,"in");
   

   /**
    * 处理 order by
    * @param rule 查询规则
    */
   private void processOrder(Rule rule) 
      switch (rule.getType()) 
      case QueryRule.ASC_ORDER:
         //propertyName非空
         if (!StringUtils.isEmpty(rule.getPropertyName())) 
            orders.add(Order.asc(rule.getPropertyName()));
         
         break;
      case QueryRule.DESC_ORDER:
         //propertyName非空
         if (!StringUtils.isEmpty(rule.getPropertyName())) 
            orders.add(Order.desc(rule.getPropertyName()));
         
         break;
      default:
         break;
      
   

   /**
    * 加入SQL查询规则队列
    * @param andOr and 或者 or
    * @param key 列名
    * @param split 列名与值之间的间隔
    * @param value 值
    */
   private  void add(int andOr,String key,String split ,Object value)
      add(andOr,key,split,"",value,"");
   

   /**
    * 加入SQL查询规则队列
    * @param andOr and 或则 or
    * @param key 列名
    * @param split 列名与值之间的间隔
    * @param prefix 值前缀
    * @param value 值
    * @param suffix 值后缀
    */
   private void add(int andOr,String key,String split,String prefix,Object value,String suffix)
      String andOrStr = (0 == andOr ? "" :(QueryRule.AND == andOr ? " and " : " or "));  
      properties.add(CURR_INDEX, andOrStr + key + " " + split + prefix + (null != value ? " ? " : " ") + suffix);
      if(null != value)
         values.add(CURR_INDEX,value);
         CURR_INDEX ++;
      
   

   /**
    * 拼装 where 语句
    */
   private void appendWhereSql()
      StringBuffer whereSql = new StringBuffer();
      for (String p : properties) 
         whereSql.append(p);
      
      this.whereSql = removeSelect(removeOrders(whereSql.toString()));
   

   /**
    * 拼装排序语句
    */
   private void appendOrderSql()
      StringBuffer orderSql = new StringBuffer();
      for (int i = 0 ; i < orders.size(); i ++) 
         if(i > 0 && i < orders.size())
            orderSql.append(",");
         
         orderSql.append(orders.get(i).toString());
      
      this.orderSql = removeSelect(removeOrders(orderSql.toString()));
   

   /**
    * 拼装参数值
    */
   private void appendValues()
      Object [] val = new Object[values.size()];
      for (int i = 0; i < values.size(); i ++) 
         val[i] = values.get(i);
         valueMap.put(i, values.get(i));
      
      this.valueArr = val;
   

3.4 BaseDaoSupport

BaseDaoSupport主要是对JdbcTemplate的包装,下面讲一下其重要代码,请“小伙伴们” 关 注 公 众 号 『 Tom弹架构 』,回复 " Spring " 可下载全部源代码。先看全局定义:


package com.tom.orm.framework;

...

/**
 * BaseDao 扩展类,主要功能是支持自动拼装SQL语句,必须继承方可使用
 * @author Tom
 */
public abstract class BaseDaoSupport<T extends Serializable, PK extends Serializable> implements BaseDao<T,PK> 
   private Logger log = Logger.getLogger(BaseDaoSupport.class);

   private String tableName = "";

   private JdbcTemplate jdbcTemplateWrite;
   private JdbcTemplate jdbcTemplateReadOnly;

   private DataSource dataSourceReadOnly;
   private DataSource dataSourceWrite;

   private EntityOperation<T> op;

   @SuppressWarnings("unchecked")
   protected BaseDaoSupport()
      try
         Class<T> entityClass = GenericsUtils.getSuperClassGenricType(getClass(), 0);
         op = new EntityOperation<T>(entityClass,this.getPKColumn());
         this.setTableName(op.tableName);
      catch(Exception e)
         e.printStackTrace();
      
   

   protected String getTableName()  return tableName; 
   protected DataSource getDataSourceReadOnly()  return dataSourceReadOnly;  
   protected DataSource getDataSourceWrite()  return dataSourceWrite;  

   /**
    * 动态切换表名
    */
   protected void setTableName(String tableName) 
      if(StringUtils.isEmpty(tableName))
         this.tableName = op.tableName;
      else
         this.tableName = tableName;
      
   

   protected void setDataSourceWrite(DataSource dataSourceWrite) 
      this.dataSourceWrite = dataSourceWrite;
      jdbcTemplateWrite = new JdbcTemplate(dataSourceWrite);
   

   protected void setDataSourceReadOnly(DataSource dataSourceReadOnly) 
      this.dataSourceReadOnly = dataSourceReadOnly;
      jdbcTemplateReadOnly = new JdbcTemplate(dataSourceReadOnly);
   

   private JdbcTemplate jdbcTemplateReadOnly() 
      return this.jdbcTemplateReadOnly;
   

   private JdbcTemplate jdbcTemplateWrite() 
      return this.jdbcTemplateWrite;
   

   /**
    * 还原默认表名
    */
   protected void restoreTableName() this.setTableName(op.tableName);  

   /**
    * 获取主键列名称,建议子类重写
    * @return
    */
   protected abstract String getPKColumn();

   protected abstract void setDataSource(DataSource dataSource);

//此处有省略


为了照顾程序员的一般使用习惯,查询方法的前缀命名主要有select、get、load,兼顾Hibernate和MyBatis的命名风格。


/**
    * 查询函数,使用查询规则
    * 例如以下代码查询条件为匹配的数据
    *
    * @param queryRule 查询规则
    * @return 查询的结果List
    */
   public List<T> select(QueryRule queryRule) throws Exception
      QueryRuleSqlBuilder bulider = new QueryRuleSqlBuilder(queryRule);
      String ws = removeFirstAnd(bulider.getWhereSql());
      String whereSql = ("".equals(ws) ? ws : (" where " + ws));
      String sql = "select " + op.allColumn + " from " + getTableName() + whereSql;
      Object [] values = bulider.getValues();
      String orderSql = bulider.getOrderSql();
      orderSql = (StringUtils.isEmpty(orderSql) ? " " : (" order by " + orderSql));
      sql += orderSql;
      log.debug(sql);
      return (List<T>) this.jdbcTemplateReadOnly().query(sql, this.op.rowMapper, values);
   

...

   /**
    * 根据SQL语句执行查询,参数为Object数组对象
    * @param sql 查询语句
    * @param args 为Object数组
    * @return 符合条件的所有对象
    */
   public List<Map<String,Object>> selectBySql(String sql,Object... args) throws Exception
      return this.jdbcTemplateReadOnly().queryForList(sql,args);
   

...

   /**
    * 分页查询函数,使用查询规则<br>
    * 例如以下代码查询条件为匹配的数据
    *
    * @param queryRule 查询规则
    * @param pageNo 页号,从1开始
    * @param pageSize 每页的记录条数
    * @return 查询的结果Page
    */
   public Page<T> select(QueryRule queryRule,final int pageNo, final int pageSize) throws Exception
      QueryRuleSqlBuilder bulider = new QueryRuleSqlBuilder(queryRule);
      Object [] values = bulider.getValues();
      String ws = removeFirstAnd(bulider.getWhereSql());
      String whereSql = ("".equals(ws) ? ws : (" where " + ws));
      String countSql = "select count(1) from " + getTableName() + whereSql;
      long count = (Long) this.jdbcTemplateReadOnly().queryForMap(countSql, values).get ("count(1)");
      if (count == 0) 
         return new Page<T>();
      
      long start = (pageNo - 1) * pageSize;
      //在有数据的情况下,继续查询
      String orderSql = bulider.getOrderSql();
      orderSql = (StringUtils.isEmpty(orderSql) ? " " : (" order by " + orderSql));
      String sql = "select " + op.allColumn +" from " + getTableName() + whereSql + orderSql + " limit " + start + "," + pageSize;
      List<T> list = (List<T>) this.jdbcTemplateReadOnly().query(sql, this.op.rowMapper, values);
      log.debug(sql);
      return new Page<T>(start, count, pageSize, list);
   
...

   /**
    * 分页查询特殊SQL语句
    * @param sql 语句
    * @param param  查询条件
    * @param pageNo   页码
    * @param pageSize 每页内容
    * @return
    */
   public Page<Map<String,Object>> selectBySqlToPage(String sql, Object [] param, final int pageNo, final int pageSize) throws Exception 
      String countSql = "select count(1) from (" + sql + ") a";

      long count = (Long) this.jdbcTemplateReadOnly().queryForMap(countSql,param).get("count(1)");
      if (count == 0) 
         return new Page<Map<String,Object>>();
      
      long start = (pageNo - 1) * pageSize;
      sql = sql + " limit " + start + "," + pageSize;
      List<Map<String,Object>> list = (List<Map<String,Object>>) this.jdbcTemplateReadOnly(). queryForList(sql, param);
      log.debug(sql);
      return new Page<Map<String,Object>>(start, count, pageSize, list);
   

/**
    * 获取默认的实例对象
    * @param <T>
    * @param pkValue
    * @param rowMapper
    * @return
    */
   private <T> T doLoad(Object pkValue, RowMapper<T> rowMapper)
      Object obj = this.doLoad(getTableName(), getPKColumn(), pkValue, rowMapper);
      if(obj != null)
         return (T)obj;
      
      return null;
   

插入方法,均以insert开头:


/**
    * 插入并返回ID
    * @param entity
    * @return
    */
   public PK insertAndReturnId(T entity) throws Exception
      return (PK)this.doInsertRuturnKey(parse(entity));
   

   /**
    * 插入一条记录
    * @param entity
    * @return
    */
   public boolean insert(T entity) throws Exception
      return this.doInsert(parse(entity));
   
/**
    * 批量保存对象.<br>
    *
    * @param list 待保存的对象List
    * @throws InvocationTargetException
    * @throws IllegalArgumentException
    * @throws IllegalAccessException
    */
   public int insertAll(List<T> list) throws Exception 
      int count = 0 ,len = list.size(),step = 50000;
      Map<String, PropertyMapping> pm = op.mappings;
      int maxPage = (len % step == 0) ? (len / step) : (len / step + 1);
      for (int i = 1; i <= maxPage; i ++) 
         Page<T> page = pagination(list, i, step);
         String sql = "insert into " + getTableName() + "(" + op.allColumn + ") values ";// (" + valstr.toString() + ")";
         StringBuffer valstr = new StringBuffer();
         Object[] values = new Object[pm.size() * page.getRows().size()];
         for (int j = 0; j < page.getRows().size(); j ++) 
            if(j > 0 && j < page.getRows().size()) valstr.append(","); 
            valstr.append("(");
            int k = 0;
            for (PropertyMapping p : pm.values()) 
               values[(j * pm.size()) + k] = p.getter.invoke(page.getRows().get(j));
               if(k > 0 && k < pm.size()) valstr.append(","); 
               valstr.append("?");
               k ++;
            
            valstr.append(")");
         
         int result = jdbcTemplateWrite().update(sql + valstr.toString(), values);
         count += result;
      

      return count;
   

private Serializable doInsertRuturnKey(Map<String,Object> params)
      final List<Object> values = new ArrayList<Object>();
      final String sql = makeSimpleInsertSql(getTableName(),params,values);
      KeyHolder keyHolder = new GeneratedKeyHolder();
      final JdbcTemplate jdbcTemplate = new JdbcTemplate(getDataSourceWrite());
        try               

             jdbcTemplate.update(new PreparedStatementCreator() 
            public PreparedStatement createPreparedStatement(

                  Connection con) throws SQLException 
               PreparedStatement ps = con.prepareStatement(sql,Statement.RETURN_GENERATED_KEYS);

               for (int i = 0; i < values.size(); i++) 
                  ps.setObject(i+1, values.get(i)==null?null:values.get(i));

               
               return ps;
             

         , keyHolder);
         catch (DataAccessException e) 
           log.error("error",e);
        

      if (keyHolder == null)  return ""; 

      Map<String, Object> keys = keyHolder.getKeys();
      if (keys == null || keys.size() == 0 || keys.values().size() == 0) 
         return "";
      
      Object key = keys.values().toArray()[0];
      if (key == null || !(key instanceof Serializable)) 
         return "";
      
      if (key instanceof Number) 
         //Long k = (Long) key;
         Class clazz = key.getClass();
//       return clazz.cast(key);
         return (clazz == int.class || clazz == Integer.class) ? ((Number) key).intValue() : ((Number)key).longValue();

       else if (key instanceof String) 
         return (String) key;
       else 
         return (Serializable) key;
      

   

/**
    * 插入
    * @param params
    * @return
    */
   private boolean doInsert(Map<String, Object> params) 
      String sql = this.makeSimpleInsertSql(this.getTableName(), params);
      int ret = this.jdbcTemplateWrite().update(sql, params.values().toArray());
      return ret > 0;
   

删除方法,均以delete开头:


/**
    * 删除对象.<br>
    *
    * @param entity 待删除的实体对象
    */
   public boolean delete(T entity) throws Exception 
        return this.doDelete(op.pkField.get(entity)) > 0;
   

   /**
    * 删除对象.<br>
    *
    * @param list 待删除的实体对象列表
    * @throws InvocationTargetException
    * @throws IllegalArgumentException
    * @throws IllegalAccessException
    */
   public int deleteAll(List<T> list) throws Exception 
      String pkName = op.pkField.getName();
      int count = 0 ,len = list.size(),step = 1000;
      Map<String, PropertyMapping> pm = op.mappings;
      int maxPage = (len % step == 0) ? (len / step) : (len / step + 1);
      for (int i = 1; i <= maxPage; i ++) 
         StringBuffer valstr = new StringBuffer();
         Page<T> page = pagination(list, i, step);
         Object[] values = new Object[page.getRows().size()];

         for (int j = 0; j < page.getRows().size(); j ++) 
            if(j > 0 && j < page.getRows().size()) valstr.append(","); 
            values[j] = pm.get(pkName).getter.invoke(page.getRows().get(j));
            valstr.append("?");
         

         String sql = "delete from " + getTableName() + " where " + pkName + " in (" + valstr.toString() + ")";
         int result = jdbcTemplateWrite().update(sql, values);
         count += result;
      
      return count;
   

   /**
    * 根据id删除对象。如果有记录则删之,没有记录也不报异常<br>
    * 例如:删除主键唯一的记录
    *
    * @param id 序列化id
    */
   protected void deleteByPK(PK id)  throws Exception 
      this.doDelete(id);
   

/**
    * 删除实例对象,返回删除记录数
    * @param tableName
    * @param pkName
    * @param pkValue
    * @return
    */
   private int doDelete(String tableName, String pkName, Object pkValue) 
      StringBuffer sb = new StringBuffer();
      sb.append("delete from ").append(tableName).append(" where ").append(pkName).append(" = ?");
      int ret = this.jdbcTemplateWrite().update(sb.toString(), pkValue);
      return ret;
   

修改方法,均以update开头:


/**
    * 更新对象.<br>
    *
    * @param entity 待更新对象
    * @throws IllegalAccessException
    * @throws IllegalArgumentException
    */
   public boolean update(T entity) throws Exception 
      return this.doUpdate(op.pkField.get(entity), parse(entity)) > 0;
   

/**
    * 更新实例对象,返回删除记录数
    * @param pkValue
    * @param params
    * @return
    */
   private int doUpdate(Object pkValue, Map<String, Object> params)
      String sql = this.makeDefaultSimpleUpdateSql(pkValue, params);
      params.put(this.getPKColumn(), pkValue);
      int ret = this.jdbcTemplateWrite().update(sql, params.values().toArray());
      return ret;
   

至此一个完整的ORM框架就横空出世。当然,还有很多优化的地方,请小伙伴可以继续完善。

关注微信公众号『 Tom弹架构 』回复“Spring”可获取完整源码。

原创不易,坚持很酷,都看到这里了,小伙伴记得点赞、收藏、在看,一键三连加关注!如果你觉得内容太干,可以分享转发给朋友滋润滋润!

以上是关于#yyds干货盘点#30个类手写Spring核心原理之自定义ORM(下)的主要内容,如果未能解决你的问题,请参考以下文章

#yyds干货盘点#30个类手写Spring核心原理之AOP代码织入

#yyds干货盘点#30个类手写Spring核心原理之MVC映射功能

#yyds干货盘点#30个类手写Spring核心原理之自定义ORM(上)

#yyds干货盘点# 爆肝30天,肝出来史上最透彻Spring原理和27道高频面试题总结

单例模式八个例子#yyds干货盘点#

#yyds干货盘点# Spring核心之控制反转(IOC)