spring+mybatis一个方法执行多条更新语句,实现批量DML

Posted tangtong1

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了spring+mybatis一个方法执行多条更新语句,实现批量DML相关的知识,希望对你有一定的参考价值。

在实际开发中,经常会遇到一个方法里需要更新多张表的情况,在使用spring+mybatis的时候,我们一般写多个dao方法分别与mapper中<update>或<insert>标签对应,然后在service中调用这么多个dao方法,这样不仅浪费人力物力,代码也不好看,全是updateXx1/updateXx2/insertXx3/insertXx4等。

笔者经过查看mybatis的源码,发现mybatis确实不支持一个<update>中写多条语句,这是因为mybatis底层其实是使用PreparedStatement实现的,这个东西在初始化的时候就得传一个sql语句进去,然后对于?参数再使用setXxx()填充进去,mybatis为了偷懒就没有实现多条语句的操作。

笔者经过慎重思考及实验,得出以下三种实现方法:

  1. 使用Statement,把sql语句写在代码里,通过其addBatch(sql)方法(此方法无法在PreparedStatement中使用)把sql语句加进去,然后批量执行,但是这种方法有个问题,你得先把参数写死在sql中,这当然是不建议的啦,sql注入,你懂的~~
  2. 使用存储过程,这当然是最简单的方法啦,但是,你得先能干得过DBA,再说了,你不能每次两次以上的更新都去麻烦DBA,不是么~~
  3. 自己通过mybatis的API改造出来一个可以多次更新的方法,以下介绍的方法就使用这种方法_

笔者把mybatis的API翻了个底朝天,终于写出了以下方法,直接上代码:
(以下代码,为节约空间,把package和import都干掉了)

/**
 * VO 公共数据库操作dao
 * 
 * @version 1.0
 */
public class BaseSqlDaoImpl<T> implements BaseSqlDao<T> 
  private static Logger logger = LoggerFactory.getLogger(BaseSqlDaoImpl.class);
 
  /**
   * Mapper包路径
   */
  private String baseMapperPackage;

  public String getBaseMapperPackage() 
    return baseMapperPackage;
  

  public void setBaseMapperPackage( String baseMapperPackage ) 
    this.baseMapperPackage = baseMapperPackage;
  

  @Resource
  private SqlSessionFactory sqlSessionFactory;

  /**
   * Alan添加,用于一个mapper更新多个表
   * 注:
   * (1)Connection不用close,mybatis自己会回收;
   * (2)使用此方法可以使用spring的事务管理
   * 
   * @param params
   * @return
   * @throws DaoException
   */
  @Override
  public void updateMultiTables( Map<String, Object> params ) throws Exception 
    SqlSession sqlSession = sqlSessionFactory.openSession();
    PreparedStatement ps = null;
    try 
      String selectVoName =
          params.get("selectVoName") == null ? "baseSelectListByVo" : params.get("selectVoName").toString();
          
      //此处getBaseMapperPackage()如果取不到值,可以直接写*Mapper.java类所在的包名
      String mapperName = getBaseMapperPackage() + "." + params.get("mapperClassType");

      Connection connection = sqlSession.getConnection();

      BoundSql boundSql =
          sqlSession.getConfiguration().getMappedStatement(mapperName + "." + selectVoName).getBoundSql(params);

      String[] sqls = boundSql.getSql().split(";");
      List<ParameterMapping> list = boundSql.getParameterMappings();

      int i = 0;
      int j = 0;
      for( ; i < sqls.length; i++ ) 
        String sql = sqls[i].trim();
        if(logger.isDebugEnabled()) 
          logger.debug("Alan print log for you -===> Preparing: " + sql);
        
        ps = connection.prepareStatement(sql);
        int questionMarkCount = ps.getParameterMetaData().getParameterCount();
        StringBuilder sb = null;
        if(logger.isDebugEnabled()) 
          sb = new StringBuilder();
        
        for( int k = 1, length = j + questionMarkCount; j < length; j++, k++ ) 
          Object param = null;
          String propertyName = list.get(j).getProperty();
          if(boundSql.hasAdditionalParameter(propertyName)) 
            // 用于获取xml通过foreach循环的list参数
            param = boundSql.getAdditionalParameter(propertyName);
          
          else 
            // 用于获取正常的参数,包括对象参数
            MetaObject metaObject = sqlSession.getConfiguration().newMetaObject(params);
            param = metaObject.getValue(propertyName);
          
          ps.setObject(k, param);
          if(logger.isDebugEnabled()) 
            sb.append(param==null?"null":param.toString()).append("(").append(param==null?list.get(j).getJavaType().getSimpleName():param.getClass().getSimpleName()).append("),");
          
        
        if(logger.isDebugEnabled() && sb.length() > 0) 
          logger.debug("Alan print log for you -==> Parameters: " + sb.toString().substring(0, sb.length() - 1));
        
        int total = ps.executeUpdate();
        if(logger.isDebugEnabled()) 
          logger.debug("Alan print log for you -<== Total: " + total);
        
        ps.close();
      
    
    catch(Exception e) 
      logger.error("", e);
      throw e;
    
    finally 
      if(ps != null) ps.close();
      sqlSession.close();
    

