Traffic limiting is a tool frequently used in projects. It is generally used to limit the frequency of user requests, prevent system crashes caused by excessive instantaneous traffic, or stabilize the message processing rate. And sometimes we also need to use distributed limiting, the common implementation is to use Redis as the central storage.

This article mainly uses Go+Redis to achieve common traffic limiting algorithms. If you need to understand the principle of each traffic limiting algorithm, you can read the article Go to achieve common traffic limiting algorithms

The following code uses the Go-Redis client

Fixed window

Using Redis to implement fixed Windows is relatively easy, mainly because there is only one fixed window at the same time, so we can use pexpire command to set the expiration time to the window size when entering the window for the first time, so that the window will expire with the expiration time, and we can use incr command to increase the window count.

Since we need to set the window expiration time when counter==1, we use a simple Lua script to ensure atomicity.

const fixedWindowLimiterTryAcquireRedisScript = `
-- ARGV[1]: window time size
-- ARGV[2]: window request upper limit

local window = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])

-- Get the original value
local counter = tonumber(redis.call("get", KEYS[1]))
if counter == nil then 
   counter = 0
end
If the window request upper limit is reached, the request fails
if counter >= limit then
   return 0
end
-- Window value +1
redis.call("incr", KEYS[1])
if counter == 0 then
    redis.call("pexpire", KEYS[1], window)
end
return 1
`
Copy the code
package redis

import (
   "context"
   "errors"
   "github.com/go-redis/redis/v8"
   "time"
)

// FixedWindowLimiter FixedWindowLimiter
type FixedWindowLimiter struct {
   limit  int           // Window request upper limit
   window int           // Window time size
   client *redis.Client // The Redis client
   script *redis.Script / / TryAcquire script
}

func NewFixedWindowLimiter(client *redis.Client, limit int, window time.Duration) (*FixedWindowLimiter, error) {
   // Redis expiration time accuracy is up to milliseconds, so the window must be divisible by milliseconds
   ifwindow%time.Millisecond ! =0 {
      return nil, errors.New("the window uint must not be less than millisecond")}return &FixedWindowLimiter{
      limit:  limit,
      window: int(window / time.Millisecond),
      client: client,
      script: redis.NewScript(fixedWindowLimiterTryAcquireRedisScript),
   }, nil
}

func (l *FixedWindowLimiter) TryAcquire(ctx context.Context, resource string) error {
   success, err := l.script.Run(ctx, l.client, []string{resource}, l.window, l.limit).Bool()
   iferr ! =nil {
      return err
   }
   // If the window request limit is reached, the request fails
   if! success {return ErrAcquireFailed
   }
   return nil
}
Copy the code

The sliding window

Hash implementation

We use Redis hash to store the count of each small window, add the count of all valid Windows to count for each request, delete invalid Windows with hdel, and finally determine whether the total number of Windows is greater than the upper limit.

We basically put all the logic in the Lua script, where the big head is the hash traversal, the time complexity is O(N), and the N is the number of small Windows, so it’s best not to have too many small Windows.

const slidingWindowLimiterTryAcquireRedisScriptHashImpl = `
-- ARGV[1]: window time size
-- ARGV[2]: window request upper limit
-- ARGV[3]: current small window value
-- ARGV[4]: start small window value

local window = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local currentSmallWindow = tonumber(ARGV[3])
local startSmallWindow = tonumber(ARGV[4])

Count the total number of requests for the current window
local counters = redis.call("hgetall", KEYS[1])
local count = 0
for i = 1, #(counters) / 2 do 
   local smallWindow = tonumber(counters[i * 2 - 1])
   local counter = tonumber(counters[i * 2])
   if smallWindow < startSmallWindow then
      redis.call("hdel", KEYS[1], smallWindow)
   else 
      count = count + counter
   end
end

If the window request upper limit is reached, the request fails
if count >= limit then
   return 0
end

-- If the window request limit is not reached, the current small window counter +1, the request is successful
redis.call("hincrby", KEYS[1], currentSmallWindow, 1)
redis.call("pexpire", KEYS[1], window)
return 1
`
Copy the code
package redis

import (
   "context"
   "errors"
   "github.com/go-redis/redis/v8"
   "time"
)

