DEV: API to register custom request rate limiting conditions (#30239)

This commit adds the `add_request_rate_limiter` plugin API which allows plugins to add custom rate limiters on top of the default rate limiters which requests by a user's id or the request's IP address.

Example to add a rate limiter that rate limits all requests from Googlebot under the same rate limit bucket:

```
add_request_rate_limiter(
  identifier: :country,
  key: ->(request) { "country/#{DiscourseIpInfo.get(request.ip)[:country]}" },
  activate_when: ->(request) { DiscourseIpInfo.get(request.ip)[:country].present? },
)
```
This commit is contained in:
Alan Guo Xiang Tan 2024-12-23 09:57:18 +08:00 committed by GitHub
parent 259f537d02
commit 859d61003e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 470 additions and 57 deletions

View File

@ -45,6 +45,7 @@ Rails.autoloaders.each do |autoloader|
"http" => "HTTP", "http" => "HTTP",
"gc_stat_instrumenter" => "GCStatInstrumenter", "gc_stat_instrumenter" => "GCStatInstrumenter",
"chat_sdk" => "ChatSDK", "chat_sdk" => "ChatSDK",
"ip" => "IP",
) )
end end
Rails.autoloaders.main.ignore( Rails.autoloaders.main.ignore(

View File

@ -59,6 +59,25 @@ class Middleware::RequestTracker
@@ip_skipper @@ip_skipper
end end
def self.reset_rate_limiters_stack
@@stack =
begin
# Update the documentation for the `add_request_rate_limiter` plugin API if this list changes.
default_rate_limiters = [
RequestTracker::RateLimiters::User,
RequestTracker::RateLimiters::IP,
]
stack = RequestTracker::RateLimiters::Stack.new
default_rate_limiters.each { |limiter| stack.append(limiter) }
stack
end
end
def self.rate_limiters_stack
@@stack ||= reset_rate_limiters_stack
end
def initialize(app, settings = {}) def initialize(app, settings = {})
@app = app @app = app
end end
@ -317,23 +336,27 @@ class Middleware::RequestTracker
::Middleware::RequestTracker.populate_request_queue_seconds!(env) ::Middleware::RequestTracker.populate_request_queue_seconds!(env)
request = Rack::Request.new(env) request = Rack::Request.new(env)
cookie = find_auth_cookie(env) cookie = find_auth_cookie(env)
if error_details = rate_limit(request, cookie) if error_details = rate_limit(request, cookie)
available_in, error_code, limit_on_id = error_details available_in, error_code = error_details
message = <<~TEXT message = <<~TEXT
Slow down, too many requests from this #{limit_on_id ? "user" : "IP address"}. Slow down, you're making too many requests.
Please retry again in #{available_in} seconds. Please retry again in #{available_in} seconds.
Error code: #{error_code}. Error code: #{error_code}.
TEXT TEXT
headers = { headers = {
"Content-Type" => "text/plain", "Content-Type" => "text/plain",
"Retry-After" => available_in.to_s, "Retry-After" => available_in.to_s,
"Discourse-Rate-Limit-Error-Code" => error_code, "Discourse-Rate-Limit-Error-Code" => error_code,
} }
if username = cookie&.[](:username) if username = cookie&.[](:username)
headers["X-Discourse-Username"] = username headers["X-Discourse-Username"] = username
end end
return 429, headers, [message] return 429, headers, [message]
end end
@ -341,11 +364,13 @@ class Middleware::RequestTracker
if error_details = check_crawler_limits(env) if error_details = check_crawler_limits(env)
available_in, error_code = error_details available_in, error_code = error_details
message = "Too many crawling requests. Error code: #{error_code}." message = "Too many crawling requests. Error code: #{error_code}."
headers = { headers = {
"Content-Type" => "text/plain", "Content-Type" => "text/plain",
"Retry-After" => available_in.to_s, "Retry-After" => available_in.to_s,
"Discourse-Rate-Limit-Error-Code" => error_code, "Discourse-Rate-Limit-Error-Code" => error_code,
} }
return 429, headers, [message] return 429, headers, [message]
end end
end end
@ -371,10 +396,12 @@ class Middleware::RequestTracker
headers["X-Redis-Calls"] = redis[:calls].to_s headers["X-Redis-Calls"] = redis[:calls].to_s
headers["X-Redis-Time"] = "%0.6f" % redis[:duration] headers["X-Redis-Time"] = "%0.6f" % redis[:duration]
end end
if sql = info[:sql] if sql = info[:sql]
headers["X-Sql-Calls"] = sql[:calls].to_s headers["X-Sql-Calls"] = sql[:calls].to_s
headers["X-Sql-Time"] = "%0.6f" % sql[:duration] headers["X-Sql-Time"] = "%0.6f" % sql[:duration]
end end
if queue = env["REQUEST_QUEUE_SECONDS"] if queue = env["REQUEST_QUEUE_SECONDS"]
headers["X-Queue-Time"] = "%0.6f" % queue headers["X-Queue-Time"] = "%0.6f" % queue
end end
@ -389,6 +416,7 @@ class Middleware::RequestTracker
ensure ensure
if (limiters = env["DISCOURSE_RATE_LIMITERS"]) && env["DISCOURSE_IS_ASSET_PATH"] if (limiters = env["DISCOURSE_RATE_LIMITERS"]) && env["DISCOURSE_IS_ASSET_PATH"]
limiters.each(&:rollback!) limiters.each(&:rollback!)
env["DISCOURSE_ASSET_RATE_LIMITERS"].each do |limiter| env["DISCOURSE_ASSET_RATE_LIMITERS"].each do |limiter|
begin begin
limiter.performed! limiter.performed!
@ -431,6 +459,7 @@ class Middleware::RequestTracker
warn = warn =
GlobalSetting.max_reqs_per_ip_mode == "warn" || GlobalSetting.max_reqs_per_ip_mode == "warn" ||
GlobalSetting.max_reqs_per_ip_mode == "warn+block" GlobalSetting.max_reqs_per_ip_mode == "warn+block"
block = block =
GlobalSetting.max_reqs_per_ip_mode == "block" || GlobalSetting.max_reqs_per_ip_mode == "block" ||
GlobalSetting.max_reqs_per_ip_mode == "warn+block" GlobalSetting.max_reqs_per_ip_mode == "warn+block"
@ -446,44 +475,42 @@ class Middleware::RequestTracker
return if @@ip_skipper&.call(ip) return if @@ip_skipper&.call(ip)
return if STATIC_IP_SKIPPER&.any? { |entry| entry.include?(ip) } return if STATIC_IP_SKIPPER&.any? { |entry| entry.include?(ip) }
ip_or_id = ip rate_limiter = self.class.rate_limiters_stack.active_rate_limiter(request, cookie)
limit_on_id = false return nil if rate_limiter.nil?
if cookie && cookie[:user_id] && cookie[:trust_level] && rate_limit_key = rate_limiter.rate_limit_key
cookie[:trust_level] >= GlobalSetting.skip_per_ip_rate_limit_trust_level error_code_identifier = rate_limiter.error_code_identifier
ip_or_id = cookie[:user_id] global = rate_limiter.rate_limit_globally?
limit_on_id = true
end
limiter10 = limiter10 =
RateLimiter.new( RateLimiter.new(
nil, nil,
"global_limit_10_#{ip_or_id}", "global_limit_10_#{rate_limit_key}",
GlobalSetting.max_reqs_per_ip_per_10_seconds, GlobalSetting.max_reqs_per_ip_per_10_seconds,
10, 10,
global: !limit_on_id, global:,
aggressive: true, aggressive: true,
error_code: limit_on_id ? "id_10_secs_limit" : "ip_10_secs_limit", error_code: "#{error_code_identifier}_10_secs_limit",
) )
limiter60 = limiter60 =
RateLimiter.new( RateLimiter.new(
nil, nil,
"global_limit_60_#{ip_or_id}", "global_limit_60_#{rate_limit_key}",
GlobalSetting.max_reqs_per_ip_per_minute, GlobalSetting.max_reqs_per_ip_per_minute,
60, 60,
global: !limit_on_id, global:,
error_code: limit_on_id ? "id_60_secs_limit" : "ip_60_secs_limit", error_code: "#{error_code_identifier}_60_secs_limit",
aggressive: true, aggressive: true,
) )
limiter_assets10 = limiter_assets10 =
RateLimiter.new( RateLimiter.new(
nil, nil,
"global_limit_10_assets_#{ip_or_id}", "global_limit_10_assets_#{rate_limit_key}",
GlobalSetting.max_asset_reqs_per_ip_per_10_seconds, GlobalSetting.max_asset_reqs_per_ip_per_10_seconds,
10, 10,
error_code: limit_on_id ? "id_assets_10_secs_limit" : "ip_assets_10_secs_limit", error_code: "#{error_code_identifier}_assets_10_secs_limit",
global: !limit_on_id, global:,
) )
request.env["DISCOURSE_RATE_LIMITERS"] = [limiter10, limiter60] request.env["DISCOURSE_RATE_LIMITERS"] = [limiter10, limiter60]
@ -491,20 +518,13 @@ class Middleware::RequestTracker
if !limiter_assets10.can_perform? if !limiter_assets10.can_perform?
if warn if warn
limited_on = limit_on_id ? "user_id" : "ip"
Discourse.warn( Discourse.warn(
"Global asset rate limit exceeded for #{limited_on}: #{ip}: 10 second rate limit", "Global asset rate limit exceeded for #{rate_limiter.class.name}: #{rate_limit_key}: 10 second rate limit",
uri: request.env["REQUEST_URI"], uri: request.env["REQUEST_URI"],
) )
end end
if block return limiter_assets10.seconds_to_wait(Time.now.to_i), limiter_assets10.error_code if block
return [
limiter_assets10.seconds_to_wait(Time.now.to_i),
limiter_assets10.error_code,
limit_on_id
]
end
end end
begin begin
@ -517,14 +537,14 @@ class Middleware::RequestTracker
nil nil
rescue RateLimiter::LimitExceeded => e rescue RateLimiter::LimitExceeded => e
if warn if warn
limited_on = limit_on_id ? "user_id" : "ip"
Discourse.warn( Discourse.warn(
"Global rate limit exceeded for #{limited_on}: #{ip}: #{type} second rate limit", "Global rate limit exceeded for #{rate_limiter.class.name}: #{rate_limit_key}: #{type} second rate limit",
uri: request.env["REQUEST_URI"], uri: request.env["REQUEST_URI"],
) )
end end
if block if block
[e.available_in, e.error_code, limit_on_id] [e.available_in, e.error_code]
else else
nil nil
end end

View File

@ -1257,6 +1257,72 @@ class Plugin::Instance
DiscoursePluginRegistry.register_search_groups_set_query_callback(callback, self) DiscoursePluginRegistry.register_search_groups_set_query_callback(callback, self)
end end
# This is an experimental API and may be changed or removed in the future without deprecation.
#
# Adds a custom rate limiter to the request rate limiters stack. Only one rate limiter is used per request and the
# first rate limiter in the stack that is active is used. By default the rate limiters stack contains the following
# rate limiters:
#
# `RequestTracker::RateLimiters::User` - Rate limits authenticated requests based on the user's id
# `RequestTracker::RateLimiters::IP` - Rate limits requests based on the IP address
#
# @param identifier [Symbol] A unique identifier for the rate limiter.
#
# @param key [Proc] A lambda/proc that defines the `rate_limit_key`.
# - Receives `request` (An instance of `Rack::Request`) as argument.
# - Should return a string representing the rate limit key.
#
# @param activate_when [Proc] A lambda/proc that defines when the rate limiter should be used for a request.
# - Receives `request` (An instance of `Rack::Request`) as argument.
# - Should return `true` if the rate limiter is active, otherwise `false`.
#
# @param global [Boolean] Whether the rate limiter applies globally across all sites. Defaults to `false`.
# - Ignored if `klass` is provided.
#
# @param after [Class, nil] The rate limiter class after which the new rate limiter should be added.
#
# @param before [Class, nil] The rate limiter class before which the new rate limiter should be added.
#
# @example Adding a rate limiter that rate limits all requests based on the country of the IP address
#
# add_request_rate_limiter(
# identifier: :country,
# key: ->(request) { "country/#{DiscourseIpInfo.get(request.ip)[:country]}" },
# activate_when: ->(request) { DiscourseIpInfo.get(request.ip)[:country].present? },
# )
def add_request_rate_limiter(
identifier:,
key:,
activate_when:,
global: false,
after: nil,
before: nil
)
raise ArgumentError, "only one of `after` or `before` can be provided" if after && before
stack = Middleware::RequestTracker.rate_limiters_stack
if (reference_klass = after || before) && !stack.include?(reference_klass)
raise ArgumentError, "#{reference_klass} is not a valid value. Must be one of #{stack}"
end
klass =
Class.new(RequestTracker::RateLimiters::Base) do
define_method(:rate_limit_key) { key.call(@request) }
define_method(:rate_limit_globally?) { global }
define_method(:active?) { activate_when.call(@request) }
define_method(:error_code_identifier) { identifier }
end
if after
stack.insert_after(after, klass)
elsif before
stack.insert_before(before, klass)
else
stack.prepend(klass)
end
end
protected protected
def self.js_path def self.js_path

View File

@ -0,0 +1,40 @@
# frozen_string_literal: true
module RequestTracker
module RateLimiters
class Base
# :nodoc:
def initialize(request, cookie)
@request = request
@cookie = cookie
end
# This method is meant to be implemented in subclasses.
#
# @return [String] The key used to identify the rate limiter.
def rate_limit_key
raise NotImplementedError
end
# :nodoc:
def error_code_identifier
self.class.name.underscore.split("/").last
end
# This method is meant to be implemented in subclasses.
#
# @return [Boolean] Indicates if the rate limiter should be used for the request.
def active?
raise NotImplementedError
end
# This method is meant to be implemented in subclasses.
#
# @return [Boolean] Indicates whether the rate limit applies globally across all sites in the cluster or just for
# the current site.
def rate_limit_globally?
raise NotImplementedError
end
end
end
end

View File

@ -0,0 +1,19 @@
# frozen_string_literal: true
module RequestTracker
module RateLimiters
class IP < Base
def rate_limit_key
"ip/#{@request.ip}"
end
def rate_limit_globally?
true
end
def active?
true
end
end
end
end

View File

@ -0,0 +1,67 @@
# frozen_string_literal: true
module RequestTracker
module RateLimiters
# :nodoc:
class Stack
def initialize
@rate_limiter_klasses = []
end
def to_s
@rate_limiter_klasses.map { |klass| klass.to_s }.join(", ")
end
def include?(reference_klass)
@rate_limiter_klasses.include?(reference_klass)
end
def prepend(rate_limiter_klass)
@rate_limiter_klasses.prepend(rate_limiter_klass)
end
def append(rate_limiter_klass)
@rate_limiter_klasses.append(rate_limiter_klass)
end
def insert_before(existing_rate_limiter_klass, new_rate_limiter_klass)
@rate_limiter_klasses.insert(
get_rate_limiter_index(existing_rate_limiter_klass),
new_rate_limiter_klass,
)
end
def insert_after(existing_rate_limiter_klass, new_rate_limiter_klass)
@rate_limiter_klasses.insert(
get_rate_limiter_index(existing_rate_limiter_klass) + 1,
new_rate_limiter_klass,
)
end
def active_rate_limiter(request, cookie)
@rate_limiter_klasses.each do |rate_limiter_klass|
rate_limiter = rate_limiter_klass.new(request, cookie)
return rate_limiter if rate_limiter.active?
end
nil
end
def method_missing(method_name, *args, &block)
if @rate_limiter_klasses.respond_to?(method_name)
@rate_limiter_klasses.send(method_name, *args, &block)
else
super
end
end
private
def get_rate_limiter_index(rate_limiter_klass)
index = @rate_limiter_klasses.index { |klass| klass == rate_limiter_klass }
raise "Rate limiter #{rate_limiter_klass} not found" if index.nil?
index
end
end
end
end

View File

@ -0,0 +1,20 @@
# frozen_string_literal: true
module RequestTracker
module RateLimiters
class User < Base
def rate_limit_key
"user/#{@cookie[:user_id]}"
end
def rate_limit_globally?
false
end
def active?
@cookie && @cookie[:user_id] && @cookie[:trust_level] &&
@cookie[:trust_level] >= GlobalSetting.skip_per_ip_rate_limit_trust_level
end
end
end
end

View File

@ -684,8 +684,9 @@ RSpec.describe Middleware::RequestTracker do
global_setting :max_reqs_per_ip_per_10_seconds, 1 global_setting :max_reqs_per_ip_per_10_seconds, 1
global_setting :max_reqs_per_ip_mode, "warn+block" global_setting :max_reqs_per_ip_mode, "warn+block"
status, _ = middleware.call(env) env1 = env("REMOTE_ADDR" => "192.0.2.42")
status, headers = middleware.call(env) status, _ = middleware.call(env1)
status, headers = middleware.call(env1)
expect(fake_logger.warnings.count { |w| w.include?("Global rate limit exceeded") }).to eq(1) expect(fake_logger.warnings.count { |w| w.include?("Global rate limit exceeded") }).to eq(1)
expect(status).to eq(429) expect(status).to eq(429)
@ -696,8 +697,9 @@ RSpec.describe Middleware::RequestTracker do
global_setting :max_reqs_per_ip_per_10_seconds, 1 global_setting :max_reqs_per_ip_per_10_seconds, 1
global_setting :max_reqs_per_ip_mode, "warn" global_setting :max_reqs_per_ip_mode, "warn"
status, _ = middleware.call(env) env1 = env("REMOTE_ADDR" => "192.0.2.42")
status, _ = middleware.call(env) status, _ = middleware.call(env1)
status, _ = middleware.call(env1)
expect(fake_logger.warnings.count { |w| w.include?("Global rate limit exceeded") }).to eq(1) expect(fake_logger.warnings.count { |w| w.include?("Global rate limit exceeded") }).to eq(1)
expect(status).to eq(200) expect(status).to eq(200)
@ -766,8 +768,12 @@ RSpec.describe Middleware::RequestTracker do
expect(status).to eq(429) expect(status).to eq(429)
expect(called).to eq(1) expect(called).to eq(1)
expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("ip_10_secs_limit") expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("ip_10_secs_limit")
expect(response.first).to include("too many requests from this IP address")
expect(response.first).to include("Error code: ip_10_secs_limit.") expect(response.first).to eq(<<~MSG)
Slow down, you're making too many requests.
Please retry again in 10 seconds.
Error code: ip_10_secs_limit.
MSG
end end
it "is included when the requests-per-minute limit is reached" do it "is included when the requests-per-minute limit is reached" do
@ -790,8 +796,12 @@ RSpec.describe Middleware::RequestTracker do
expect(status).to eq(429) expect(status).to eq(429)
expect(called).to eq(1) expect(called).to eq(1)
expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("ip_60_secs_limit") expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("ip_60_secs_limit")
expect(response.first).to include("too many requests from this IP address")
expect(response.first).to include("Error code: ip_60_secs_limit.") expect(response.first).to eq(<<~MSG)
Slow down, you're making too many requests.
Please retry again in 60 seconds.
Error code: ip_60_secs_limit.
MSG
end end
it "is included when the assets-requests-per-10-seconds limit is reached" do it "is included when the assets-requests-per-10-seconds limit is reached" do
@ -815,8 +825,12 @@ RSpec.describe Middleware::RequestTracker do
expect(status).to eq(429) expect(status).to eq(429)
expect(called).to eq(1) expect(called).to eq(1)
expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("ip_assets_10_secs_limit") expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("ip_assets_10_secs_limit")
expect(response.first).to include("too many requests from this IP address")
expect(response.first).to include("Error code: ip_assets_10_secs_limit.") expect(response.first).to eq(<<~MSG)
Slow down, you're making too many requests.
Please retry again in 10 seconds.
Error code: ip_assets_10_secs_limit.
MSG
end end
end end
@ -855,10 +869,15 @@ RSpec.describe Middleware::RequestTracker do
middleware = Middleware::RequestTracker.new(app) middleware = Middleware::RequestTracker.new(app)
status, headers, response = middleware.call(env) status, headers, response = middleware.call(env)
expect(status).to eq(429) expect(status).to eq(429)
expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("id_60_secs_limit") expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("user_60_secs_limit")
expect(response.first).to include("too many requests from this user")
expect(response.first).to include("Error code: id_60_secs_limit.") expect(response.first).to eq(<<~MSG)
Slow down, you're making too many requests.
Please retry again in 60 seconds.
Error code: user_60_secs_limit.
MSG
end end
expect(called).to eq(3) expect(called).to eq(3)
end end
@ -878,11 +897,13 @@ RSpec.describe Middleware::RequestTracker do
env = env("HTTP_COOKIE" => "_t=#{cookie}", "REMOTE_ADDR" => "1.1.1.1") env = env("HTTP_COOKIE" => "_t=#{cookie}", "REMOTE_ADDR" => "1.1.1.1")
called = 0 called = 0
app = app =
lambda do |_| lambda do |_|
called += 1 called += 1
[200, {}, ["OK"]] [200, {}, ["OK"]]
end end
freeze_time(12.minutes.from_now) do freeze_time(12.minutes.from_now) do
middleware = Middleware::RequestTracker.new(app) middleware = Middleware::RequestTracker.new(app)
status, = middleware.call(env) status, = middleware.call(env)
@ -892,8 +913,12 @@ RSpec.describe Middleware::RequestTracker do
status, headers, response = middleware.call(env) status, headers, response = middleware.call(env)
expect(status).to eq(429) expect(status).to eq(429)
expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("ip_60_secs_limit") expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("ip_60_secs_limit")
expect(response.first).to include("too many requests from this IP address")
expect(response.first).to include("Error code: ip_60_secs_limit.") expect(response.first).to eq(<<~MSG)
Slow down, you're making too many requests.
Please retry again in 60 seconds.
Error code: ip_60_secs_limit.
MSG
end end
end end
@ -928,8 +953,53 @@ RSpec.describe Middleware::RequestTracker do
status, headers, response = middleware.call(env) status, headers, response = middleware.call(env)
expect(status).to eq(429) expect(status).to eq(429)
expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("ip_60_secs_limit") expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("ip_60_secs_limit")
expect(response.first).to include("too many requests from this IP address")
expect(response.first).to include("Error code: ip_60_secs_limit.") expect(response.first).to eq(<<~MSG)
Slow down, you're making too many requests.
Please retry again in 60 seconds.
Error code: ip_60_secs_limit.
MSG
end
context "for `add_request_rate_limiter` plugin API" do
after { described_class.reset_rate_limiters_stack }
it "can be used to add a custom rate limiter" do
global_setting :max_reqs_per_ip_per_minute, 1
plugin = Plugin::Instance.new
plugin.add_request_rate_limiter(
identifier: :crawlers,
key: ->(_request) { "crawlers" },
activate_when: ->(request) { request.user_agent =~ /crawler/ },
)
env1 = env("HTTP_USER_AGENT" => "some crawler")
called = 0
app =
lambda do |_|
called += 1
[200, {}, ["OK"]]
end
middleware = Middleware::RequestTracker.new(app)
status, = middleware.call(env1)
expect(status).to eq(200)
middleware = Middleware::RequestTracker.new(app)
status, headers, response = middleware.call(env1)
expect(status).to eq(429)
expect(headers["Discourse-Rate-Limit-Error-Code"]).to eq("crawlers_60_secs_limit")
expect(response.first).to eq(<<~MSG)
Slow down, you're making too many requests.
Please retry again in 60 seconds.
Error code: crawlers_60_secs_limit.
MSG
end
end end
end end

View File

@ -993,4 +993,114 @@ TEXT
expect(sum).to eq(3) expect(sum).to eq(3)
end end
end end
describe "#add_request_rate_limiter" do
after { Middleware::RequestTracker.reset_rate_limiters_stack }
it "should raise an error if `after` and `before` kwarg are provided" do
plugin = Plugin::Instance.new
expect do
plugin.add_request_rate_limiter(
identifier: :some_identifier,
key: ->(request) { request.ip },
activate_when: ->(request) { request.ip == "1.2.3.4" },
after: 0,
before: 0,
)
end.to raise_error(ArgumentError, "only one of `after` or `before` can be provided")
end
it "should raise an error if value of `after` kwarg is invalid" do
plugin = Plugin::Instance.new
expect {
plugin.add_request_rate_limiter(
identifier: :some_identifier,
key: ->(request) { request.ip },
activate_when: ->(request) { request.ip == "1.2.3.4" },
after: 0,
)
}.to raise_error(
ArgumentError,
"0 is not a valid value. Must be one of RequestTracker::RateLimiters::User, RequestTracker::RateLimiters::IP",
)
end
it "should raise an error if value of `before` kwarg is invalid" do
plugin = Plugin::Instance.new
expect {
plugin.add_request_rate_limiter(
identifier: :some_identifier,
key: ->(request) { request.ip },
activate_when: ->(request) { request.ip == "1.2.3.4" },
before: 0,
)
}.to raise_error(
ArgumentError,
"0 is not a valid value. Must be one of RequestTracker::RateLimiters::User, RequestTracker::RateLimiters::IP",
)
end
it "can prepend a rate limiter to `Middleware::RequestTracker.rate_limiters_stack`" do
plugin = Plugin::Instance.new
plugin.add_request_rate_limiter(
identifier: :some_identifier,
key: ->(request) { "crawlers" },
activate_when: ->(request) { request.user_agent =~ /crawler/ },
)
rate_limiter = Middleware::RequestTracker.rate_limiters_stack[0]
expect(rate_limiter.superclass).to eq(RequestTracker::RateLimiters::Base)
end
it "can insert a rate limiter before a specific rate limiter in `Middleware::RequestTracker.rate_limiters_stack`" do
plugin = Plugin::Instance.new
plugin.add_request_rate_limiter(
identifier: :some_identifier,
key: ->(request) { "crawlers" },
activate_when: ->(request) { request.user_agent =~ /crawler/ },
before: RequestTracker::RateLimiters::IP,
)
expect(Middleware::RequestTracker.rate_limiters_stack[0]).to eq(
RequestTracker::RateLimiters::User,
)
expect(Middleware::RequestTracker.rate_limiters_stack[1].superclass).to eq(
RequestTracker::RateLimiters::Base,
)
expect(Middleware::RequestTracker.rate_limiters_stack[2]).to eq(
RequestTracker::RateLimiters::IP,
)
end
it "can insert a rate limiter after a specific rate limiter in `Middleware::RequestTracker.rate_limiters_stack`" do
plugin = Plugin::Instance.new
plugin.add_request_rate_limiter(
identifier: :some_identifier,
key: ->(request) { "crawlers" },
activate_when: ->(request) { request.user_agent =~ /crawler/ },
after: RequestTracker::RateLimiters::IP,
)
expect(Middleware::RequestTracker.rate_limiters_stack[0]).to eq(
RequestTracker::RateLimiters::User,
)
expect(Middleware::RequestTracker.rate_limiters_stack[1]).to eq(
RequestTracker::RateLimiters::IP,
)
expect(Middleware::RequestTracker.rate_limiters_stack[2].superclass).to eq(
RequestTracker::RateLimiters::Base,
)
end
end
end end

View File

@ -110,14 +110,14 @@ RSpec.describe "RequestTracker in multisite", type: :multisite do
before { global_setting :max_reqs_per_ip_per_10_seconds, 1 } before { global_setting :max_reqs_per_ip_per_10_seconds, 1 }
include_examples "ip rate limiters behavior", "ip_10_secs_limit" include_examples "ip rate limiters behavior", "ip_10_secs_limit"
include_examples "user id rate limiters behavior", "id_10_secs_limit" include_examples "user id rate limiters behavior", "user_10_secs_limit"
end end
context "with a 60 seconds limiter" do context "with a 60 seconds limiter" do
before { global_setting :max_reqs_per_ip_per_minute, 1 } before { global_setting :max_reqs_per_ip_per_minute, 1 }
include_examples "ip rate limiters behavior", "ip_60_secs_limit" include_examples "ip rate limiters behavior", "ip_60_secs_limit"
include_examples "user id rate limiters behavior", "id_60_secs_limit" include_examples "user id rate limiters behavior", "user_60_secs_limit"
end end
context "with assets 10 seconds limiter" do context "with assets 10 seconds limiter" do
@ -125,6 +125,6 @@ RSpec.describe "RequestTracker in multisite", type: :multisite do
app_callback = ->(env) { env["DISCOURSE_IS_ASSET_PATH"] = true } app_callback = ->(env) { env["DISCOURSE_IS_ASSET_PATH"] = true }
include_examples "ip rate limiters behavior", "ip_assets_10_secs_limit", app_callback include_examples "ip rate limiters behavior", "ip_assets_10_secs_limit", app_callback
include_examples "user id rate limiters behavior", "id_assets_10_secs_limit", app_callback include_examples "user id rate limiters behavior", "user_assets_10_secs_limit", app_callback
end end
end end

View File

@ -46,7 +46,7 @@ RSpec.describe "rate limits" do
.expects(:performed!) .expects(:performed!)
.with( .with(
nil, nil,
"global_limit_60_192.0.2.1", "global_limit_60_ip/192.0.2.1",
200, 200,
60, 60,
global: true, global: true,
@ -58,7 +58,7 @@ RSpec.describe "rate limits" do
.expects(:performed!) .expects(:performed!)
.with( .with(
nil, nil,
"global_limit_10_192.0.2.1", "global_limit_10_ip/192.0.2.1",
50, 50,
10, 10,
global: true, global: true,
@ -76,7 +76,7 @@ RSpec.describe "rate limits" do
.expects(:rollback!) .expects(:rollback!)
.with( .with(
nil, nil,
"global_limit_60_192.0.2.1", "global_limit_60_ip/192.0.2.1",
200, 200,
60, 60,
global: true, global: true,
@ -88,7 +88,7 @@ RSpec.describe "rate limits" do
.expects(:rollback!) .expects(:rollback!)
.with( .with(
nil, nil,
"global_limit_10_192.0.2.1", "global_limit_10_ip/192.0.2.1",
50, 50,
10, 10,
global: true, global: true,
@ -116,7 +116,7 @@ RSpec.describe "rate limits" do
.expects(:performed!) .expects(:performed!)
.with( .with(
nil, nil,
"global_limit_60_192.0.2.1", "global_limit_60_ip/192.0.2.1",
200, 200,
60, 60,
global: true, global: true,
@ -128,7 +128,7 @@ RSpec.describe "rate limits" do
.expects(:performed!) .expects(:performed!)
.with( .with(
nil, nil,
"global_limit_10_192.0.2.1", "global_limit_10_ip/192.0.2.1",
50, 50,
10, 10,
global: true, global: true,