深入底层,仿MyBatis自己写框架

Posted Java大联盟

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深入底层,仿MyBatis自己写框架相关的知识,希望对你有一定的参考价值。


前言:


最近研究了一下Mybatis的底层代码,写了一个操作数据库的小工具,实现了Mybatis的部分功能:

1.SQL语句在mapper.xml中配置。

2.支持int,String,自定义数据类型的入参。

3.根据mapper.xml动态创建接口的代理实现对象。


功能有限,目的是搞清楚MyBatis框架的底层思想,多学习研究优秀框架的实现思路,对提升自己的编码能力大有裨益。


小工具使用到的核心技术点:xml解析+反射+jdk动态代理


接下来,一步一步来实现。


首先来说为什么要使用jdk动态代理。

传统的开发方式:

1.接口定义业务方法。

2.实现类实现业务方法。

3.实例化实现类对象来完成业务操作。


接口:


public interface UserDAO {
   public User get(int id);
}


实现类:


public class UserDAOImpl implements UserDAO{

   @Override
   public User get(int id) {
       Connection conn = JDBCTools.getConnection();
       String sql = "select * from user where id = ?";
       PreparedStatement pstmt = null;
       ResultSet rs = null;
       try {
           pstmt = conn.prepareStatement(sql);
           pstmt.setInt(1, id);
           rs = pstmt.executeQuery();
           if(rs.next()){
               int sid = rs.getInt(1);
               String name = rs.getString(2);
               User user = new User(sid,name);
               return user;
           }
       } catch (Exception e) {
           // TODO Auto-generated catch block
           e.printStackTrace();
       }finally{
           JDBCTools.release(conn, pstmt, rs);
       }
       return null;
   }

}


测试:


public static void main(String[] args) {

       UserDAO userDAO = new UserDAOImpl();
       User user = userDAO.get(1);
       System.out.println(user);

   }


Mybatis的方式:

1.开发者只需要创建接口,定义业务方法。

2.不需要创建实现类。

3.具体的业务操作通过配置xml来完成。


接口:


public interface StudentDAO {
   public Student getById(int id);
   public Student getByStudent(Student student);
   public Student getByName(String name);
   public Student getByStudent2(Student student);
}


StudentDAO.xml:


<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.southwind.dao.StudentDAO">

   <select id="getById" parameterType="int"
       resultType="com.southwind.entity.Student">

       select * from student where id=#{id}
   </select>

   <select id="getByStudent" parameterType="com.southwind.entity.Student"
       resultType="com.southwind.entity.Student">

       select * from student where id=#{id} and name=#{name}
   </select>

   <select id="getByStudent2" parameterType="com.southwind.entity.Student"
       resultType="com.southwind.entity.Student">

       select * from student where name=#{name} and tel=#{tel}
   </select>

   <select id="getByName" parameterType="java.lang.String"
       resultType="com.southwind.entity.Student">

       select * from student where name=#{name}
   </select>

</mapper>


测试:


public static void main(String[] args) {

       StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
       Student stu = studentDAO.getById(1);
       System.out.println(stu);

   }


通过以上代码可以看到,MyBatis的方式省去了实现类的创建,改为用xml来定义业务方法的具体实现。


那么问题来了。


我们知道Java是面向对象的编程语言,程序在运行时执行业务方法,必须要有实例化的对象。但是,接口是不能被实例化的,而且也没有接口的实现类,那么此时这个对象从哪来呢?


程序在运行时,动态创建代理对象。


即jdk动态代理,运行时结合接口和mapper.xml来动态创建一个代理对象,程序调用该代理对象的方法来完成业务。


如何使用jdk动态代理?


创建一个类,实现InvocationHandler接口,该类就具备了创建动态代理对象的功能。


两个核心方法:

1.自定义getInstance方法:入参为目标对象,通过Proxy.newProxyInstance方法创建代理对象,并返回。


    public Object getInstance(Class cls){
       Object newProxyInstance = Proxy.newProxyInstance(  
               cls.getClassLoader(),  
               new Class[] { cls },
               this);
       return (Object)newProxyInstance;
   }