// SlidingWindowLimiter
type SlidingWindowLimiter struct {
   limit        int           // Window request upper limit
   window       int64         // Window time size
   smallWindow  int64         // Small window time size
   smallWindows int64         // Number of small Windows
   client       *redis.Client // The Redis client
   script       *redis.Script / / TryAcquire script
}

func NewSlidingWindowLimiter(client *redis.Client, limit int, window, smallWindow time.Duration) ( *SlidingWindowLimiter, error) {
   // Redis expiration time accuracy is up to milliseconds, so the window must be divisible by milliseconds
   ifwindow%time.Millisecond ! =0|| smallWindow%time.Millisecond ! =0 {
      return nil, errors.New("the window uint must not be less than millisecond")}// Window time must be divisible by small window time
   ifwindow%smallWindow ! =0 {
      return nil, errors.New("window cannot be split by integers")}return &SlidingWindowLimiter{
      limit:        limit,
      window:       int64(window / time.Millisecond),
      smallWindow:  int64(smallWindow / time.Millisecond),
      smallWindows: int64(window / smallWindow),
      client:       client,
      script:       redis.NewScript(slidingWindowLimiterTryAcquireRedisScriptHashImpl),
   }, nil
}

func (l *SlidingWindowLimiter) TryAcquire(ctx context.Context, resource string) error {
   // Get the current small window value
   currentSmallWindow := time.Now().UnixMilli() / l.smallWindow * l.smallWindow
   // Get the initial window value
   startSmallWindow := currentSmallWindow - l.smallWindow*(l.smallWindows- 1)

   success, err := l.script.Run(
      ctx, l.client, []string{resource}, l.window, l.limit, currentSmallWindow, startSmallWindow).Bool()
   iferr ! =nil {
      return err
   }
   // If the window request limit is reached, the request fails
   if! success {return ErrAcquireFailed
   }
   return nil
}
Copy the code

The list implementation

If the number of small Windows is very large, you can use list to optimize the time complexity. The structure of list is:

[counter, smallWindow1, count1, smallWindow2, count2, smallWindow3, count3...]

That is, we use the first element of the list to store the counter, and each window is represented by two elements. The first element represents the value of the small window, and the second element represents the count of the small window. The reason for not putting small window values and counts directly into an element is that there is no function in the Redis Lua script to split strings.

Specific operation process:

  • Get the list length
  • If the length is 0, set counter, length +1
  • If the length is greater than 1, get the second and third elements
    • If the value is less than the initial small window value, counter- the value of the third element, remove the second and third elements, length -2
  • If counter is greater than or equal to limit, the request fails
  • If the length is greater than 1, get the penultimate first element
    • If the value of the penultimate element small window is greater than or equal to the value of the current small window, the current request due to network delay reached the server, the window is outdated, the penultimate element as the current small window (because it is updated), the value of the penultimate element +1
    • Otherwise, add a new window value, add a new count (1), and update the expiration time
  • Otherwise, add a new window value, add a new count (1), and update the expiration time
  • counter + 1
  • Return to success
const slidingWindowLimiterTryAcquireRedisScriptListImpl = `
-- ARGV[1]: window time size
-- ARGV[2]: window request upper limit
-- ARGV[3]: current small window value
-- ARGV[4]: start small window value

local window = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local currentSmallWindow = tonumber(ARGV[3])
local startSmallWindow = tonumber(ARGV[4])

Get the length of the list
local len = redis.call("llen", KEYS[1])
If the length is 0, set counter to +1
local counter = 0
if len= =0 then 
   redis.call("rpush", KEYS[1].0)
   redis.call("pexpire", KEYS[1], window)
   len = len + 1
else
   If the length is greater than 1, get the second element
   local smallWindow1 = tonumber(redis.call("lindex", KEYS[1].1))
   counter = tonumber(redis.call("lindex", KEYS[1].0))
   -- If the value is less than the initial small window value
   if smallWindow1 < startSmallWindow then 
      local count1 = redis.call("lindex", KEYS[1].2)
      -- counter- Value of the third element
      counter = counter - count1
      -- - 2 length
      len = len - 2
      Delete the second and third elements
      redis.call("lrem", KEYS[1].1, smallWindow1)
      redis.call("lrem", KEYS[1].1, count1)
   end
end

If the window request upper limit is reached, the request fails
if counter >= limit then 
   return 0
end 

If the length is greater than 1, get the second-to-last element
if len > 1 then
   local smallWindown = tonumber(redis.call("lindex", KEYS[1].2 -))
   -- If the penultimate element's small window value is greater than or equal to the current small window value
   if smallWindown >= currentSmallWindow then
      Treat the next-to-last element as the current widget (because it updates), and the next-to-last element as +1
      local countn = redis.call("lindex", KEYS[1].- 1)
      redis.call("lset", KEYS[1].- 1, countn + 1)
   else 
      Otherwise, add a new window value, add a new count (1), update the expiration time
      redis.call("rpush", KEYS[1], currentSmallWindow, 1)
      redis.call("pexpire", KEYS[1], window)
   end
else 
   Otherwise, add a new window value, add a new count (1), update the expiration time
   redis.call("rpush", KEYS[1], currentSmallWindow, 1)
   redis.call("pexpire", KEYS[1], window)
end 

-- counter + 1 and update
redis.call("lset", KEYS[1].0, counter + 1)
return 1
`
Copy the code

