diff --git a/lib/middleware/request_tracker.rb b/lib/middleware/request_tracker.rb index 6f802ec16f6..b2d4d7299cb 100644 --- a/lib/middleware/request_tracker.rb +++ b/lib/middleware/request_tracker.rb @@ -6,6 +6,7 @@ require_dependency 'method_profiler' class Middleware::RequestTracker @@detailed_request_loggers = nil + @@ip_skipper = nil # register callbacks for detailed request loggers called on every request # example: @@ -35,7 +36,20 @@ class Middleware::RequestTracker if @@detailed_request_loggers.length == 0 @detailed_request_loggers = nil end + end + # Register a custom `ip_skipper`, a function that will skip rate limiting + # for any IP that returns true. + # + # For example, if you never wanted to rate limit 1.2.3.4 + # + # ``` + # Middleware::RequestTracker.register_ip_skipper do |ip| + # ip == "1.2.3.4" + # end + # ``` + def self.register_ip_skipper(&blk) + @@ip_skipper = blk end def initialize(app, settings = {}) @@ -167,6 +181,8 @@ class Middleware::RequestTracker return false if is_private_ip?(ip) end + return false if @@ip_skipper.try(:call, ip) + limiter10 = RateLimiter.new( nil, "global_ip_limit_10_#{ip}", diff --git a/spec/components/middleware/request_tracker_spec.rb b/spec/components/middleware/request_tracker_spec.rb index e56c708a227..a112d864e60 100644 --- a/spec/components/middleware/request_tracker_spec.rb +++ b/spec/components/middleware/request_tracker_spec.rb @@ -127,6 +127,30 @@ describe Middleware::RequestTracker do expect(status).to eq(429) end + describe "register_ip_skipper" do + before do + Middleware::RequestTracker.register_ip_skipper do |ip| + ip == "1.1.1.2" + end + global_setting :max_reqs_per_ip_per_10_seconds, 1 + global_setting :max_reqs_per_ip_mode, 'block' + end + + it "won't block if the ip is skipped" do + env1 = env("REMOTE_ADDR" => "1.1.1.2") + status, _ = middleware.call(env1) + status, _ = middleware.call(env1) + expect(status).to eq(200) + end + + it "blocks if the ip isn't skipped" do + env1 = env("REMOTE_ADDR" => "1.1.1.1") + status, _ = middleware.call(env1) + status, _ = middleware.call(env1) + expect(status).to eq(429) + end + end + it "does nothing for private IPs if skipped" do global_setting :max_reqs_per_ip_per_10_seconds, 1 global_setting :max_reqs_per_ip_mode, 'warn+block' @@ -206,7 +230,7 @@ describe Middleware::RequestTracker do end after do - Middleware::RequestTracker.register_detailed_request_logger(logger) + Middleware::RequestTracker.unregister_detailed_request_logger(logger) end it "can correctly log detailed data" do