提交 bb84c3ac 编写于 作者: 不合群的混子's avatar 不合群的混子

添加限流切面,IP工具类,lua脚本,redis配置

上级 dab50b4e
package com.xkcoding.ratelimit.redis.aspect;
import cn.hutool.core.util.StrUtil;
import com.xkcoding.ratelimit.redis.annotation.RateLimiter;
import com.xkcoding.ratelimit.redis.util.IpUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import java.lang.reflect.Method;
import java.time.Instant;
import java.util.Collections;
import java.util.concurrent.TimeUnit;
/**
* <p>
* 限流切面
* </p>
*
* @author yangkai.shen
* @date Created in 2019/9/30 10:30
*/
@Slf4j
@Aspect
@Component
@RequiredArgsConstructor(onConstructor_ = @Autowired)
public class RateLimiterAspect {
private final static String SEPARATOR = ":";
private final static String REDIS_LIMIT_KEY_PREFIX = "limit:";
private final StringRedisTemplate stringRedisTemplate;
private final RedisScript<Long> limitRedisScript;
@Pointcut("@annotation(com.xkcoding.ratelimit.redis.annotation.RateLimiter)")
public void rateLimit() {
}
@Around("rateLimit()")
public Object pointcut(ProceedingJoinPoint point) throws Throwable {
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
// 通过 AnnotationUtils.findAnnotation 获取 RateLimiter 注解
RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class);
if (rateLimiter != null) {
String key = rateLimiter.key();
// 默认用方法名做限流的 key 前缀
if (StrUtil.isBlank(key)) {
key = method.getName();
}
// 最终限流的 key 为 前缀 + IP地址
// TODO: 此时需要考虑局域网多用户访问的情况,因此 key 后续需要加上方法参数更加合理
key = key + SEPARATOR + IpUtil.getIpAddr();
long max = rateLimiter.max();
long timeout = rateLimiter.timeout();
TimeUnit timeUnit = rateLimiter.timeUnit();
boolean limited = shouldLimited(key, max, timeout, timeUnit);
if (limited) {
throw new RuntimeException("手速太快了,慢点儿吧~");
}
}
return point.proceed();
}
private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) {
// 最终的 key 格式为:
// limit:自定义key:IP
// limit:方法名:IP
key = REDIS_LIMIT_KEY_PREFIX + key;
// 统一使用单位毫秒
long ttl = timeUnit.toMillis(timeout);
// 当前时间毫秒数
long now = Instant.now().toEpochMilli();
long expired = now - ttl;
// 注意这里必须转为 String,否则会报错 java.lang.Long cannot be cast to java.lang.String
Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + "");
if (executeTimes != null) {
if (executeTimes == 0) {
log.error("【{}】在单位时间 {} 毫秒内已达到访问上限,当前接口上限 {}", key, ttl, max);
return true;
} else {
log.info("【{}】在单位时间 {} 毫秒内访问 {} 次", key, ttl, executeTimes);
return false;
}
}
return false;
}
}
package com.xkcoding.ratelimit.redis.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
/**
* <p>
* Redis 配置
* </p>
*
* @author yangkai.shen
* @date Created in 2019/9/30 11:37
*/
@Configuration
public class RedisConfig {
@Bean
@SuppressWarnings("unchecked")
public RedisScript<Long> limitRedisScript() {
DefaultRedisScript redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/redis/limit.lua")));
redisScript.setResultType(Long.class);
return redisScript;
}
}
package com.xkcoding.ratelimit.redis.util;
import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
/**
* <p>
* IP 工具类
* </p>
*
* @author yangkai.shen
* @date Created in 2019/9/30 10:38
*/
@Slf4j
public class IpUtil {
private final static String UNKNOWN = "unknown";
private final static int MAX_LENGTH = 15;
/**
* 获取IP地址
* 使用Nginx等反向代理软件, 则不能通过request.getRemoteAddr()获取IP地址
* 如果使用了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP地址,X-Forwarded-For中第一个非unknown的有效IP字符串,则为真实IP地址
*/
public static String getIpAddr() {
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
String ip = null;
try {
ip = request.getHeader("x-forwarded-for");
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (StrUtil.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
} catch (Exception e) {
log.error("IPUtils ERROR ", e);
}
// 使用代理,则获取第一个IP地址
if (!StrUtil.isEmpty(ip) && ip.length() > MAX_LENGTH) {
if (ip.indexOf(StrUtil.COMMA) > 0) {
ip = ip.substring(0, ip.indexOf(StrUtil.COMMA));
}
}
return ip;
}
}
......@@ -2,3 +2,20 @@ server:
port: 8080
servlet:
context-path: /demo
spring:
redis:
host: localhost
# 连接超时时间(记得添加单位,Duration)
timeout: 10000ms
# Redis默认情况下有16个分片,这里配置具体使用的分片
# database: 0
lettuce:
pool:
# 连接池最大连接数(使用负值表示没有限制) 默认 8
max-active: 8
# 连接池最大阻塞等待时间(使用负值表示没有限制) 默认 -1
max-wait: -1ms
# 连接池中的最大空闲连接 默认 8
max-idle: 8
# 连接池中的最小空闲连接 默认 0
min-idle: 0
-- 下标从 1 开始
local key = KEYS[1]
local now = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local expired = tonumber(ARGV[3])
-- 最大访问量
local max = tonumber(ARGV[4])
-- 清除过期的数据
-- 移除指定分数区间内的所有元素,expired 即已经过期的 score
-- 根据当前时间毫秒数 - 超时毫秒数,得到过期时间 expired
redis.call('zremrangebyscore', key, 0, expired)
-- 获取 zset 中的元素个数
local current = tonumber(redis.call('zcard', key))
--
local next = current + 1
if next > max then
-- 达到限流大小 返回 0
return 0;
else
-- 往 zset 中添加一个值、得分均为当前时间戳的元素,[value,score]
redis.call("zadd", key, now, now)
-- 每次访问均重新设置 zset 的过期时间,单位毫秒
redis.call("pexpire", key, ttl)
return next
end
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册