Algorithms operate on the head or tail of the list, so the time complexity is close to O(1).

Bucket algorithm

The leaky bucket needs to store the current water level and the time of the last drain, so we use hash to store these two values.

const leakyBucketLimiterTryAcquireRedisScript = `
-- ARGV[1]: indicates the highest water level
-- ARGV[2]: flow velocity per second
-- ARGV[3]: current time (seconds)

local peakLevel = tonumber(ARGV[1])
local currentVelocity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])

local lastTime = tonumber(redis.call("hget", KEYS[1]."lastTime"))
local currentLevel = tonumber(redis.call("hget", KEYS[1]."currentLevel"))
Initialization -
if lastTime == nil then 
   lastTime = now
   currentLevel = 0
   redis.call("hmset", KEYS[1]."currentLevel", currentLevel, "lastTime", lastTime)
end 

-- Try pouring water
-- Time since last water release
local interval = now - lastTime
if interval > 0 then
   -- Current water level -- time since last discharge (s)* flow speed
   local newLevel = currentLevel - interval * currentVelocity
   if newLevel < 0 then 
      newLevel = 0
   end 
   currentLevel = newLevel
   redis.call("hmset", KEYS[1]."currentLevel", newLevel, "lastTime", now)
end

-- If the request reaches the highest watermark, the request fails
if currentLevel >= peakLevel then
   return 0
end
-- If the request does not reach the highest watermark, the current watermark is +1
redis.call("hincrby", KEYS[1]."currentLevel".1)
redis.call("expire", KEYS[1], peakLevel / currentVelocity)
return 1
`
Copy the code
package redis

import (
   "context"
   "github.com/go-redis/redis/v8"
   "time"
)

// LeakyBucketLimiter Leak bucket current limiter
type LeakyBucketLimiter struct {
   peakLevel       int           // Maximum water level
   currentVelocity int           // Flow speed per second
   client          *redis.Client // The Redis client
   script          *redis.Script / / TryAcquire script
}

func NewLeakyBucketLimiter(client *redis.Client, peakLevel, currentVelocity int) *LeakyBucketLimiter {
   return &LeakyBucketLimiter{
      peakLevel:       peakLevel,
      currentVelocity: currentVelocity,
      client:          client,
      script:          redis.NewScript(leakyBucketLimiterTryAcquireRedisScript),
   }
}

func (l *LeakyBucketLimiter) TryAcquire(ctx context.Context, resource string) error {
   // The current time
   now := time.Now().Unix()
   success, err := l.script.Run(ctx, l.client, []string{resource}, l.peakLevel, l.currentVelocity, now).Bool()
   iferr ! =nil {
      return err
   }
   // If the window request limit is reached, the request fails
   if! success {return ErrAcquireFailed
   }
   return nil
}
Copy the code

The token bucket

Token buckets can be thought of as the opposite of a leaky bucket, either pouring water into the bucket or getting a token from the bucket.

