diff --git a/fs/fshttp/http.go b/fs/fshttp/http.go index 498714dfd..be9e3aee3 100644 --- a/fs/fshttp/http.go +++ b/fs/fshttp/http.go @@ -69,6 +69,13 @@ func NewTransportCustom(ctx context.Context, customize func(*http.Transport)) ht if err != nil { log.Fatalf("Failed to load --client-cert/--client-key pair: %v", err) } + if cert.Leaf == nil { + // Leaf is always the first certificate + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + log.Fatalf("Failed to parse the certificate") + } + } t.TLSClientConfig.Certificates = []tls.Certificate{cert} } @@ -148,17 +155,24 @@ type Transport struct { userAgent string headers []*fs.HTTPOption metrics *Metrics + // Filename of the client cert in case we need to reload it + clientCert string + clientKey string + // Mutex for serializing attempts at reloading the certificates + reloadMutex sync.Mutex } // newTransport wraps the http.Transport passed in and logs all // roundtrips including the body if logBody is set. func newTransport(ci *fs.ConfigInfo, transport *http.Transport) *Transport { return &Transport{ - Transport: transport, - dump: ci.Dump, - userAgent: ci.UserAgent, - headers: ci.Headers, - metrics: DefaultMetrics, + Transport: transport, + dump: ci.Dump, + userAgent: ci.UserAgent, + headers: ci.Headers, + metrics: DefaultMetrics, + clientCert: ci.ClientCert, + clientKey: ci.ClientKey, } } @@ -247,8 +261,44 @@ func cleanAuths(buf []byte) []byte { return buf } +var expireWindow = 30 * time.Second + +func isCertificateExpired(cc *tls.Config) bool { + return len(cc.Certificates) > 0 && cc.Certificates[0].Leaf != nil && time.Until(cc.Certificates[0].Leaf.NotAfter) < expireWindow +} + +func (t *Transport) reloadCertificates() { + t.reloadMutex.Lock() + defer t.reloadMutex.Unlock() + // Check that the certificate is expired before trying to reload it + // it might have been reloaded while we were waiting to lock the mutex + if !isCertificateExpired(t.TLSClientConfig) { + return + } + + cert, err := tls.LoadX509KeyPair(t.clientCert, t.clientKey) + if err != nil { + log.Fatalf("Failed to load --client-cert/--client-key pair: %v", err) + } + // Check if we need to parse the certificate again, we need it + // for checking the expiration date + if cert.Leaf == nil { + // Leaf is always the first certificate + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + log.Fatalf("Failed to parse the certificate") + } + } + t.TLSClientConfig.Certificates = []tls.Certificate{cert} +} + // RoundTrip implements the RoundTripper interface. func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + // Check if certificates are being used and the certificates are expired + if isCertificateExpired(t.TLSClientConfig) { + t.reloadCertificates() + } + // Limit transactions per second if required accounting.LimitTPS(req.Context()) // Force user agent diff --git a/fs/fshttp/http_test.go b/fs/fshttp/http_test.go index 24f440b43..2c10e84a2 100644 --- a/fs/fshttp/http_test.go +++ b/fs/fshttp/http_test.go @@ -1,9 +1,24 @@ package fshttp import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "os" "testing" + "time" + "github.com/rclone/rclone/fs" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCleanAuth(t *testing.T) { @@ -45,3 +60,118 @@ func TestCleanAuths(t *testing.T) { assert.Equal(t, test.want, got, test.in) } } + +var certSerial = int64(0) + +// Create a test certificate and key pair that is valid for a specific +// duration +func createTestCert(validity time.Duration) (keyPEM []byte, certPEM []byte, err error) { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + return + } + keyBytes := x509.MarshalPKCS1PrivateKey(key) + // PEM encoding of private key + keyPEM = pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: keyBytes, + }, + ) + + // Now create the certificate + notBefore := time.Now() + notAfter := notBefore.Add(validity).Add(expireWindow) + + certSerial += 1 + template := x509.Certificate{ + SerialNumber: big.NewInt(certSerial), + Subject: pkix.Name{CommonName: "localhost"}, + SignatureAlgorithm: x509.SHA256WithRSA, + NotBefore: notBefore, + NotAfter: notAfter, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + return + } + + certPEM = pem.EncodeToMemory( + &pem.Block{ + Type: "CERTIFICATE", + Bytes: derBytes, + }, + ) + return +} + +func writeTestCert(t *testing.T, ci *fs.ConfigInfo, validity time.Duration) { + keyPEM, certPEM, err := createTestCert(1 * time.Second) + assert.NoError(t, err, "Cannot create test cert") + err = os.WriteFile(ci.ClientCert, certPEM, 0666) + assert.NoError(t, err, "Failed to write cert") + err = os.WriteFile(ci.ClientKey, keyPEM, 0666) + assert.NoError(t, err, "Failed to write key") +} + +func TestCertificates(t *testing.T) { + startTime := time.Now() + // Starting a TLS server + expectedSerial := int64(0) + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cert := r.TLS.PeerCertificates + require.Greater(t, len(cert), 0, "No certificates received") + expectedSerial += 1 + assert.Equal(t, expectedSerial, cert[0].SerialNumber.Int64(), "Did not get the correct serial number in certificate") + // Check that the certificate hasn't expired. We cannot use cert validation + // functions because those check for signature as well and our certificates + // are not properly signed + if time.Now().After(cert[0].NotAfter) { + assert.Fail(t, "Certificate expired", "Certificate expires at %s, current time is %s", cert[0].NotAfter.Sub(startTime), time.Since(startTime)) + } + + // Write some test data to fullfil the request + w.Header().Set("Content-Type", "text/plain") + _, _ = fmt.Fprintln(w, "test data") + })) + defer ts.Close() + // Modify servers config to request a client certificate + // we cannot validate the certificate since we are not properly signing it + ts.TLS.ClientAuth = tls.RequestClientCert + + // Set --client-cert and --client-key in config to + // a pair of temp files + // create a test cert/key pair and write it to the files + ctx := context.TODO() + ci := fs.GetConfig(ctx) + // Create a test certificate and write it to a temp file + ci.ClientCert = t.TempDir() + "client.cert" + ci.ClientKey = t.TempDir() + "client.key" + validity := 1 * time.Second + writeTestCert(t, ci, validity) + + // Now create the client with the above settings + // we need to disable TLS verification since we don't + // care about server certificate + client := NewClient(ctx) + tt := client.Transport.(*Transport) + tt.TLSClientConfig.InsecureSkipVerify = true + + // Now make requests, the first request should be within + // the valid window + _, err := client.Get(ts.URL) + assert.NoError(t, err) + + // Wait for the 2* valid duration of the certificate so that has definitely expired + time.Sleep(2 * validity) + + // Create a new cert and write it to files + writeTestCert(t, ci, validity) + + // The new cert should be auto-loaded before we make this request + _, err = client.Get(ts.URL) + assert.NoError(t, err) +}