2.实现接口的invoke方法,通过反射机制完成业务逻辑代码。


    @Override
   public Object invoke(Object proxy, Method method, Object[] args)
           throws Throwable
{
       // TODO Auto-generated method stub
       return null;
   }


invoke方法是核心代码,在该方法中实现具体的业务需求。接下来我们来看如何实现。

既然是对数据库进行操作,则一定需要数据库连接对象,数据库相关信息配置在config.xml中。

所以invoke方法第一步,就是要解析config.xml,创建数据库连接对象,使用C3P0数据库连接池。


    //读取C3P0数据源配置信息
   public static Map<String,String> getC3P0Properties(){
       Map<String,String> map = new HashMap<String,String>();
       SAXReader reader = new SAXReader();
       try {
           Document document = reader.read("src/config.xml");
           //获取根节点
           Element root = document.getRootElement();
           Iterator iter = root.elementIterator();
           while(iter.hasNext()){
               Element e = (Element) iter.next();
               //解析environments节点
               if("environments".equals(e.getName())){
                   Iterator iter2 = e.elementIterator();
                   while(iter2.hasNext()){
                       //解析environment节点
                       Element e2 = (Element) iter2.next();
                       Iterator iter3 = e2.elementIterator();
                       while(iter3.hasNext()){
                           Element e3 = (Element) iter3.next();
                           //解析dataSource节点
                           if("dataSource".equals(e3.getName())){
                               if("POOLED".equals(e3.attributeValue("type"))){
                                   Iterator iter4 = e3.elementIterator();
                                   //获取数据库连接信息
                                   while(iter4.hasNext()){
                                       Element e4 = (Element) iter4.next();
                                       map.put(e4.attributeValue("name"),e4.attributeValue("value"));
                                   }
                               }
                           }
                       }
                   }
               }
           }
       } catch (Exception e) {
           // TODO Auto-generated catch block
           e.printStackTrace();
       }
       return map;
   }


//获取C3P0信息,创建数据源对象
Map<String,String> map = ParseXML.getC3P0Properties();
ComboPooledDataSource datasource = new ComboPooledDataSource();
datasource.setDriverClass(map.get("driver"));
datasource.setJdbcUrl(map.get("url"));
datasource.setUser(map.get("username"));
datasource.setPassword(map.get("password"));
datasource.setInitialPoolSize(20);
datasource.setMaxPoolSize(40);
datasource.setMinPoolSize(2);
datasource.setAcquireIncrement(5);
Connection conn = datasource.getConnection();


有了数据库连接,接下来就需要获取待执行的SQL语句,SQL的定义全部写在StudentDAO.xml中,继续解析xml,执行SQL语句。


SQL执行完毕,查询结果会保存在ResultSet中,还需要将ResultSet对象中的数据进行解析,封装到JavaBean中返回。

两步完成:

1.反射机制创建Student对象。

2.通过反射动态执行类中所有属性的setter方法,完成赋值。


这样就将ResultSet中的数据封装到JavaBean中了。


//获取sql语句
String sql = element.getText();
//获取参数类型
String parameterType = element.attributeValue("parameterType");
//创建pstmt
PreparedStatement pstmt = createPstmt(sql,parameterType,conn,args);
ResultSet rs = pstmt.executeQuery();
if(rs.next()){
   //读取返回数据类型
   String resultType = element.attributeValue("resultType");  
   //反射创建对象
   Class clazz = Class.forName(resultType);
   obj = clazz.newInstance();
   //获取ResultSet数据
   ResultSetMetaData rsmd = rs.getMetaData();
   //遍历实体类属性集合,依次将结果集中的值赋给属性
   Field[] fields = clazz.getDeclaredFields();
   for(int i = 0; i < fields.length; i++){
       Object value = setFieldValueByResultSet(fields[i],rsmd,rs);
       //通过属性名找到对应的setter方法
       String name = fields[i].getName();
       name = name.substring(0, 1).toUpperCase() + name.substring(1);
       String MethodName = "set"+name;
       Method methodObj = clazz.getMethod(MethodName,fields[i].getType());
       //调用setter方法完成赋值
       methodObj.invoke(obj, value);
       }
}


