dify 源码分析(六)ratelimiter
文章目录
- 1. 平滑限流策略(RateLimiter )
- 1.1. 核心算法实现
- 1.1.1. 令牌桶算法 (Token Bucket)
- 1.1.2. 漏桶算法 (Leaky Bucket)
- 1.2. 滑动窗口限流
- 1.2.1. 时间片滑动窗口
- 1.2.2. 分布式滑动窗口 (Redis实现)
- 1.3. 自适应限流策略
- 1.3.1. 基于系统负载的自适应限流
- 1.3.2. 梯度限流策略
- 1.4. 流式输出专用限流器
- 1.4.1. 字符级平滑限流
- 1.4.2. 优先级限流策略
- 1.5. 监控和统计
- 1.6. 配置管理
- 2. 代码分析
1. 平滑限流策略(RateLimiter )
1.1. 核心算法实现
1.1.1. 令牌桶算法 (Token Bucket)
public class TokenBucketRateLimiter {private final int capacity; // 桶容量private final double refillRate; // 令牌填充速率 (令牌/毫秒)private double tokens; // 当前令牌数量private long lastRefillTimestamp; // 上次填充时间戳public TokenBucketRateLimiter(int capacity, int refillsPerSecond) {this.capacity = capacity;this.refillRate = refillsPerSecond / 1000.0;this.tokens = capacity;this.lastRefillTimestamp = System.currentTimeMillis();}public synchronized boolean tryAcquire(int permits) {refillTokens();if (tokens >= permits) {tokens -= permits;return true;}return false;}public synchronized long acquire(int permits) {long waitTime = 0;while (true) {refillTokens();if (tokens >= permits) {tokens -= permits;return waitTime;}// 计算需要等待的时间double missingTokens = permits - tokens;long requiredWait = (long) (missingTokens / refillRate);try {Thread.sleep(requiredWait);waitTime += requiredWait;} catch (InterruptedException e) {Thread.currentThread().interrupt();return -1;}}}private void refillTokens() {long currentTime = System.currentTimeMillis();if (currentTime > lastRefillTimestamp) {long timeElapsed = currentTime - lastRefillTimestamp;double tokensToAdd = timeElapsed * refillRate;tokens = Math.min(capacity, tokens + tokensToAdd);lastRefillTimestamp = currentTime;}}// 获取当前状态public synchronized double getAvailableTokens() {refillTokens();return tokens;}
}
1.1.2. 漏桶算法 (Leaky Bucket)
public class LeakyBucketRateLimiter {private final int capacity; // 桶容量private final long leakInterval; // 漏水间隔 (毫秒)private double waterLevel; // 当前水位private long lastLeakTimestamp; // 上次漏水时间public LeakyBucketRateLimiter(int capacity, int leaksPerSecond) {this.capacity = capacity;this.leakInterval = 1000 / leaksPerSecond;this.waterLevel = 0;this.lastLeakTimestamp = System.currentTimeMillis();}public synchronized boolean tryAcquire(int permits) {leakWater();if (waterLevel + permits <= capacity) {waterLevel += permits;return true;}return false;}public synchronized long acquire(int permits) {leakWater();if (waterLevel + permits <= capacity) {waterLevel += permits;return 0;}// 计算需要等待的时间double overflow = (waterLevel + permits) - capacity;long waitTime = (long) (overflow * leakInterval);try {Thread.sleep(waitTime);waterLevel += permits - (overflow / leakInterval);return waitTime;} catch (InterruptedException e) {Thread.currentThread().interrupt();return -1;}}private void leakWater() {long currentTime = System.currentTimeMillis();long timeElapsed = currentTime - lastLeakTimestamp;if (timeElapsed >= leakInterval) {long leaks = timeElapsed / leakInterval;waterLevel = Math.max(0, waterLevel - leaks);lastLeakTimestamp = currentTime;}}
}
1.2. 滑动窗口限流
1.2.1. 时间片滑动窗口
public class SlidingWindowRateLimiter {private final int maxRequests; // 时间窗口内最大请求数private final long windowSizeInMillis; // 时间窗口大小(毫秒)private final int segments; // 窗口分段数private final long[] timestamps; // 时间戳数组private final int[] counters; // 计数器数组private int currentSegment; // 当前段索引public SlidingWindowRateLimiter(int maxRequests, long windowSizeInMillis, int segments) {this.maxRequests = maxRequests;this.windowSizeInMillis = windowSizeInMillis;this.segments = segments;this.timestamps = new long[segments];this.counters = new int[segments];this.currentSegment = 0;long currentTime = System.currentTimeMillis();for (int i = 0; i < segments; i++) {timestamps[i] = currentTime;}}public synchronized boolean tryAcquire(int permits) {long currentTime = System.currentTimeMillis();updateWindow(currentTime);// 计算当前窗口内的总请求数int totalRequests = 0;for (int i = 0; i < segments; i++) {totalRequests += counters[i];}if (totalRequests + permits <= maxRequests) {counters[currentSegment] += permits;return true;}return false;}private void updateWindow(long currentTime) {long segmentSize = windowSizeInMillis / segments;long currentSegmentStart = currentTime - (currentTime % segmentSize);// 如果当前时间已经进入新的时间段if (currentSegmentStart > timestamps[currentSegment]) {int segmentsToAdvance = (int) ((currentSegmentStart - timestamps[currentSegment]) / segmentSize);for (int i = 1; i <= segmentsToAdvance; i++) {int newSegment = (currentSegment + i) % segments;timestamps[newSegment] = currentSegmentStart - (segmentsToAdvance - i) * segmentSize;counters[newSegment] = 0;}currentSegment = (currentSegment + segmentsToAdvance) % segments;}}public synchronized int getCurrentRequests() {updateWindow(System.currentTimeMillis());int total = 0;for (int counter : counters) {total += counter;}return total;}
}
1.2.2. 分布式滑动窗口 (Redis实现)
@Component
public class RedisSlidingWindowRateLimiter {private final RedisTemplate<String, String> redisTemplate;private final StringRedisTemplate stringRedisTemplate;private static final String LUA_SCRIPT = "local key = KEYS[1]\n" +"local now = tonumber(ARGV[1])\n" +"local window = tonumber(ARGV[2])\n" +"local limit = tonumber(ARGV[3])\n" +"local clearBefore = now - window\nn" +"\n" +"redis.call('ZREMRANGEBYSCORE', key, 0, clearBefore)\n" +"local current = redis.call('ZCARD', key)\n" +"\n" +"if current < limit then\n" +" redis.call('ZADD', key, now, now)\n" +" redis.call('EXPIRE', key, window/1000)\n" +" return 1\n" +"else\n" +" return 0\n" +"end";public RedisSlidingWindowRateLimiter(RedisTemplate<String, String> redisTemplate,StringRedisTemplate stringRedisTemplate) {this.redisTemplate = redisTemplate;this.stringRedisTemplate = stringRedisTemplate;}public boolean tryAcquire(String key, int maxRequests, long windowInMillis) {long now = System.currentTimeMillis();Long result = stringRedisTemplate.execute(new DefaultRedisScript<>(LUA_SCRIPT, Long.class),Collections.singletonList(key),String.valueOf(now),String.valueOf(windowInMillis),String.valueOf(maxRequests));return result != null && result == 1;}
}
1.3. 自适应限流策略
1.3.1. 基于系统负载的自适应限流
@Component
public class AdaptiveRateLimiter {private final TokenBucketRateLimiter baseLimiter;private final int baseRate; // 基础限流速率private final int minRate; // 最小限流速率private final int maxRate; // 最大限流速率// 系统指标监控private double systemLoad = 0.0;private long lastUpdateTime = System.currentTimeMillis();private final double loadThreshold = 0.8; // 系统负载阈值public AdaptiveRateLimiter(int baseRate, int minRate, int maxRate) {this.baseRate = baseRate;this.minRate = minRate;this.maxRate = maxRate;this.baseLimiter = new TokenBucketRateLimiter(baseRate, baseRate);}public synchronized boolean tryAcquire(int permits) {updateRateBasedOnSystemLoad();return baseLimiter.tryAcquire(permits);}private void updateRateBasedOnSystemLoad() {long currentTime = System.currentTimeMillis();if (currentTime - lastUpdateTime < 1000) { // 每秒更新一次return;}// 获取系统负载 (需要根据具体环境实现)double currentLoad = getSystemLoad();this.systemLoad = 0.7 * this.systemLoad + 0.3 * currentLoad; // 平滑处理// 根据系统负载调整限流速率int newRate;if (systemLoad > loadThreshold) {// 系统负载高,降低限流速率double reductionFactor = 1.0 - (systemLoad - loadThreshold) / (1.0 - loadThreshold);newRate = (int) (baseRate * Math.max(0.1, reductionFactor));newRate = Math.max(minRate, newRate);} else {// 系统负载低,可以适当提高限流速率double increaseFactor = 1.0 + (loadThreshold - systemLoad) / loadThreshold;newRate = (int) (baseRate * Math.min(2.0, increaseFactor));newRate = Math.min(maxRate, newRate);}// 更新限流器 (需要TokenBucketRateLimiter支持动态调整)updateLimiterRate(newRate);lastUpdateTime = currentTime;}private double getSystemLoad() {// 实现获取系统负载的逻辑// 这里使用CPU负载作为示例OperatingSystemMXBean osBean = ManagementFactory.getOperatingSystemMXBean();if (osBean instanceof com.sun.management.OperatingSystemMXBean) {return ((com.sun.management.OperatingSystemMXBean) osBean).getSystemCpuLoad();}return 0.5; // 默认值}private void updateLimiterRate(int newRate) {// 动态更新限流器速率// 需要TokenBucketRateLimiter支持速率调整}// 基于响应时间的自适应调整public void recordResponseTime(long responseTime, long threshold) {if (responseTime > threshold) {// 响应时间过长,触发限流调整adjustRateForSlowResponse();}}private synchronized void adjustRateForSlowResponse() {// 根据响应时间调整限流策略}
}
1.3.2. 梯度限流策略
public class GradientRateLimiter {private final Map<Integer, RateLimiter> limiters = new ConcurrentHashMap<>();private final int[] thresholds; // 流量阈值数组private final int[] rates; // 对应限流速率public GradientRateLimiter(int[] thresholds, int[] rates) {if (thresholds.length != rates.length) {throw new IllegalArgumentException("阈值和速率数组长度必须相同");}this.thresholds = thresholds;this.rates = rates;// 为每个阈值创建对应的限流器for (int rate : rates) {limiters.put(rate, new TokenBucketRateLimiter(rate, rate));}}public boolean tryAcquire(int currentQPS) {int targetRate = determineRate(currentQPS);RateLimiter limiter = limiters.get(targetRate);return limiter.tryAcquire(1);}private int determineRate(int currentQPS) {for (int i = 0; i < thresholds.length; i++) {if (currentQPS <= thresholds[i]) {return rates[i];}}// 超过所有阈值,使用最严格的限流return rates[rates.length - 1];}// 示例配置: 当QPS<100时限流100,100-200时限流150,>200时限流200public static GradientRateLimiter createDefault() {int[] thresholds = {100, 200};int[] rates = {100, 150, 200};return new GradientRateLimiter(thresholds, rates);}
}
1.4. 流式输出专用限流器
1.4.1. 字符级平滑限流
public class StreamingOutputRateLimiter {private final RateLimiter characterLimiter; // 字符输出限流private final RateLimiter chunkLimiter; // 数据块限流private final ScheduledExecutorService scheduler;// 限流配置private final int maxCharsPerSecond;private final int maxChunksPerSecond;private final int minChunkSize;private final int maxChunkSize;public StreamingOutputRateLimiter(int maxCharsPerSecond, int maxChunksPerSecond) {this.maxCharsPerSecond = maxCharsPerSecond;this.maxChunksPerSecond = maxChunksPerSecond;this.minChunkSize = 1;this.maxChunkSize = Math.max(1, maxCharsPerSecond / maxChunksPerSecond);this.characterLimiter = new TokenBucketRateLimiter(maxCharsPerSecond, maxCharsPerSecond);this.chunkLimiter = new TokenBucketRateLimiter(maxChunksPerSecond, maxChunksPerSecond);this.scheduler = Executors.newScheduledThreadPool(1);}public CompletableFuture<Void> streamOutput(String content, Consumer<String> outputConsumer) {return CompletableFuture.runAsync(() -> {try {streamContentSmoothly(content, outputConsumer);} catch (InterruptedException e) {Thread.currentThread().interrupt();throw new RuntimeException("流式输出被中断", e);}});}private void streamContentSmoothly(String content, Consumer<String> outputConsumer) throws InterruptedException {char[] characters = content.toCharArray();int position = 0;while (position < characters.length) {// 获取数据块许可chunkLimiter.acquire(1);// 确定当前块的大小int chunkSize = calculateChunkSize(characters.length - position);// 获取字符输出许可if (!characterLimiter.tryAcquire(chunkSize)) {// 令牌不足,等待并重试Thread.sleep(calculateWaitTime(chunkSize));continue;}// 输出数据块String chunk = new String(characters, position, chunkSize);outputConsumer.accept(chunk);position += chunkSize;// 添加微小延迟,使输出更平滑if (position < characters.length) {Thread.sleep(10);}}}private int calculateChunkSize(int remainingChars) {// 动态调整块大小,保持输出平滑int idealChunkSize = maxCharsPerSecond / maxChunksPerSecond;return Math.min(Math.min(idealChunkSize, maxChunkSize), remainingChars);}private long calculateWaitTime(int requiredTokens) {double missingTokens = requiredTokens - characterLimiter.getAvailableTokens();return (long) (missingTokens / (maxCharsPerSecond / 1000.0));}// 动态调整限流参数public void adjustRate(int newCharsPerSecond, int newChunksPerSecond) {// 实现动态调整逻辑}public void shutdown() {scheduler.shutdown();}
}
1.4.2. 优先级限流策略
public class PriorityRateLimiter {private final Map<Integer, RateLimiter> priorityLimiters;private final int baseRate;public PriorityRateLimiter(int baseRate) {this.baseRate = baseRate;this.priorityLimiters = new ConcurrentHashMap<>();// 初始化不同优先级的限流器// 优先级越高,限流越宽松priorityLimiters.put(1, new TokenBucketRateLimiter(baseRate / 4, baseRate / 4)); // 低优先级priorityLimiters.put(2, new TokenBucketRateLimiter(baseRate / 2, baseRate / 2)); // 中优先级priorityLimiters.put(3, new TokenBucketRateLimiter(baseRate, baseRate)); // 高优先级priorityLimiters.put(4, new TokenBucketRateLimiter(baseRate * 2, baseRate * 2)); // 最高优先级}public boolean tryAcquire(int priority) {RateLimiter limiter = priorityLimiters.get(priority);if (limiter == null) {limiter = priorityLimiters.get(2); // 默认中优先级}return limiter.tryAcquire(1);}public long acquire(int priority) {RateLimiter limiter = priorityLimiters.get(priority);if (limiter == null) {limiter = priorityLimiters.get(2);}return limiter.acquire(1);}// 为重要数据设置高优先级public boolean tryAcquireForImportantData() {return tryAcquire(4);}// 为普通数据设置中优先级public boolean tryAcquireForNormalData() {return tryAcquire(2);}
}
1.5. 监控和统计
@Component
public class RateLimitMonitor {private final MeterRegistry meterRegistry;private final Map<String, Counter> limitCounters = new ConcurrentHashMap<>();private final Map<String, Timer> requestTimers = new ConcurrentHashMap<>();public RateLimitMonitor(MeterRegistry meterRegistry) {this.meterRegistry = meterRegistry;}public void recordRequest(String endpoint, boolean limited, long duration) {// 记录请求指标String status = limited ? "limited" : "allowed";Counter counter = limitCounters.computeIfAbsent(endpoint + "." + status,key -> Counter.builder("ratelimit.requests").tag("endpoint", endpoint).tag("status", status).register(meterRegistry));counter.increment();// 记录响应时间Timer timer = requestTimers.computeIfAbsent(endpoint,key -> Timer.builder("ratelimit.duration").tag("endpoint", endpoint).register(meterRegistry));timer.record(duration, TimeUnit.MILLISECONDS);}public void recordSystemLoad(double load) {Gauge.builder("ratelimit.system.load", () -> load).register(meterRegistry);}public Map<String, Object> getRateLimitStats() {Map<String, Object> stats = new HashMap<>();// 实现统计信息收集return stats;}
}
1.6. 配置管理
@Configuration
@ConfigurationProperties(prefix = "ratelimit")
@Data
public class RateLimitConfig {private boolean enabled = true;private int defaultRate = 100; // 默认每秒100个请求private int burstCapacity = 150; // 突发容量private long timeoutMillis = 1000; // 超时时间private Map<String, Integer> endpoints = new HashMap<>();// 流式输出特殊配置private Streaming streaming = new Streaming();@Datapublic static class Streaming {private int charsPerSecond = 50; // 每秒字符数private int chunksPerSecond = 10; // 每秒数据块数private int minChunkSize = 1; // 最小块大小private int maxChunkSize = 10; // 最大块大小}public int getRateForEndpoint(String endpoint) {return endpoints.getOrDefault(endpoint, defaultRate);}
}
2. 代码分析
。。。。。。