const tokenBucketLimiterTryAcquireRedisScript = `
- ARGV [1] : capacity
-- ARGV[2]: token issuing rate in seconds
-- ARGV[3]: current time (seconds)

local capacity = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local now = tonumber(ARGV[3])

local lastTime = tonumber(redis.call("hget", KEYS[1]."lastTime"))
local currentTokens = tonumber(redis.call("hget", KEYS[1]."currentTokens"))
Initialization -
if lastTime == nil then 
   lastTime = now
   currentTokens = capacity
   redis.call("hmset", KEYS[1]."currentTokens", currentTokens, "lastTime", lastTime)
end 

-- Attempt to issue tokens
-- Time since the last token was issued
local interval = now - lastTime
if interval > 0 then
   -- Number of current tokens + time since last token issue (s)* token issue rate
   local newTokens = currentTokens + interval * rate
   if newTokens > capacity then 
      newTokens = capacity
   end 
   currentTokens = newTokens
   redis.call("hmset", KEYS[1]."currentTokens", newTokens, "lastTime", now)
end

If there is no token, the request fails
if currentTokens == 0 then
   return 0
end
If there is a token, current token -1, request successful
redis.call("hincrby", KEYS[1]."currentTokens".- 1)
redis.call("expire", KEYS[1], capacity / rate)
return 1
`
Copy the code
package redis

import (
   "context"
   "github.com/go-redis/redis/v8"
   "time"
)

// TokenBucketLimiter token bucket current limiter
type TokenBucketLimiter struct {
   capacity int           / / capacity
   rate     int           // Token issuing rate/second
   client   *redis.Client // The Redis client
   script   *redis.Script / / TryAcquire script
}

func NewTokenBucketLimiter(client *redis.Client, capacity, rate int) *TokenBucketLimiter {
   return &TokenBucketLimiter{
      capacity: capacity,
      rate:     rate,
      client:   client,
      script:   redis.NewScript(tokenBucketLimiterTryAcquireRedisScript),
   }
}

func (l *TokenBucketLimiter) TryAcquire(ctx context.Context, resource string) error {
   // The current time
   now := time.Now().Unix()
   success, err := l.script.Run(ctx, l.client, []string{resource}, l.capacity, l.rate, now).Bool()
   iferr ! =nil {
      return err
   }
   // If the window request limit is reached, the request fails
   if! success {return ErrAcquireFailed
   }
   return nil
}
Copy the code

Sliding log

The algorithm flow is the same as a sliding window, except that it can specify multiple policies, and when a request fails, the caller needs to be informed which policy is blocked.

const slidingLogLimiterTryAcquireRedisScriptHashImpl = `
-- ARGV[1]: the current small window value
-- ARGV[2]: The window time size of the first policy
-- ARGV[I * 2 + 1]: The starting small window value for each policy
-- ARGV[I * 2 + 2]: maximum window requests per policy

local currentSmallWindow = tonumber(ARGV[1])
-- The window size for the first policy
local window = tonumber(ARGV[2])
-- The starting small window value for the first policy
local startSmallWindow = tonumber(ARGV[3])
local strategiesLen = #(ARGV) / 2 - 1

Count the total number of requests for the current window of each policy
local counters = redis.call("hgetall", KEYS[1])
local counts = {}
-- Initialize counts
for j = 1, strategiesLen do
   counts[j] = 0
end

for i = 1, #(counters) / 2 do 
   local smallWindow = tonumber(counters[i * 2 - 1])
   local counter = tonumber(counters[i * 2])
   if smallWindow < startSmallWindow then
      redis.call("hdel", KEYS[1], smallWindow)
   else 
      for j = 1, strategiesLen do
         if smallWindow >= tonumber(ARGV[j * 2 + 1]) then
            counts[j] = counts[j] + counter
         end
      end
   end
end

If the request reaches the upper limit of the corresponding policy window, the request fails and the violated policy subscript is returned
for i = 1, strategiesLen do
   if counts[i] >= tonumber(ARGV[i * 2 + 2]) then
      return i - 1
   end
end

-- If the window request limit is not reached, the current small window counter +1, the request is successful
redis.call("hincrby", KEYS[1], currentSmallWindow, 1)
redis.call("pexpire", KEYS[1], window)
return - 1
`
Copy the code
package redis

import (
   "context"
   "errors"
   "fmt"
   "github.com/go-redis/redis/v8"
   "sort"
   "time"
)

// ViolationStrategyError Indicates a policy violation error
type ViolationStrategyError struct {
   Limit  int           // Window request upper limit
   Window time.Duration // Window time size
}