代码的实现大致思路如上所述,具体实现起来有很多细节需要处理。使用到两个自定义工具类:ParseXML,MyInvocationHandler。


完整代码:

ParseXML


public class ParseXML {

   //读取C3P0数据源配置信息
   public static Map<String,String> getC3P0Properties(){
       Map<String,String> map = new HashMap<String,String>();
       SAXReader reader = new SAXReader();
       try {
           Document document = reader.read("src/config.xml");
           //获取根节点
           Element root = document.getRootElement();
           Iterator iter = root.elementIterator();
           while(iter.hasNext()){
               Element e = (Element) iter.next();
               //解析environments节点
               if("environments".equals(e.getName())){
                   Iterator iter2 = e.elementIterator();
                   while(iter2.hasNext()){
                       //解析environment节点
                       Element e2 = (Element) iter2.next();
                       Iterator iter3 = e2.elementIterator();
                       while(iter3.hasNext()){
                           Element e3 = (Element) iter3.next();
                           //解析dataSource节点
                           if("dataSource".equals(e3.getName())){
                               if("POOLED".equals(e3.attributeValue("type"))){
                                   Iterator iter4 = e3.elementIterator();
                                   //获取数据库连接信息
                                   while(iter4.hasNext()){
                                       Element e4 = (Element) iter4.next();
                                       map.put(e4.attributeValue("name"),e4.attributeValue("value"));
                                   }
                               }
                           }
                       }
                   }
               }
           }
       } catch (Exception e) {
           // TODO Auto-generated catch block
           e.printStackTrace();
       }
       return map;
   }

   //根据接口查找对应的mapper.xml
   public static String getMapperXML(String className){
       //保存xml路径
       String xml = "";
       SAXReader reader = new SAXReader();
       Document document;
       try {
           document = reader.read("src/config.xml");
           Element root = document.getRootElement();
           Iterator iter = root.elementIterator();
           while(iter.hasNext()){
               Element mappersElement = (Element) iter.next();
               if("mappers".equals(mappersElement.getName())){
                   Iterator iter2 = mappersElement.elementIterator();
                   while(iter2.hasNext()){
                       Element mapperElement = (Element) iter2.next();
                       //com.southwin.dao.UserDAO . 替换 #
                       className = className.replace(".", "#");
                       //获取接口结尾名
                       String classNameEnd = className.split("#")[className.split("#").length-1];
                       String resourceName = mapperElement.attributeValue("resource");
                       //获取resource结尾名
                       String resourceName2 = resourceName.split("/")[resourceName.split("/").length-1];
                       //UserDAO.xml . 替换 #
                       resourceName2 = resourceName2.replace(".", "#");
                       String resourceNameEnd = resourceName2.split("#")[0];
                       if(classNameEnd.equals(resourceNameEnd)){
                           xml="src/"+resourceName;
                       }
                   }
               }
           }
       } catch (DocumentException e) {
           // TODO Auto-generated catch block
           e.printStackTrace();
       }
       return xml;
   }
}


MyInvocationHandler:


public class MyInvocationHandler implements InvocationHandler{

   private String className;

   public Object getInstance(Class cls){
       //保存接口类型
       className = cls.getName();
       Object newProxyInstance = Proxy.newProxyInstance(  
               cls.getClassLoader(),  
               new Class[] { cls },
               this);
       return (Object)newProxyInstance;
   }

