利用Redis实现限流

Posted 热爱编程的大忽悠

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了利用Redis实现限流相关的知识,希望对你有一定的参考价值。

利用Redis实现限流


思路

redis实现限流的核心思路是利用redis提供的key过期时间作为限流窗口期,key的值记录该窗口期内已经产生的访问资源次数,key本身记录限流的资源范围。

具体步骤如下:

  • 首先规定资源限制范围,一般都是限制对某个接口的调用频率,因此key使用接口方法名即可
  • 第一次访问资源时,key不存在,那么新创建一个key,并将值设置为1,最后设置key的过期时间,表示开启限流窗口期
  • 每一次访问资源,会首先判断当前是否存在限流窗口期,如果存在,将访问次数加一,并判断是否达到最大资源访问次数限制
  • 如果达到了,则抛出异常,告诉用户访问频繁,请稍后再试
  • 如果没达到,则放行请求
  • 在不是第一次访问资源的前提下,如果发现限流窗口期过了,那么重新开启一个

步骤

1.准备工作

  • 引入redis相关依赖
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
 <dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-pool2</artifactId>
 </dependency>
  • 添加相关配置信息
spring:
  redis:
    host: xxx
    port: 6379
    password: xxx
    lettuce:
      #只有自动配置连接池的依赖,连接池才会生效
      pool:
        max-active: 8 #最大连接
        max-idle: 8 #最大空闲连接
        min-idle: 0 #最小空闲连接
        max-wait: 100 #连接等待时间
  • 修改redisTemplate的序列化方式为JSON
    @ConditionalOnMissingBean
    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory)
    
        //创建template
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        //设置连接工厂
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        //设置序列化工具
        GenericJackson2JsonRedisSerializer jsonRedisSerializer = new GenericJackson2JsonRedisSerializer();
        //key和hashKey采用String序列化
        redisTemplate.setKeySerializer(RedisSerializer.string());
        redisTemplate.setHashKeySerializer(RedisSerializer.string());
        //value和hashValue用JSON序列化
        redisTemplate.setValueSerializer(jsonRedisSerializer);
        redisTemplate.setHashValueSerializer(jsonRedisSerializer);
        return redisTemplate;
    

2.限流核心类实现

  • 定义一个顶层的流量控制接口实现,pass方法返回true,表示方向请求,否则表示请求被拦截了
/**
 * 流量控制
 * @author 大忽悠
 * @create 2023/2/6 10:50
 */
public interface RateLimiter 
     /**
      * @param requestInfo 请求信息
      * @return 当前请求是否允许通过
      */
     boolean pass(RequestInformation requestInfo);

  • requestInfo提供当前请求的相关信息
/**
 * 请求信息
 * @author 大忽悠
 * @create 2023/2/6 10:55
 */
@Data
public class RequestInformation 
    /**
     * 限流key
     */
    private String key;
    /**
     * 限流时间
     */
    private int time;
    /**
     * time时间内最大请求资源次数
     */
    private int count;
    /**
     * 限流类型
     */
    private int limitType;
    /**
     * 请求的方法信息
     */
    private Method method;
    /**
     * 方法参数信息
     */
    private Object[] arguments;
    /**
     * 客户端IP地址
     */
    private String ip;
    private HttpServletRequest httpServletRequest;
    private HttpServletResponse httpServletResponse;

    public RequestInformation() 
    

  • 提供一个限流注解,该注解可以标注在方法或者类上,标注在类上,则表示当前类所有方法都需要流量控制
/**
 * 限流注解
 * @author 大忽悠
 * @create 2023/2/6 10:39
 */
