diff --git a/lib/middleware/request_tracker.rb b/lib/middleware/request_tracker.rb index b2d4d7299cb..cee0f5535cf 100644 --- a/lib/middleware/request_tracker.rb +++ b/lib/middleware/request_tracker.rb @@ -38,6 +38,11 @@ class Middleware::RequestTracker end end + # used for testing + def self.unregister_ip_skipper + @@ip_skipper = nil + end + # Register a custom `ip_skipper`, a function that will skip rate limiting # for any IP that returns true. # @@ -49,6 +54,7 @@ class Middleware::RequestTracker # end # ``` def self.register_ip_skipper(&blk) + raise "IP skipper is already registered!" if @@ip_skipper @@ip_skipper = blk end @@ -181,7 +187,7 @@ class Middleware::RequestTracker return false if is_private_ip?(ip) end - return false if @@ip_skipper.try(:call, ip) + return false if @@ip_skipper&.call(ip) limiter10 = RateLimiter.new( nil, diff --git a/spec/components/middleware/request_tracker_spec.rb b/spec/components/middleware/request_tracker_spec.rb index a112d864e60..d4455abbb4d 100644 --- a/spec/components/middleware/request_tracker_spec.rb +++ b/spec/components/middleware/request_tracker_spec.rb @@ -136,6 +136,10 @@ describe Middleware::RequestTracker do global_setting :max_reqs_per_ip_mode, 'block' end + after do + Middleware::RequestTracker.unregister_ip_skipper + end + it "won't block if the ip is skipped" do env1 = env("REMOTE_ADDR" => "1.1.1.2") status, _ = middleware.call(env1)