func (e *ViolationStrategyError) Error(a) string {
   return fmt.Sprintf("violation strategy that limit = %d and window = %d", e.Limit, e.Window)
}

/ / SlidingLogLimiterStrategy sliding log current limiter strategy
type SlidingLogLimiterStrategy struct {
   limit        int   // Window request upper limit
   window       int64 // Window time size
   smallWindows int64 // Number of small Windows
}

func NewSlidingLogLimiterStrategy(limit int, window time.Duration) *SlidingLogLimiterStrategy {
   return &SlidingLogLimiterStrategy{
      limit:  limit,
      window: int64(window),
   }
}

// SlidingLogLimiter SlidingLogLimiter
type SlidingLogLimiter struct {
   strategies  []*SlidingLogLimiterStrategy // Slide log Current limiter policy list
   smallWindow int64                        // Small window time size
   client      *redis.Client                // The Redis client
   script      *redis.Script                / / TryAcquire script
}

func NewSlidingLogLimiter(client *redis.Client, smallWindow time.Duration, strategies ... *SlidingLogLimiterStrategy) ( *SlidingLogLimiter, error) {
   // The replication policy avoids being modified
   strategies = append(make([]*SlidingLogLimiterStrategy, 0.len(strategies)), strategies...)

   // The policy cannot be set
   if len(strategies) == 0 {
      return nil, errors.New("must be set strategies")}// Redis expiration time accuracy is up to milliseconds, so the window must be divisible by milliseconds
   ifsmallWindow%time.Millisecond ! =0 {
      return nil, errors.New("the window uint must not be less than millisecond")
   }
   smallWindow = smallWindow / time.Millisecond
   for _, strategy := range strategies {
      if strategy.window%int64(time.Millisecond) ! =0 {
         return nil, errors.New("the window uint must not be less than millisecond")
      }
      strategy.window = strategy.window / int64(time.Millisecond)
   }

   // The sorting strategy, the window with a large time in front of the row, the same window with a large ceiling in front of the row
   sort.Slice(strategies, func(i, j int) bool {
      a, b := strategies[i], strategies[j]
      if a.window == b.window {
         return a.limit > b.limit
      }
      return a.window > b.window
   })

   for i, strategy := range strategies {
      // As the window time gets smaller, the window upper limit should also get smaller
      if i > 0 {
         if strategy.limit >= strategies[i- 1].limit {
            return nil, errors.New("the smaller window should be the smaller limit")}}// Window time must be divisible by small window time
      if strategy.window%int64(smallWindow) ! =0 {
         return nil, errors.New("window cannot be split by integers")
      }
      strategy.smallWindows = strategy.window / int64(smallWindow)
   }

   return &SlidingLogLimiter{
      strategies:  strategies,
      smallWindow: int64(smallWindow),
      client:      client,
      script:      redis.NewScript(slidingLogLimiterTryAcquireRedisScriptHashImpl),
   }, nil
}

func (l *SlidingLogLimiter) TryAcquire(ctx context.Context, resource string) error {
   // Get the current small window value
   currentSmallWindow := time.Now().UnixMilli() / l.smallWindow * l.smallWindow
   args := make([]interface{}, len(l.strategies)*2+2)
   args[0] = currentSmallWindow
   args[1] = l.strategies[0].window
   // Gets the starting small window value for each policy
   for i, strategy := range l.strategies {
      args[i*2+2] = currentSmallWindow - l.smallWindow*(strategy.smallWindows- 1)
      args[i*2+3] = strategy.limit
   }

   index, err := l.script.Run(
      ctx, l.client, []string{resource}, args...) .Int()iferr ! =nil {
      return err
   }
   // If the window request limit is reached, the request fails
   ifindex ! =- 1 {
      return &ViolationStrategyError{
         Limit:  l.strategies[index].limit,
         Window: time.Duration(l.strategies[index].window),
      }
   }
   return nil
}
Copy the code

conclusion

Since Redis has rich and high-performance data types, it is not difficult to implement limiting algorithms using Redis, but each algorithm requires Lua scripting, so if you are not familiar with Lua, you may tread a few holes.

The complete code and test code can be viewed at github.com/jiaxwu/limi…