From f26d2c6ba86854bd1e65ebfd460a98b96db6a9e7 Mon Sep 17 00:00:00 2001 From: Saleh Dindar Date: Tue, 24 Oct 2023 22:01:42 -0700 Subject: [PATCH] fs/http: reload client certificates on expiry In corporate environments, client certificates have short life times for added security, and they get renewed automatically. This means that client certificate can expire in the middle of long running command such as `mount`. This commit attempts to reload the client certificates 30s before they expire. This will be active for all backends which use HTTP. --- fs/fshttp/http.go | 60 +++++++++++++++++-- fs/fshttp/http_test.go | 130 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 5 deletions(-) diff --git a/fs/fshttp/http.go b/fs/fshttp/http.go index 498714dfd..be9e3aee3 100644 --- a/fs/fshttp/http.go +++ b/fs/fshttp/http.go @@ -69,6 +69,13 @@ func NewTransportCustom(ctx context.Context, customize func(*http.Transport)) ht if err != nil { log.Fatalf("Failed to load --client-cert/--client-key pair: %v", err) } + if cert.Leaf == nil { + // Leaf is always the first certificate + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + log.Fatalf("Failed to parse the certificate") + } + } t.TLSClientConfig.Certificates = []tls.Certificate{cert} } @@ -148,17 +155,24 @@ type Transport struct { userAgent string headers []*fs.HTTPOption metrics *Metrics + // Filename of the client cert in case we need to reload it + clientCert string + clientKey string + // Mutex for serializing attempts at reloading the certificates + reloadMutex sync.Mutex } // newTransport wraps the http.Transport passed in and logs all // roundtrips including the body if logBody is set. func newTransport(ci *fs.ConfigInfo, transport *http.Transport) *Transport { return &Transport{ - Transport: transport, - dump: ci.Dump, - userAgent: ci.UserAgent, - headers: ci.Headers, - metrics: DefaultMetrics, + Transport: transport, + dump: ci.Dump, + userAgent: ci.UserAgent, + headers: ci.Headers, + metrics: DefaultMetrics, + clientCert: ci.ClientCert, + clientKey: ci.ClientKey, } } @@ -247,8 +261,44 @@ func cleanAuths(buf []byte) []byte { return buf } +var expireWindow = 30 * time.Second + +func isCertificateExpired(cc *tls.Config) bool { + return len(cc.Certificates) > 0 && cc.Certificates[0].Leaf != nil && time.Until(cc.Certificates[0].Leaf.NotAfter) < expireWindow +} + +func (t *Transport) reloadCertificates() { + t.reloadMutex.Lock() + defer t.reloadMutex.Unlock() + // Check that the certificate is expired before trying to reload it + // it might have been reloaded while we were waiting to lock the mutex + if !isCertificateExpired(t.TLSClientConfig) { + return + } + + cert, err := tls.LoadX509KeyPair(t.clientCert, t.clientKey) + if err != nil { + log.Fatalf("Failed to load --client-cert/--client-key pair: %v", err) + } + // Check if we need to parse the certificate again, we need it + // for checking the expiration date + if cert.Leaf == nil { + // Leaf is always the first certificate + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + log.Fatalf("Failed to parse the certificate") + } + } + t.TLSClientConfig.Certificates = []tls.Certificate{cert} +} + // RoundTrip implements the RoundTripper interface. func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + // Check if certificates are being used and the certificates are expired + if isCertificateExpired(t.TLSClientConfig) { + t.reloadCertificates() + } + // Limit transactions per second if required accounting.LimitTPS(req.Context()) // Force user agent diff --git a/fs/fshttp/http_test.go b/fs/fshttp/http_test.go index 24f440b43..2c10e84a2 100644 --- a/fs/fshttp/http_test.go +++ b/fs/fshttp/http_test.go @@ -1,9 +1,24 @@ package fshttp import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "os" "testing" + "time" + "github.com/rclone/rclone/fs" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCleanAuth(t *testing.T) { @@ -45,3 +60,118 @@ func TestCleanAuths(t *testing.T) { assert.Equal(t, test.want, got, test.in) } } + +var certSerial = int64(0) + +// Create a test certificate and key pair that is valid for a specific +// duration +func createTestCert(validity time.Duration) (keyPEM []byte, certPEM []byte, err error) { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + return + } + keyBytes := x509.MarshalPKCS1PrivateKey(key) + // PEM encoding of private key + keyPEM = pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: keyBytes, + }, + ) + + // Now create the certificate + notBefore := time.Now() + notAfter := notBefore.Add(validity).Add(expireWindow) + + certSerial += 1 + template := x509.Certificate{ + SerialNumber: big.NewInt(certSerial), + Subject: pkix.Name{CommonName: "localhost"}, + SignatureAlgorithm: x509.SHA256WithRSA, + NotBefore: notBefore, + NotAfter: notAfter, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + return + } + + certPEM = pem.EncodeToMemory( + &pem.Block{ + Type: "CERTIFICATE", + Bytes: derBytes, + }, + ) + return +} + +func writeTestCert(t *testing.T, ci *fs.ConfigInfo, validity time.Duration) { + keyPEM, certPEM, err := createTestCert(1 * time.Second) + assert.NoError(t, err, "Cannot create test cert") + err = os.WriteFile(ci.ClientCert, certPEM, 0666) + assert.NoError(t, err, "Failed to write cert") + err = os.WriteFile(ci.ClientKey, keyPEM, 0666) + assert.NoError(t, err, "Failed to write key") +} + +func TestCertificates(t *testing.T) { + startTime := time.Now() + // Starting a TLS server + expectedSerial := int64(0) + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cert := r.TLS.PeerCertificates + require.Greater(t, len(cert), 0, "No certificates received") + expectedSerial += 1 + assert.Equal(t, expectedSerial, cert[0].SerialNumber.Int64(), "Did not get the correct serial number in certificate") + // Check that the certificate hasn't expired. We cannot use cert validation + // functions because those check for signature as well and our certificates + // are not properly signed + if time.Now().After(cert[0].NotAfter) { + assert.Fail(t, "Certificate expired", "Certificate expires at %s, current time is %s", cert[0].NotAfter.Sub(startTime), time.Since(startTime)) + } + + // Write some test data to fullfil the request + w.Header().Set("Content-Type", "text/plain") + _, _ = fmt.Fprintln(w, "test data") + })) + defer ts.Close() + // Modify servers config to request a client certificate + // we cannot validate the certificate since we are not properly signing it + ts.TLS.ClientAuth = tls.RequestClientCert + + // Set --client-cert and --client-key in config to + // a pair of temp files + // create a test cert/key pair and write it to the files + ctx := context.TODO() + ci := fs.GetConfig(ctx) + // Create a test certificate and write it to a temp file + ci.ClientCert = t.TempDir() + "client.cert" + ci.ClientKey = t.TempDir() + "client.key" + validity := 1 * time.Second + writeTestCert(t, ci, validity) + + // Now create the client with the above settings + // we need to disable TLS verification since we don't + // care about server certificate + client := NewClient(ctx) + tt := client.Transport.(*Transport) + tt.TLSClientConfig.InsecureSkipVerify = true + + // Now make requests, the first request should be within + // the valid window + _, err := client.Get(ts.URL) + assert.NoError(t, err) + + // Wait for the 2* valid duration of the certificate so that has definitely expired + time.Sleep(2 * validity) + + // Create a new cert and write it to files + writeTestCert(t, ci, validity) + + // The new cert should be auto-loaded before we make this request + _, err = client.Get(ts.URL) + assert.NoError(t, err) +}