@Target(ElementType.METHOD,ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Limiter 
    /**
     * @return 限流key--默认为rate_limit:业务名:类名.方法名 ,如果限制了IP类型,则为: rate_limit:业务名:ip:类名.方法名
     */
    String key() default "";
    /**
     * @return 限流时间,单位为s
     */
    int time() default 60;
    /**
     * @return time时间内限制的资源请求次数
     */
    int count() default 100;
    /**
     * @return 限流类型
     */
    int limitType() default LimitType.DEFAULT;

  • redis作为限流器的实现
public class RedisRateLimiterImpl implements RateLimiter
    private static final String RATE_LIMITER_KEY_PREFIX="rate_limiter";
    /**
     * 使用redis做限流处理使用的lua脚本
     */
    private static  final String LIMITER_LUA=
            "local key = KEYS[1]\\n" +
            "local count = tonumber(ARGV[1])\\n" +
            "local time = tonumber(ARGV[2])\\n" +
            "local current = redis.call('get', key)\\n" +
            "if current and tonumber(current) > count then\\n" +
            "    return 1\\n" +
            "end\\n" +
            "current = redis.call('incr', key)\\n" +
            "if tonumber(current) == 1 then\\n" +
            "    redis.call('expire', key, time)\\n" +
            "end\\n" +
            "return 0\\n";
    private RedisTemplate<String, Object> redisTemplate;

    public RedisRateLimiterImpl(RedisTemplate<String, Object> redisTemplate) 
        this.redisTemplate = redisTemplate;
    

    /**
     * @param requestInfo 请求信息
     * @return 当前请求是否允许通过
     */
    @Override
    public boolean pass(RequestInformation requestInfo) 
        //拿到限流key
        String limiterKey=getRateLimiterKey(requestInfo);
        //执行lua脚本
        Long limiterRes = redisTemplate.execute(RedisScript.of(LIMITER_LUA,Long.class), List.of(limiterKey), requestInfo.getCount(), requestInfo.getTime());
        //判断限流结果
        return limiterRes==0L;
    


    private String getRateLimiterKey(RequestInformation requestInfo) 
         return combineKey(RATE_LIMITER_KEY_PREFIX,
                 requestInfo.getKey(),
                 requestInfo.getIp(),
                 requestInfo.getMethod().getClass().getName(),
                 requestInfo.getMethod().getName());
    

    private String combineKey(String ... keys) 
        StringBuilder keyBuilder=new StringBuilder();
        for (int i = 0; i < keys.length; i++) 
              if(StringUtils.isEmpty(keys[i]))
                  continue;
              
              keyBuilder.append(keys[i]);
              if(i==keys.length-1)
                  continue;
              
              keyBuilder.append(":");
        
        return keyBuilder.toString();
    

lua脚本解释:

KEYS 和 ARGV 都是一会调用时候传进来的参数,tonumber 就是把字符串转为数字,redis.call 就是执行具体的 redis 指令,具体流程是这样:

  • 首先获取到传进来的 key 以及 限流的 count 和时间 time。
  • 通过 get 获取到这个 key 对应的值,这个值就是当前时间窗内这个接口可以访问多少次。
  • 如果是第一次访问,此时拿到的结果为 nil,否则拿到的结果应该是一个数字,所以接下来就判断,如果拿到的结果是一个数字,并且这个数字还大于 count,那就说明已经超过流量限制了,那么返回1表示请求拦截。
  • 如果拿到的结果为 nil,说明是第一次访问,此时就给当前 key 自增 1,然后设置一个过期时间。
  • 最后返回0表示请求放行。

注意; lua脚本也可以定义在文件在,然后通过加载文件获取

@Bean
public DefaultRedisScript<Long> limitScript() 
    DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
    redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
    redisScript.setResultType(Long.class);
    return redisScript;

或者在 Redis 服务端定义好 Lua 脚本,然后计算出来一个散列值,在 Java 代码中,通过这个散列值锁定要执行哪个 Lua 脚本


3.aop相关逻辑实现

我们需要将限流逻辑在需要流量管控的方法执行前先执行,因此需要拦截目标方法,有两个思路:

  1. 通过@Aspect注解标注一个切面类,用@Before或者@Around注解标注在切面方法上,里面填写限流管控逻辑
  2. 手动编写一个advisor增强器,注入容器,并提供相关拦截器和pointcut实现

这里我采用的是手动编写advisor的方式进行实现,下面演示具体步骤:

  • 编写拦截器
/**
 * 限流方法拦截器
 * @author 大忽悠
 * @create 2023/2/6 11:08
 */
@Slf4j
public class RateLimiterMethodInterceptor implements MethodInterceptor 
    private final RateLimiter rateLimiter;

    public RateLimiterMethodInterceptor(RateLimiter rateLimiter) 
        this.rateLimiter=rateLimiter;
    

    @Override
    public Object invoke(MethodInvocation invocation) throws Throwable 
        try
            RequestInformation requestInformation = new RequestInformation();
            buildMethodInfo(requestInformation,invocation);
            buildLimitInfo(requestInformation);
            buildRequestInfo(requestInformation);
            if (rateLimiter.pass(requestInformation))
                return invocation.proceed();
            
            logWarn(requestInformation);
        catch (Exception e)
           e.printStackTrace();
           throw e;
        
        throw new RateLimiterException("访问过于频繁,请稍后再试!");
    

    private void logWarn(RequestInformation requestInformation) 
        if(requestInformation.getHttpServletRequest()!=null)
            log.warn("rateLimiter拦截了一个请求,该请求信息如下: URI:  ,IP:  ,方法名:  ,方法参数信息:  ",
                    requestInformation.getHttpServletRequest().getRequestURI(),requestInformation.getIp(),requestInformation.getMethod().getName(),
                    Arrays.toString(requestInformation.getArguments()));
        else 
            log.warn("rateLimiter拦截了一个请求,该请求信息如下: 方法名:  ,方法参数信息:  ",
                    requestInformation.getMethod().getName(), Arrays.toString(requestInformation.getArguments()));
        
    

    private void buildLimitInfo(RequestInformation requestInformation) throws RateLimiterException 
        Method method = requestInformation.getMethod();
        Limiter limiter;
        if(method.isAnnotationPresent(Limiter.class))
            limiter = method.getAnnotation(Limiter.class);
        else 
            limiter=method.getClass().getAnnotation(Limiter.class);
        
        if(limiter==null)
            throw new RateLimiterException("无法在当前方法"+method.getName()+"或者类"+method.getClass().getName()+"上寻找到@Limiter注解");
        
        requestInformation.setKey(limiter.key());
        requestInformation.setCount(limiter.count());
        requestInformation.setTime(limiter.time());
        requestInformation.setLimitType(limiter.limitType());
    

    private void buildMethodInfo(RequestInformation requestInformation, MethodInvocation invocation) 
        requestInformation.setMethod(invocation.getMethod());
        if(invocation instanceof ReflectiveMethodInvocation)
            ReflectiveMethodInvocation reflectiveMethodInvocation = (ReflectiveMethodInvocation) invocation;
            requestInformation.setArguments(reflectiveMethodInvocation.getArguments());
        
    

    /**
     * 从线程上下文中取出请求和响应相关信息
     */
    private void buildRequestInfo(RequestInformation requestInformation) 
        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        if(requestAttributes instanceof ServletRequestAttributes)
            ServletRequestAttributes sra = (ServletRequestAttributes) requestAttributes;
            requestInformation.setHttpServletRequest(sra.getRequest());
            requestInformation.setHttpServletResponse(sra.getResponse());
        
        if(requestInformation.getHttpServletRequest()!=null && requestInformation.getLimitType()==LimitType.IP)
            requestInformation.setIp(IPUtils.getIpAddress(requestInformation.getHttpServletRequest()));
        
    

  • 编写advisor增强器
/**
 * 限流增强器
 *
 * @author 大忽悠
 * @create 2023/2/6 10:57
 */
public class RateLimiterAdvisor extends AbstractPointcutAdvisor 
    private Pointcut pointcut;
    private RateLimiterMethodInterceptor rateLimiterMethodInterceptor;

    public RateLimiterAdvisor(RateLimiter rateLimiter) 
        pointcut = buildPointCut();
        rateLimiterMethodInterceptor=new RateLimiterMethodInterceptor(rateLimiter);
    

    @Override
    public Pointcut getPointcut() 
        return pointcut;
    

    @Override
    public Advice getAdvice() 
        return rateLimiterMethodInterceptor;
    


    private Pointcut buildPointCut() 
        return new Pointcut() 
            @Override
            public ClassFilter getClassFilter() 
                return (c)-> AnnotationUtils.isCandidateClass(c,Limiter.class);
            

            @Override
            public MethodMatcher getMethodMatcher() 
                return new StaticMethodMatcher() 
                    @Override
                    public boolean matches(Method method, Class<?> targetClass) 
                        return method.isAnnotationPresent(Limiter.class) || targetClass.isAnnotationPresent(Limiter.class);
                    
                ;
            
        ;
    

  • 使用配置类将advisor增强器放入容器中
/**
 * @author 大忽悠
 * @create 2023/2/6 11:14
 */
@Configuration
public class RateLimiterAutoConfiguration 
    @Bean
    @ConditionalOnMissingBean
    public RateLimiterAdvisor rateLimiterAdvisor(RateLimiter rateLimiter) 
        return new RateLimiterAdvisor(rateLimiter);
    

    @Bean
    @ConditionalOnMissingBean
    public RateLimiter rateLimiter(RedisTemplate<String, Object> redisTemplate) 
        return new RedisRateLimiterImpl(redisTemplate);
    
    ...


采用切面进行实现,可以参考江南一点雨大佬给出的实现:

@Aspect
@Component
public class RateLimiterAspect 
    private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);

    @Autowired
    private RedisTemplate<Object, Object> redisTemplate;

    @Autowired
    private RedisScript<Long> limitScript;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable 
        String key = rateLimiter.key();
        int time = rateLimiter.time()以上是关于利用Redis实现限流的主要内容,如果未能解决你的问题,请参考以下文章

Redis之zset实现滑动窗口限流

深入Redis简单限流

限流常规设计和实例

8.Redis系列Redis的高级应用-简单限流

[PHP] 基于redis实现滑动窗口式的短信发送接口限流

(十七)ATP应用测试平台——Redis实现API接口访问限流(固定窗口限流算法)