简单分析:把<update>中的多个语句取出来用分号"; "分割,每一条语句装入到PreparedStatement中,再把params中的参数set到对应占位符 ? 上。

在service的基类中添加以下方法:

public abstract class BaseCrudServiceImpl implements BaseCrudService  
  /**
   * 更新多张表
   */
  @SuppressWarnings( "unchecked" )
  @Override
  public void updateMultiTables( Map<String, Object> params ) throws ServiceException 
    try 
      BaseSqlDao<Object> baseSqlDao = (BaseSqlDao<Object>) SpringComponent.getBean("baseSqlDao");
      baseSqlDao.updateMultiTables(params);
    
    catch(Exception e) 
      logger.error("update multi-tables error");
      throw new ServiceException("", e);
    
  

这样,只要继承了BaseCrudServiceImpl 类的service中直接调用即可。

在service中调用代码如下:

@Service( "testService" )
public class TestServiceImpl extends BaseCrudServiceImpl implements TestService 

  @Override
  @Transactional( propagation = Propagation.REQUIRED, rollbackFor = Exception.class )
  public void testUpdateMultiTables( Map<String, Object> params ) 
    updateMultiTables(params);
  

  @Override
  public BaseCrudDao init() 
    return null;
  


单元测试类如下:

@RunWith( SpringJUnit4ClassRunner.class )
@ContextConfiguration( locations = "classpath*:META-INF/applicationContext.xml" )
public class BaseTest0 extends AbstractJUnit4SpringContextTests 

  @Resource
  TestService testService;

  @Test
  public void test() 
    try 
      Map<String, Object> params = new HashMap<String, Object>();
      BasicUserInfo userInfo = new BasicUserInfo();
      userInfo.setCityId(440300);
      userInfo.setCityName("深圳市2");
      userInfo.setMobile("13265767392");
      userInfo.setNickName("唐彤");
      userInfo.setUserId(6037288328267771904L);
      params.put("selectVoName", "testMultiTables");
      params.put("mapperClassType", "TestMapper");
      params.put("param1", 1);
      params.put("param2", 2);
      params.put("param3", "10.1.20.75");
      params.put("param4", "ttxx");
      params.put("user", userInfo);

      List<Integer> list = new ArrayList<Integer>();
      list.add(1);
      list.add(2);
      list.add(3);
      list.add(4);
      list.add(5);
      params.put("list", list);

      testService.testUpdateMultiTables(params);
    
    catch(Exception e) 
      logger.error("update multi-tables error");
      throw new ServiceException("", e);
    
  


xml中配置如下:

	<update id="testMultiTables">
		update basic_bank_info set name = #param4	where id = #param2;
		delete from basic_user_gift_package_mapping where user_id = #user.userId;
		insert into basic_like_log values(#user.userId,#user.userId,#param1,now(),#user.nickName,#param3);
		update basic_area set name = #user.cityName where area_id = #user.cityId;
		update basic_bank_info set name = #param3 where id in
		<foreach collection="list" item="id" index="index" open="(" close=")" separator=",">
            #id
		</foreach>
	</update>

可见上述一共有五条sql语句,其中参数包括基本类型、对象类型和list类型,执行单元测试,打印日志如下:

可以看见五条sql语句都正确执行了,事务提交之后,connection自动释放(Release)。

以上即为笔者拙见,有不得当的地方欢迎提出来,大家如果有好的方法也欢迎推荐,谢谢~~


欢迎关注我的公众号“彤哥读源码”,查看更多“源码&架构&算法”系列文章, 与彤哥一起畅游源码的海洋。

以上是关于spring+mybatis一个方法执行多条更新语句,实现批量DML的主要内容,如果未能解决你的问题,请参考以下文章

oracle mybatis一次执行多条sql,提示SQL命令未正确结束

mybatis批量更新不同参数多条语句带分号update报错的解决方案

mybatis一次执行多条SQL语句

mybatis源码分析 mybatis与spring事务管理分析

Mybatis+Oracle进行数据的批量插入和更新

mybatis 批量执行多条update语句