diff --git a/script/cache_critical_dns b/script/cache_critical_dns index a0307781a7d..9ed043683e4 100755 --- a/script/cache_critical_dns +++ b/script/cache_critical_dns @@ -23,6 +23,22 @@ CRITICAL_HOST_ENV_VARS = %w{ HOST_RESOLVER_CACHE = {} HOST_HEALTHY_CACHE = {} HOSTS_PATH = ENV['DISCOURSE_DNS_CACHE_HOSTS_FILE'] || "/etc/hosts" + +PrioFilter = Struct.new(:min, :max) do + # min and max must be integers and relate to the minimum or maximum accepted + # priority of an SRV RR target. + # The return value from within_threshold? indicates if the priority is less + # than or equal to the upper threshold, or greater than or equal to the + # lower threshold. + def within_threshold?(p) + p >= min && p <= max + end +end +SRV_PRIORITY_THRESHOLD_MIN = 0 +SRV_PRIORITY_THRESHOLD_MAX = 65535 +SRV_PRIORITY_FILTERS = Hash.new( + PrioFilter.new(SRV_PRIORITY_THRESHOLD_MIN, SRV_PRIORITY_THRESHOLD_MAX)) + REFRESH_SECONDS = 30 module DNSClient @@ -54,14 +70,16 @@ end class SRVName include DNSClient - def initialize(srv_hostname) + def initialize(srv_hostname, prio_filter) @name = srv_hostname + @prio_filter = prio_filter end def resolve dns_client_with_timeout do |dns_client| [].tap do |addresses| targets = dns_client.getresources(@name, Resolv::DNS::Resource::IN::SRV) + targets.delete_if { |t| !@prio_filter.within_threshold?(t.priority) } addresses.concat(targets.map { |t| Name.new(t.target.to_s).resolve }.flatten) end end @@ -264,8 +282,12 @@ def nilempty(v) end end +def env_srv_var(env_name) + "#{env_name}_SRV" +end + def env_srv_name(env_name) - nilempty(ENV["#{env_name}_SRV"]) + nilempty(ENV[env_srv_var(env_name)]) end def run(hostname_vars) @@ -278,7 +300,7 @@ def run(hostname_vars) name = ENV[var] HOST_RESOLVER_CACHE[var] ||= ResolverCache.new( if (srv_name = env_srv_name(var)) - SRVName.new(srv_name) + SRVName.new(srv_name, SRV_PRIORITY_FILTERS[env_srv_var(var)]) else Name.new(name) end @@ -339,6 +361,23 @@ all_hostname_vars = CRITICAL_HOST_ENV_VARS.select do |name| end end +# Populate the SRV_PRIORITY_FILTERS for any name that has a priority present in +# the environment. If no priority thresholds are found for the name, the default +# is that no filtering based on priority will be performed. +CRITICAL_HOST_ENV_VARS.each do |v| + if (name = env_srv_name(v)) + max = ENV.fetch("#{env_srv_var(v)}_PRIORITY_LE", SRV_PRIORITY_THRESHOLD_MAX).to_i + min = ENV.fetch("#{env_srv_var(v)}_PRIORITY_GE", SRV_PRIORITY_THRESHOLD_MIN).to_i + if max > SRV_PRIORITY_THRESHOLD_MAX || + min < SRV_PRIORITY_THRESHOLD_MIN || + min > max + raise "invalid priority threshold set for #{v}" + end + + SRV_PRIORITY_FILTERS[env_srv_var(v)] = PrioFilter.new(min, max) + end +end + while true run(all_hostname_vars) sleep REFRESH_SECONDS