diff --git a/lib/final_destination/http.rb b/lib/final_destination/http.rb index ffceff99deb..0a549cdcce1 100644 --- a/lib/final_destination/http.rb +++ b/lib/final_destination/http.rb @@ -1,40 +1,47 @@ # frozen_string_literal: true -class FinalDestination::HTTP < Net::HTTP - def connect - original_open_timeout = @open_timeout - return super if @ipaddr +class FinalDestination + module SSRFSafeNetHTTP + def connect + original_open_timeout = @open_timeout + return super if @ipaddr - timeout_at = current_time + @open_timeout + timeout_at = current_time + @open_timeout - # This iteration through addresses would normally happen in Socket#tcp - # We do it here because we're tightly controlling addresses rather than - # handing Socket#tcp a hostname - ips = FinalDestination::SSRFDetector.lookup_and_filter_ips(@address, timeout: @connect_timeout) + # This iteration through addresses would normally happen in Socket#tcp + # We do it here because we're tightly controlling addresses rather than + # handing Socket#tcp a hostname + ips = + FinalDestination::SSRFDetector.lookup_and_filter_ips(@address, timeout: @connect_timeout) - ips.each_with_index do |ip, index| - debug "[FinalDestination] Attempting connection to #{ip}..." - self.ipaddr = ip + ips.each_with_index do |ip, index| + debug "[FinalDestination] Attempting connection to #{ip}..." + self.ipaddr = ip - remaining_time = timeout_at - current_time - if remaining_time <= 0 - raise Net::OpenTimeout.new("Operation timed out - FinalDestination::HTTP") + remaining_time = timeout_at - current_time + if remaining_time <= 0 + raise Net::OpenTimeout.new("Operation timed out - FinalDestination::HTTP") + end + + @open_timeout = remaining_time + return super + rescue SystemCallError, Net::OpenTimeout => e + debug "[FinalDestination] Error connecting to #{ip}... #{e.message}" + was_last_attempt = index == ips.length - 1 + raise if was_last_attempt end - - @open_timeout = remaining_time - return super - rescue SystemCallError, Net::OpenTimeout => e - debug "[FinalDestination] Error connecting to #{ip}... #{e.message}" - was_last_attempt = index == ips.length - 1 - raise if was_last_attempt + ensure + @open_timeout = original_open_timeout + end + + private + + def current_time + Process.clock_gettime(Process::CLOCK_MONOTONIC) end - ensure - @open_timeout = original_open_timeout end - private - - def current_time - Process.clock_gettime(Process::CLOCK_MONOTONIC) + class HTTP < ::Net::HTTP + include SSRFSafeNetHTTP end end diff --git a/spec/support/final_destination_helper.rb b/spec/support/final_destination_helper.rb index 7b586a3b1f9..dabd1b9a1f6 100644 --- a/spec/support/final_destination_helper.rb +++ b/spec/support/final_destination_helper.rb @@ -7,7 +7,15 @@ WebMock::HttpLibAdapterRegistry.instance.register( def self.enable! FinalDestination.send(:remove_const, :HTTP) - FinalDestination.send(:const_set, :HTTP, Net::HTTP) + + # At this point, `Net::HTTP` has already been patched by WebMock so we need to re-declare `FinalDestination::HTTP` + # but inherit from the patched `Net::HTTP` class. This is to allow requests made using `FinalDestination::HTTP` to be + # intercepted by WebMock. + FinalDestination.send( + :const_set, + :HTTP, + Class.new(Net::HTTP) { include FinalDestination::SSRFSafeNetHTTP }, + ) end def self.disable!