The web is a weird place. You go to sleep thinking that you have a perfectly functional web application and the next day when you wake up, you might find yourself staring at a sudden huge spike in the number of requests. Either your app got popular overnight or you were just a victim of a DOS
attack trying to bring your app server down. Usually, it's the latter.
There are some popular gems like rack-attack and rack-throttle which work quite well and provides a lot of flexibility. But if you're looking to write your custom logic with minimum dependencies, then continue reading.
We will create a middleware
that intercepts and blocks any host which tries to overload our servers by firing too many requests within a short timespan. We will be using Redis to store the count of requests from each IP address.
Let's start by writing the most basic middleware
.
# app/lib/middlewares/custom_rate_limit.rb
class CustomRateLimit
def initialize(app)
@app = app
end
def call(env)
@app.call(env)
end
end
Add the following line inside your application.rb
:
# For Rails version > 5
config.middleware.use CustomRateLimit
If you're using Rails version lesser than 5, add the following:
# For Rails version < 5
config.middleware.use 'CustomRateLimit'
Run the bin/rails middleware
command and verify that our custom middleware is present in the middleware stack.
In case you run into an error saying uninitialized constant ::CustomRateLimit
, add the below line at the top of your application.rb
:
require_relative '../lib/middlewares/custom_rate_limit'
Now, let's install the redis
gem by adding it to our Gemfile
.
gem 'redis'
Also, ensure that you have the Redis
server installed on your local system. If not, install it by following the steps mentioned here.
Note: As Redis is an in-memory database, it's state is not persistent. In other words, if your redis server were to go down due to an outage, you'd lose your data. By default, Redis saves snapshots of the dataset on the disk, in a binary file called dump.rdb
. Refer redis persistence for more details.
To initialize redis
in your app, add the following file inside the config/initializers
:
# config/initializers/redis.rb
require 'redis'
REDIS = Redis.new(url: ENV.fetch('REDIS_URL'))
For development environment, the value of REDIS_URL
will be redis://localhost:6379
.
Now, let's edit our middleware and add some logic.
def call(env)
if should_allow?(env)
@app.call(env)
else
request_quota_exceeded
end
end
The should_allow?
function will look something like this:
def should_allow?(env)
key = "IP:#{env['action_dispatch.remote_ip']}"
REDIS.set(key, 0, nx: true, ex: TIME_PERIOD)
REDIS.incr(key) > LIMIT ? false : true
end
We will use the user's IP address as the key and store the request count as the value.
The Redis#set method will set the record in the redis store. We will pass it the following arguments:
key
- a unique identifier, which in this case will be the user's IP addressvalue
- sets the value against the given keyex
- sets the expiry time in secondsnx
- sets the key only if it doesn't already exist
On every request, we increment the count using Redis#incr. If the count exceeds the predefined limit, we return false
.
Define the constants in the same file and update the values as per your needs.
TIME_PERIOD = 60 # no. of seconds
LIMIT = 20 # no. of allowed requests per IP for unauthenticated user
If you are using Warden
based authentication like Devise
, and don't want to throttle authenticated requests, add the following guard condition to the should_allow?
function.
return true if env['rack.session']['warden.user.user.key'].present?
This will allow all authenticated requests to pass through without any rate limits.
The request_quota_exceeded
method will look something like this:
def request_quota_exceeded
[ 429, {}, ['Too many requests fired. Request quota exceeded!'] ]
end
The HTTP status code 429
indicates Too Many Requests
to the server in a given period.
Our middleware will finally look something like this:
# lib/middlewares/custom_rate_limit.rb
class CustomRateLimit
TIME_PERIOD = 60 # no. of seconds
LIMIT = 20 # no. of allowed requests per IP for unauthenticated user
def initialize(app)
@app = app
end
def call(env)
if should_allow?(env)
@app.call(env)
else
request_quota_exceeded
end
end
private
def should_allow?(env)
return true if env['rack.session']['warden.user.user.key'].present?
key = "IP:#{env['action_dispatch.remote_ip']}"
REDIS.set(key, 0, nx: true, ex: TIME_PERIOD)
REDIS.incr(key) > LIMIT ? false : true
end
def request_quota_exceeded
[ 429, {}, ['Too many requests fired. Request quota exceeded!'] ]
end
end
Hope this blog was helpful.
Thank you!