#!/usr/bin/env ruby
# frozen_string_literal: true

require 'resolv'
require 'time'
require 'socket'

REFRESH_SECONDS = 30

HOSTS_PATH = "/etc/hosts"

CRITICAL_HOST_ENV_VARS = %w{
  DISCOURSE_DB_HOST
  DISCOURSE_DB_REPLICA_HOST
  DISCOURSE_REDIS_HOST
  DISCOURSE_REDIS_SLAVE_HOST
}

def log(msg)
  STDERR.puts "#{Time.now.iso8601}: #{msg}"
end

def error(msg)
  log(msg)
end

def swap_address(hosts, name, ips)
  new_file = []

  hosts.split("\n").each do |line|
    line.strip!
    if line[0] != '#'
      _, hostname = line.split(/\s+/)
      next if hostname == name
    end
    new_file << line
    new_file << "\n"
  end

  ips.each do |ip|
    new_file << "#{ip} #{name} # AUTO GENERATED: #{Time.now.iso8601}\n"
  end

  new_file.join
end

def hosts_entries(dns, name)
  host = ENV[name]

  results = dns.getresources(host, Resolv::DNS::Resource::IN::A)
  results.concat dns.getresources(host, Resolv::DNS::Resource::IN::AAAA)

  results.map do |result|
    "#{result.address}"
  end
end

def send_counter(name, description, labels, value)
  host = "localhost"
  port = 9405

  if labels
    labels = labels.map do |k, v|
      "\"#{k}\": \"#{v}\""
    end.join(",")
  else
    labels = ""
  end

  json = <<~JSON
  {
    "_type": "Custom",
    "type": "Counter",
    "name": "#{name}",
    "description": "#{description}",
    "labels": { #{labels} },
    "value": #{value}
  }
  JSON

  payload = +"POST /send-metrics HTTP/1.1\r\n"
  payload << "Host: #{host}\r\n"
  payload << "Connection: Close\r\n"
  payload << "Content-Type: application/json\r\n"
  payload << "Content-Length: #{json.bytesize}\r\n"
  payload << "\r\n"
  payload << json

  socket = TCPSocket.new host, port
  socket.write payload
  socket.flush
  result = socket.read
  first_line = result.split("\n")[0]
  if first_line.strip != "HTTP/1.1 200 OK"
    error("Failed to report metric #{result}")
  end
  socket.close
rescue => e
  error("Failed to send metric to Prometheus #{e}")
end

def report_success
  send_counter('critical_dns_successes_total', 'critical DNS resolution success', nil, 1)
end

def report_failure(errors)
  errors.each do |host, count|
    send_counter('critical_dns_failures_total', 'critical DNS resolution failures', host ? { host: host } : nil, count)
  end
end

@vars = CRITICAL_HOST_ENV_VARS.map do |name|
  begin
    host = ENV[name]
    next if !host || host.length == 0
    IPAddr.new(ENV[name])
    nil
  rescue IPAddr::InvalidAddressError, IPAddr::AddressFamilyError
    name
  end
end.compact

def loop
  errors = Hash.new(0)

  Resolv::DNS.open do |dns|
    dns.timeouts = 2

    resolved = {}

    hosts = @vars.each do |var|
      host = ENV[var]

      begin
        entries = hosts_entries(dns, var)
      rescue => e
        error("Failed to resolve DNS for #{name} - #{e}")
        errors[host] += 1
      end

      if entries&.length > 0
        resolved[host] = entries
      else
        error("Failed to find any DNS entry for #{var} : #{ENV[var]}")
        errors[host] += 1
      end

    end

    hosts_content = File.read(HOSTS_PATH)
    hosts = Resolv::Hosts.new(HOSTS_PATH)

    changed = false
    resolved.each do |name, ips|
      if hosts.getaddresses(name).map(&:to_s).sort != ips.sort
        log("IP addresses for #{name} changed to #{ips}")
        hosts_content = swap_address(hosts_content, name, ips)
        changed = true
      end
    end

    if changed
      File.write(HOSTS_PATH, hosts_content)
    end

  end
rescue => e
  error("Failed to access DNS - #{e}")
  errors[nil] = 1
ensure
  if errors == {}
    report_success
  else
    report_failure(errors)
  end
end

while true
  loop
  sleep REFRESH_SECONDS
end