也是重新整理了之前的那篇
话不多说直接上代码
首先是结构
依赖pom.xml
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>yebatis</groupId> <artifactId>com.yck.yebatis</artifactId> <version>1.0-SNAPSHOT</version> <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <configuration> <source>1.7</source> <target>1.7</target> </configuration> </plugin> </plugins> </build> <dependencies> <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <version>4.12</version> <scope>test</scope> </dependency> <dependency> <groupId>dom4j</groupId> <artifactId>dom4j</artifactId> <version>1.6</version> </dependency> <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>5.1.42</version> </dependency> </dependencies> </project>
一些基本类
用来封装一个dao信息的Mapper类
package com.yck.yebaitis; import java.util.List; public class Mapper { private String mapperClass; private List<Function> functions; public String getMapperClass() { return mapperClass; } public void setMapperClass(String mapperClass) { this.mapperClass = mapperClass; } public List<Function> getFunctions() { return functions; } public void setFunctions(List<Function> functions) { this.functions = functions; } }
封装一条方法信息的Function类
package com.yck.yebaitis; public class Function { private String name; private String type; private Class<?> resultClass; private String sql; public String getName() { return name; } public void setName(String name) { this.name = name; } public String getType() { return type; } public void setType(String type) { this.type = type; } public Class<?> getResultClass() { return resultClass; } public void setResultClass(Class<?> resultClass) { this.resultClass = resultClass; } public String getSql() { return sql; } public void setSql(String sql) { this.sql = sql; } }
常量
package com.yck.yebaitis; public class FunctionConstants { public static final String ADD = "add"; public static final String DELETE = "delete"; public static final String UPDATE = "update"; public static final String SELECT = "select"; }
实现功能的DaoFactory
package com.yck.yebaitis; import com.yck.jdbc.DataUtil; import org.dom4j.Document; import org.dom4j.DocumentException; import org.dom4j.Element; import org.dom4j.io.SAXReader; import com.yck.util.StringUtil; import com.yck.exception.NoConfigFileException; import java.io.File; import java.io.FileFilter; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.util.*; public class DaoFactory { private static final String configPath = "src/dao/mapper"; private static DaoFactory instance; private DaoFactory() { } public static DaoFactory getInstance() { if (instance == null) { synchronized (DaoFactory.class) { if (instance == null) instance = new DaoFactory(); } } return instance; } public Map<String, Object> getDaoMap() { Map<String, Object> map = null; try { File[] files = getAllFiles(); map = new HashMap<>(files.length); for (File file : files) { Mapper mapper = readerMapper(file); Object obj = implDao(mapper); map.put(mapper.getMapperClass(), obj); } } catch (NoConfigFileException | ClassNotFoundException | DocumentException e) { e.printStackTrace(); } return map; } private Object implDao(Mapper mapper) throws ClassNotFoundException { ClassLoader classLoader = DaoFactory.class.getClassLoader(); final Mapper temp = mapper; //加载一个接口类 Class<?> interfaze; interfaze = classLoader.loadClass(mapper.getMapperClass()); /* 代理实现方法 之前我是理解错了,我以为是在执行下面这个方法时,就已经实现了类似我们自己写一个DaoImpl,其实它就只是返回了一个代理类实例 */ return Proxy.newProxyInstance(classLoader, new Class[]{interfaze}, new InvocationHandler() { @Override public Object invoke(Object proxy, Method method, Object[] args) { List<Function> functions = temp.getFunctions(); for (Function func : functions) { if (func.getName().equals(method.getName())) { if (func.getType().equals(FunctionConstants.SELECT)) { if (method.getReturnType().equals(List.class)) { return DataUtil.queryForList(func.getSql(), func.getResultClass(), args); } else { return DataUtil.queryForObject(func.getSql(), func.getResultClass(), args); } } else { return DataUtil.excuteUpdate(func.getSql(), args); } } } return null; } }); } private File[] getAllFiles() throws NoConfigFileException { FileFilter fileFilter = new FileFilter() { public boolean accept(File pathname) { String fileName = pathname.getName().toLowerCase(); return fileName.endsWith(".xml"); } }; File configPath = new File("src/mapper"); File[] files = configPath.listFiles(fileFilter); if (files == null || files.length == 0) { throw new NoConfigFileException("file not find"); } return files; } private Mapper readerMapper(File file) throws DocumentException, ClassNotFoundException { SAXReader reader = new SAXReader(); Mapper mapper = new Mapper(); Document doc = reader.read(file); Element root = doc.getRootElement(); //读取根节点 即dao节点 mapper.setMapperClass(root.attributeValue("class").trim()); //把dao节点的class值存为接口名 List<Function> list = new ArrayList<>(); //用来存储方法的List for (Iterator<?> rootIter = root.elementIterator(); rootIter.hasNext(); ) //遍历根节点下所有子节点 { Function fun = new Function(); //用来存储一条方法的信息 Element e = (Element) rootIter.next(); String type = e.getName().trim(); switch (type) { case FunctionConstants.ADD: fun.setType(FunctionConstants.ADD); break; case FunctionConstants.DELETE: fun.setType(FunctionConstants.DELETE); break; case FunctionConstants.UPDATE: fun.setType(FunctionConstants.UPDATE); break; case FunctionConstants.SELECT: fun.setType(FunctionConstants.SELECT); break; default: continue; } fun.setName(e.attributeValue("id").trim()); fun.setSql(e.getText().trim()); String resultType = e.attributeValue("resultType"); if (!StringUtil.isBlank(resultType)) { fun.setResultClass(Class.forName(resultType)); } list.add(fun); } mapper.setFunctions(list); return mapper; } }
测试用类
实现IUserDao的xml文件就是最底下的userdao.xml
<?xml version="1.0" encoding="UTF-8"?> <dao id="userdao" class="dao.IUserDao"> <select id="selectById" resultType ="po.User"> select * from t_user where id = ? </select> <update id="updateName"> update t_user set name = ? where id = ? </update> <delete id="deleteById"> delete from t_user where id=? </delete> <insert id="add"> insert into t_user(name,age,score,create_time,update_time) values(?,?,?,now(),now()); </insert> <select id="getAll" resultType = "po.User"> select * from t_user; </select> </dao>
测试代码
import com.yck.yebaitis.DaoFactory; import dao.IUserDao; import po.User; import java.util.List; import java.util.Map; public class Test { public static void main(String[] args) { Map<String,Object> daoMap = DaoFactory.getInstance().getDaoMap(); IUserDao dao = (IUserDao) daoMap.get("dao.IUserDao"); List<User> users = dao.getAll(); System.out.println("查询多条记录:"+users); System.out.println("*******************************************"); User user = dao.selectById(2); System.out.println("查询一条记录:"+user); System.out.println("*******************************************"); int i = dao.updateName("二傻",2); System.out.println("更新一条记录:"+i); System.out.println("*******************************************"); List<User> userList = dao.getAll(); System.out.println("更新一条记录后查询所有记录:"+user); System.out.println("*******************************************"); } }
测试结果