discourse/script/cache_critical_dns
Michael Fitz-Payne 0553788d3b DEV(cache_critical_dns): improve postgres_healthcheck
The `PG::Connection#ping` method is only reliable for checking if the
given host is accepting connections, and not if the authentication
details are valid.

This extends the healthcheck to confirm that the auth details are
able to both create a connection and execute queries against the
database.

We expect the empty query to return an empty result set, so we can
assert on that. If a failure occurs for any reason, the healthcheck will
return false.
2022-05-24 08:20:10 +10:00

390 lines
9.5 KiB
Ruby
Executable File

#!/usr/bin/env ruby
# frozen_string_literal: true
# Specifying this env var ensures ruby can load gems installed via the Discourse
# project Gemfile (e.g. pg, redis).
ENV['BUNDLE_GEMFILE'] ||= '/var/www/discourse/Gemfile'
require 'bundler/setup'
require 'ipaddr'
require 'pg'
require 'redis'
require 'resolv'
require 'socket'
require 'time'
CRITICAL_HOST_ENV_VARS = %w{
DISCOURSE_DB_HOST
DISCOURSE_DB_REPLICA_HOST
DISCOURSE_REDIS_HOST
DISCOURSE_REDIS_SLAVE_HOST
DISCOURSE_REDIS_REPLICA_HOST
}
DEFAULT_DB_NAME = "discourse"
HOST_RESOLVER_CACHE = {}
HOST_HEALTHY_CACHE = {}
HOSTS_PATH = ENV['DISCOURSE_DNS_CACHE_HOSTS_FILE'] || "/etc/hosts"
PrioFilter = Struct.new(:min, :max) do
# min and max must be integers and relate to the minimum or maximum accepted
# priority of an SRV RR target.
# The return value from within_threshold? indicates if the priority is less
# than or equal to the upper threshold, or greater than or equal to the
# lower threshold.
def within_threshold?(p)
p >= min && p <= max
end
end
SRV_PRIORITY_THRESHOLD_MIN = 0
SRV_PRIORITY_THRESHOLD_MAX = 65535
SRV_PRIORITY_FILTERS = Hash.new(
PrioFilter.new(SRV_PRIORITY_THRESHOLD_MIN, SRV_PRIORITY_THRESHOLD_MAX))
REFRESH_SECONDS = 30
module DNSClient
def dns_client_with_timeout
Resolv::DNS.open do |dns_client|
dns_client.timeouts = 2
yield dns_client
end
end
end
class Name
include DNSClient
def initialize(hostname)
@name = hostname
end
def resolve
dns_client_with_timeout do |dns_client|
[].tap do |addresses|
addresses.concat(dns_client.getresources(@name, Resolv::DNS::Resource::IN::A).map(&:address))
addresses.concat(dns_client.getresources(@name, Resolv::DNS::Resource::IN::AAAA).map(&:address))
end.map(&:to_s)
end
end
end
class SRVName
include DNSClient
def initialize(srv_hostname, prio_filter)
@name = srv_hostname
@prio_filter = prio_filter
end
def resolve
dns_client_with_timeout do |dns_client|
[].tap do |addresses|
targets = dns_client.getresources(@name, Resolv::DNS::Resource::IN::SRV)
targets.delete_if { |t| !@prio_filter.within_threshold?(t.priority) }
addresses.concat(targets.map { |t| Name.new(t.target.to_s).resolve }.flatten)
end
end
end
end
CacheMeta = Struct.new(:first_seen, :last_seen)
class ResolverCache
def initialize(name)
# instance of Name|SRVName
@name = name
# {IPv4/IPv6 address: CacheMeta}
@cached = {}
end
# resolve returns a list of resolved addresses ordered by the time first seen,
# most recently seen at the head of the list.
# Addresses last seen more than 30 minutes ago are evicted from the cache on
# a call to resolve().
# If an exception occurs during DNS resolution we return whatever addresses are
# cached.
def resolve
@name.resolve.each do |address|
if @cached[address]
@cached[address].last_seen = Time.now.utc
else
@cached[address] = CacheMeta.new(Time.now.utc, Time.now.utc)
end
end
ensure
@cached = @cached.delete_if { |_, meta| Time.now.utc - 30 * 60 > meta.last_seen }
@cached.sort_by { |_, meta| meta.first_seen }.reverse.map(&:first)
end
end
class HealthyCache
def initialize(resolver_cache, check)
@resolver_cache = resolver_cache # instance of ResolverCache
@check = check # lambda function to perform for health checks
@cached = nil # a single IP address that was most recently found to be healthy
end
def first_healthy
address = @resolver_cache.resolve.lazy.select { |addr| @check.call(addr) }.first
if !nilempty(address).nil?
@cached = address
end
@cached
end
end
def redis_healthcheck(host:, password:)
client_opts = {
host: host,
password: password,
timeout: 1,
}
if !nilempty(ENV['DISCOURSE_REDIS_USE_SSL']).nil? then
client_opts[:ssl] = true
client_opts[:ssl_params] = {
verify_mode: OpenSSL::SSL::VERIFY_NONE,
}
end
client = Redis.new(**client_opts)
response = client.ping
response == "PONG"
rescue
false
ensure
client.close if client
end
def postgres_healthcheck(host:, user:, password:, dbname:)
client = PG::Connection.new(
host: host,
user: user,
password: password,
dbname: dbname,
connect_timeout: 2, # minimum
)
client.exec(';').none?
rescue
false
ensure
client.close if client
end
HEALTH_CHECKS = {
"DISCOURSE_DB_HOST": lambda { |addr|
postgres_healthcheck(
host: addr,
user: ENV["DISCOURSE_DB_USERNAME"] || DEFAULT_DB_NAME,
dbname: ENV["DISCOURSE_DB_NAME"] || DEFAULT_DB_NAME,
password: ENV["DISCOURSE_DB_PASSWORD"])},
"DISCOURSE_DB_REPLICA_HOST": lambda { |addr|
postgres_healthcheck(
host: addr,
user: ENV["DISCOURSE_DB_USERNAME"] || DEFAULT_DB_NAME,
dbname: ENV["DISCOURSE_DB_NAME"] || DEFAULT_DB_NAME,
password: ENV["DISCOURSE_DB_PASSWORD"])},
"DISCOURSE_REDIS_HOST": lambda { |addr|
redis_healthcheck(
host: addr,
password: ENV["DISCOURSE_REDIS_PASSWORD"])},
"DISCOURSE_REDIS_REPLICA_HOST": lambda { |addr|
redis_healthcheck(
host: addr,
password: ENV["DISCOURSE_REDIS_PASSWORD"])},
}
def log(msg)
STDERR.puts "#{Time.now.utc.iso8601}: #{msg}"
end
def error(msg)
log(msg)
end
def swap_address(hosts, name, ips)
new_file = []
hosts.split("\n").each do |line|
line.strip!
if line[0] != '#'
_, hostname = line.split(/\s+/)
next if hostname == name
end
new_file << line
new_file << "\n"
end
ips.each do |ip|
new_file << "#{ip} #{name} # AUTO GENERATED: #{Time.now.utc.iso8601}\n"
end
new_file.join
end
def send_counter(name, description, labels, value)
host = "localhost"
port = ENV.fetch("DISCOURSE_PROMETHEUS_COLLECTOR_PORT", 9405).to_i
if labels
labels = labels.map do |k, v|
"\"#{k}\": \"#{v}\""
end.join(",")
else
labels = ""
end
json = <<~JSON
{
"_type": "Custom",
"type": "Counter",
"name": "#{name}",
"description": "#{description}",
"labels": { #{labels} },
"value": #{value}
}
JSON
payload = +"POST /send-metrics HTTP/1.1\r\n"
payload << "Host: #{host}\r\n"
payload << "Connection: Close\r\n"
payload << "Content-Type: application/json\r\n"
payload << "Content-Length: #{json.bytesize}\r\n"
payload << "\r\n"
payload << json
socket = TCPSocket.new host, port
socket.write payload
socket.flush
result = socket.read
first_line = result.split("\n")[0]
if first_line.strip != "HTTP/1.1 200 OK"
error("Failed to report metric #{result}")
end
socket.close
rescue => e
error("Failed to send metric to Prometheus #{e}")
end
def report_success
send_counter('critical_dns_successes_total', 'critical DNS resolution success', nil, 1)
end
def report_failure(errors)
errors.each do |host, count|
send_counter('critical_dns_failures_total', 'critical DNS resolution failures', host ? { host: host } : nil, count)
end
end
def nilempty(v)
if v.nil?
nil
elsif v.respond_to?(:empty?) && v.empty?
nil
else
v
end
end
def env_srv_var(env_name)
"#{env_name}_SRV"
end
def env_srv_name(env_name)
nilempty(ENV[env_srv_var(env_name)])
end
def run(hostname_vars)
# HOSTNAME: [IP_ADDRESS, ...]
# this will usually be a single address
resolved = {}
errors = Hash.new(0)
hostname_vars.each do |var|
name = ENV[var]
HOST_RESOLVER_CACHE[var] ||= ResolverCache.new(
if (srv_name = env_srv_name(var))
SRVName.new(srv_name, SRV_PRIORITY_FILTERS[env_srv_var(var)])
else
Name.new(name)
end
)
HOST_HEALTHY_CACHE[var] ||= HealthyCache.new(HOST_RESOLVER_CACHE[var], HEALTH_CHECKS[var.to_sym])
begin
if (address = HOST_HEALTHY_CACHE[var].first_healthy)
resolved[name] = [address]
else
error("#{var}: #{name}: no address")
errors[name] += 1
end
rescue => e
error("#{var}: #{name}: #{e}")
errors[name] += 1
end
end
hosts_content = File.read(HOSTS_PATH)
hosts = Resolv::Hosts.new(HOSTS_PATH)
changed = false
resolved.each do |hostname, ips|
if hosts.getaddresses(hostname).map(&:to_s).sort != ips.sort
log("IP addresses for #{hostname} changed to #{ips}")
hosts_content = swap_address(hosts_content, hostname, ips)
changed = true
end
end
if changed
File.write(HOSTS_PATH, hosts_content)
end
rescue => e
error("DNS lookup failed: #{e}")
errors[nil] = 1
ensure
if errors == {}
report_success
else
report_failure(errors)
end
end
# If any of the host variables are an explicit IP we will not attempt to cache
# them.
all_hostname_vars = CRITICAL_HOST_ENV_VARS.select do |name|
begin
host = ENV[name]
next if nilempty(host).nil?
IPAddr.new(host)
false
rescue IPAddr::InvalidAddressError, IPAddr::AddressFamilyError
true
end
end
# Populate the SRV_PRIORITY_FILTERS for any name that has a priority present in
# the environment. If no priority thresholds are found for the name, the default
# is that no filtering based on priority will be performed.
CRITICAL_HOST_ENV_VARS.each do |v|
if (name = env_srv_name(v))
max = ENV.fetch("#{env_srv_var(v)}_PRIORITY_LE", SRV_PRIORITY_THRESHOLD_MAX).to_i
min = ENV.fetch("#{env_srv_var(v)}_PRIORITY_GE", SRV_PRIORITY_THRESHOLD_MIN).to_i
if max > SRV_PRIORITY_THRESHOLD_MAX ||
min < SRV_PRIORITY_THRESHOLD_MIN ||
min > max
raise "invalid priority threshold set for #{v}"
end
SRV_PRIORITY_FILTERS[env_srv_var(v)] = PrioFilter.new(min, max)
end
end
while true
run(all_hostname_vars)
sleep REFRESH_SECONDS
end