   public Object invoke(Object proxy, Method method, Object[] args)  throws Throwable {        
       SAXReader reader = new SAXReader();
       //返回结果
       Object obj = null;
       try {
           //获取对应的mapper.xml
           String xml = ParseXML.getMapperXML(className);
           Document document = reader.read(xml);
           Element root = document.getRootElement();
           Iterator iter = root.elementIterator();
           while(iter.hasNext()){
               Element element = (Element) iter.next();
               String id = element.attributeValue("id");
               if(method.getName().equals(id)){
                   //获取C3P0信息,创建数据源对象
                   Map<String,String> map = ParseXML.getC3P0Properties();
                   ComboPooledDataSource datasource = new ComboPooledDataSource();
                   datasource.setDriverClass(map.get("driver"));
                   datasource.setJdbcUrl(map.get("url"));
                   datasource.setUser(map.get("username"));
                   datasource.setPassword(map.get("password"));
                   datasource.setInitialPoolSize(20);
                   datasource.setMaxPoolSize(40);
                   datasource.setMinPoolSize(2);
                   datasource.setAcquireIncrement(5);
                   Connection conn = datasource.getConnection();
                   //获取sql语句
                   String sql = element.getText();
                   //获取参数类型
                   String parameterType = element.attributeValue("parameterType");
                   //创建pstmt
                   PreparedStatement pstmt = createPstmt(sql,parameterType,conn,args);
                   ResultSet rs = pstmt.executeQuery();
                   if(rs.next()){
                       //读取返回数据类型
                       String resultType = element.attributeValue("resultType");  
                       //反射创建对象
                       Class clazz = Class.forName(resultType);
                       obj = clazz.newInstance();
                       //获取ResultSet数据
                       ResultSetMetaData rsmd = rs.getMetaData();
                       //遍历实体类属性集合,依次将结果集中的值赋给属性
                       Field[] fields = clazz.getDeclaredFields();
                       for(int i = 0; i < fields.length; i++){
                           Object value = setFieldValueByResultSet(fields[i],rsmd,rs);
                           //通过属性名找到对应的setter方法
                           String name = fields[i].getName();
                           name = name.substring(0, 1).toUpperCase() + name.substring(1);
                           String MethodName = "set"+name;
                           Method methodObj = clazz.getMethod(MethodName,fields[i].getType());
                           //调用setter方法完成赋值
                           methodObj.invoke(obj, value);
                       }
                   }
                   conn.close();
               }
           }
       } catch (Exception e) {
           // TODO Auto-generated catch block
           e.printStackTrace();
       }

      return obj;
   }

   /**
    * 根据条件创建pstmt
    * @param sql
    * @param parameterType
    * @param conn
    * @param args
    * @return
    * @throws Exception
    */

   public PreparedStatement createPstmt(String sql,String parameterType,Connection conn,Object[] args) throws Exception{
       PreparedStatement pstmt = null;
       try {
           switch(parameterType){
               case "int":
                   int start = sql.indexOf("#{");
                   int end = sql.indexOf("}");
                   //获取参数占位符 #{name}
                   String target = sql.substring(start, end+1);
                   //将参数占位符替换为?
                   sql = sql.replace(target, "?");
                   pstmt = conn.prepareStatement(sql);
                   int num = Integer.parseInt(args[0].toString());
                   pstmt.setInt(1, num);
                   break;
               case "java.lang.String":
                   int start2 = sql.indexOf("#{");
                   int end2 = sql.indexOf("}");
                   String target2 = sql.substring(start2, end2+1);
                   sql = sql.replace(target2, "?");
                   pstmt = conn.prepareStatement(sql);
                   String str = args[0].toString();
                   pstmt.setString(1, str);
                   break;
               default:
                   Class clazz = Class.forName(parameterType);
                   Object obj = args[0];
                   boolean flag = true;
                   //存储参数
                   List<Object> values = new ArrayList<Object>();
                   //保存带#的sql
                   String sql2 = "";
                   while(flag){
                       int start3 = sql.indexOf("#{");
                       //判断#{}是否替换完成
                       if(start3<0){
                           flag = false;
                           break;
                       }
                       int end3 = sql.indexOf("}");
                       String target3 = sql.substring(start3, end3+1);
                       //获取#{}的值 如#{name}拿到name
                       String name = sql.substring(start3+2, end3);
                       //通过反射获取对应的getter方法
                       name = name.substring(0, 1).toUpperCase() + name.substring(1);
                       String MethodName = "get"+name;
                       Method methodObj = clazz.getMethod(MethodName);
                       //调用getter方法完成赋值
                       Object value = methodObj.invoke(obj);
                       values.add(value);
                       sql = sql.replace(target3, "?");
                       sql2 = sql.replace("?", "#");
                   }
                   //截取sql2,替换参数
                   String[] sqls = sql2.split("#");
                   pstmt = conn.prepareStatement(sql);
                   for(int i = 0; i < sqls.length-1; i++){
                       Object value = values.get(i);
                       if("java.lang.String".equals(value.getClass().getName())){
                           pstmt.setString(i+1, (String)value);
                       }
                       if("java.lang.Integer".equals(value.getClass().getName())){
                           pstmt.setInt(i+1, (Integer)value);
                       }
                   }
                   break;
               }
       } catch (SQLException e) {
           // TODO Auto-generated catch block
           e.printStackTrace();
       }
       return pstmt;
   }

