Hi, I am open to new collaboration opportunities. Know about me & my work here and reach me via mail or below socials.

I tried to build a rate-limiter using Go and Redis

21/3/2024 ā€¢153 views ā€¢8 min read


On my journey of learning, I decided to learn about rate-limiters and try and implement one.

I tried to implement two of the algorithm, Fixed bucket and Sliding Window Counter, using Go(Golang) and Redis.

You can know more about rate limiters and their algorithms from the resources below.

Here are some screenshots:

Fixed (without and with limit):

image
image

image
image

Sliding window counter:

  • Without rate limit, we can see that within the past 5 mins, 10 reqs were made

image
image

image
image

  • With rate limit, within the past 5 mins, more than 10 reqs were made, so the API is now rate limited

image
image

image
image

The main function contains a http server with two endpoints for each algo. I created a RateLimiter middleware for each of these endpoints, if rate limiter is not enabled, the handler functions execute.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 func main() { mux := http.NewServeMux() rl := &RateLimiter{} mux.HandleFunc("GET /fixed", rl.FixedBucket(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) res := APIResponse{ Message: "Fixed API endpoint", } json.NewEncoder(w).Encode(res) })) mux.HandleFunc("GET /window", rl.WindowCounter(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) res := APIResponse{ Message: "Window API endpoint", } json.NewEncoder(w).Encode(res) })) log.Fatal(http.ListenAndServe(":8080", mux)) }

Rate Limiter is a struct with two methods.

1 type RateLimiter struct{}

Fixed Bucket

This algo is pretty simple, you have a bucket of "counts". If the counts runs out you are rate limited. The bucket is also refilled at a specific rate, I've used go routine to refill it at 1 count per second.

For every request hit, we decrement the count. Oh, and I am using Redis to store the bucket and it's count.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 func (rl *RateLimiter) FixedBucket(next http.HandlerFunc) http.HandlerFunc { rdb.Set(ctx, "bucket", LIMIT, redis.KeepTTL) // Refills every 1 second until the bucket is filled go func() { for { time.Sleep(time.Second) if v, _ := strconv.Atoi(rdb.Get(ctx, "bucket").Val()); v < LIMIT { rdb.Incr(ctx, "bucket") } } }() return func(w http.ResponseWriter, r *http.Request) { val, err := strconv.Atoi(rdb.Get(ctx, "bucket").Val()) if err != nil { log.Panic(err) os.Exit(1) } w.Header().Add(HRLIMIT, fmt.Sprint(LIMIT)) w.Header().Add(HRREM, fmt.Sprint(val)) // if bucket becomes empty, throw rate limit error and status code if val <= 0 { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) res := APIResponse{ Message: "Rate Limit reached", } json.NewEncoder(w).Encode(res) } else { // For every hit, decrement the count rdb.Decr(ctx, "bucket") next.ServeHTTP(w, r) } } }

Sliding Window Counter

This is one of the most efficient algo. It basically checks if in the past time window, did you go over the rate limit count.

For example at 10 req/5 min - for every hit it will check if from now to 5 mins ago, did the endpoint hit 10 request.

If yes - rate limit, else - increment the count for current unix time. And, we delete all the keys stored in the redis that were older than the window time.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 func (rl *RateLimiter) WindowCounter(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ok, err := rl.isAllowed(rdb) if err != nil { log.Panic(err) os.Exit(1) } if !ok { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) res := APIResponse{ Message: "Rate Limit reached", } json.NewEncoder(w).Encode(res) } else { next.ServeHTTP(w, r) } } }

The isAllowed method takes in reference to redis and returns is rate limiting should be allowed or not.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 func (rl *RateLimiter) isAllowed(r *redis.Client) (bool, error) { now := time.Now() // We truncate the API hit time to only it's minute value, so 14:23:43 converts to 14:23:00 timeStamp := strconv.FormatInt(now.Truncate(time.Minute).Unix(), 10) // then we store it in the redis KV v, err := r.HIncrBy(ctx, "window", timeStamp, 1).Result() if err != nil { return false, err } // If the value for the current minute itself passes the limit, then we rate limit // so 14:23:00 had a count of more that LIMIT if v > LIMIT { return false, nil } // Then we also delete all the keys from the redis that are older than the window. // Ex: 5, keys that contained the timestap older that 5 mins from before are deleted vals, err := r.HGetAll(ctx, "window").Result() if err != nil { return false, err } total := 0 // this code, takes the time when API hit, subtracts window time, truncates it and gets the Unix value // so, if API hits at 15:00:53, this would return 14:55:00 threshold := now.Add(-time.Minute * time.Duration(WINDOW)).Truncate(time.Minute).Unix() // loop for all the keys, some error handling and type conversion for k, v := range vals { intKey, err := strconv.Atoi(k) if err != nil { return false, err } // if key in redis is still uder the window, increment the total count, else delete that key if int64(intKey) > threshold { val, err := strconv.Atoi(v) if err != nil { return false, err } total += val } else { r.HDel(ctx, "window", k) } } // If the total within the WINDOW time exceeds limit then we rate limit if total > LIMIT { return false, nil } return true, nil }

That's it. These were the two algos that I tried to implement in Go using Redis. Rate limiters used in production are more complex than this, they have something called config store and other pieces to them.

And they also use a dynamic keys like IP address or user ids to track each individual requests unlike in this blog with fixed key names like "bucket" and "window". You can read more on them from the refs.

References:


Found this blog helpful? Share it with your friends!