   /**
    * 根据将结果集中的值赋给对应的属性
    * @param field
    * @param rsmd
    * @param rs
    * @return
    */

   public Object setFieldValueByResultSet(Field field,ResultSetMetaData rsmd,ResultSet rs){
       Object result = null;
       try {
           int count = rsmd.getColumnCount();
           for(int i=1;i<=count;i++){
               if(field.getName().equals(rsmd.getColumnName(i))){
                   String type = field.getType().getName();
                   switch (type) {
                       case "int":
                           result = rs.getInt(field.getName());
                           break;
                       case "java.lang.String":
                           result = rs.getString(field.getName());
                           break;
                   default:
                       break;
                   }
               }
           }
       } catch (SQLException e) {
           // TODO Auto-generated catch block
           e.printStackTrace();
       }
       return result;
   }


}


代码测试:

StudnetDAO.getById


public static void main(String[] args) {

       StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
       Student stu = studentDAO.getById(1);
       System.out.println(stu);

   }


代码中的studentDAO为动态代理对象,此对象通过 MyInvocationHandler().getInstance(StudentDAO.class)方法动态创建,并且结合StudentDAO.xml实现了StudentDAO接口的全部方法,直接调用studentDAO对象的方法即可完成业务需求。


深入底层,仿MyBatis自己写框架

StudnetDAO.getByName


public static void main(String[] args) {

       StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
       Student stu = studentDAO.getByName("李四");
       System.out.println(stu);

   }


深入底层,仿MyBatis自己写框架


StudnetDAO.getByStudent(根据id和name查询)


public static void main(String[] args) {

       StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
       Student student = new Student();
       student.setId(1);
       student.setName("张三");
       Student stu = studentDAO.getByStudent(student);
       System.out.println(stu);

   }


深入底层,仿MyBatis自己写框架


StudnetDAO.getByStudent2(根据name和tel查询)


public static void main(String[] args) {

       StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
       Student student = new Student();
       student.setName("李四");
       student.setTel("18367895678");
       Student stu = studentDAO.getByStudent2(student);
       System.out.println(stu);

   }


深入底层,仿MyBatis自己写框架


以上就是仿MyBatis实现自定义小工具的大致思路,细节之处还需具体查看源码,最后附上小工具源码链接。



源码:


链接: https://pan.baidu.com/s/1pMz0FDh 

密码: fnjb




专业 热爱 专注

致力于最高效的Java学习

Java大联盟



扫描下方二维码,加入Java大联盟



以上是关于深入底层,仿MyBatis自己写框架的主要内容,如果未能解决你的问题,请参考以下文章

目前主流的java框架有哪些?

二手写MyBatis简易版框架

仿京东开放平台框架,开发自己的开放平台(包含需求,服务端代码,SDK代码)

仿京东开放平台框架,开发自己的开放平台(包含需求,服务端代码,SDK代码)

深入详解Mybatis的架构原理与6大核心流程

通过手写MyBatis带你掌握自己写框架的秘诀