redir: Include scheme in redirect rules

And added tests for status code and scheme
This commit is contained in:
Matthew Holt 2015-09-30 08:38:31 -06:00
parent 698399e61f
commit 3f9f675c43
3 changed files with 73 additions and 35 deletions

View File

@ -37,13 +37,13 @@ func redirParse(c *Controller) ([]redirect.Rule, error) {
// checkAndSaveRule checks the rule for validity (except the redir code) // checkAndSaveRule checks the rule for validity (except the redir code)
// and saves it if it's valid, or returns an error. // and saves it if it's valid, or returns an error.
checkAndSaveRule := func(rule redirect.Rule) error { checkAndSaveRule := func(rule redirect.Rule) error {
if rule.From == rule.To { if rule.FromPath == rule.To {
return c.Err("'from' and 'to' values of redirect rule cannot be the same") return c.Err("'from' and 'to' values of redirect rule cannot be the same")
} }
for _, otherRule := range redirects { for _, otherRule := range redirects {
if otherRule.From == rule.From { if otherRule.FromPath == rule.FromPath {
return c.Errf("rule with duplicate 'from' value: %s -> %s", otherRule.From, otherRule.To) return c.Errf("rule with duplicate 'from' value: %s -> %s", otherRule.FromPath, otherRule.To)
} }
} }
@ -60,6 +60,12 @@ func redirParse(c *Controller) ([]redirect.Rule, error) {
var rule redirect.Rule var rule redirect.Rule
if c.Config.TLS.Enabled {
rule.FromScheme = "https"
} else {
rule.FromScheme = "http"
}
// Set initial redirect code // Set initial redirect code
// BUG: If the code is specified for a whole block and that code is invalid, // BUG: If the code is specified for a whole block and that code is invalid,
// the line number will appear on the first line inside the block, even if that // the line number will appear on the first line inside the block, even if that
@ -84,15 +90,15 @@ func redirParse(c *Controller) ([]redirect.Rule, error) {
// To specified (catch-all redirect) // To specified (catch-all redirect)
// Not sure why user is doing this in a table, as it causes all other redirects to be ignored. // Not sure why user is doing this in a table, as it causes all other redirects to be ignored.
// As such, this feature remains undocumented. // As such, this feature remains undocumented.
rule.From = "/" rule.FromPath = "/"
rule.To = insideArgs[0] rule.To = insideArgs[0]
case 2: case 2:
// From and To specified // From and To specified
rule.From = insideArgs[0] rule.FromPath = insideArgs[0]
rule.To = insideArgs[1] rule.To = insideArgs[1]
case 3: case 3:
// From, To, and Code specified // From, To, and Code specified
rule.From = insideArgs[0] rule.FromPath = insideArgs[0]
rule.To = insideArgs[1] rule.To = insideArgs[1]
err := setRedirCode(insideArgs[2], &rule) err := setRedirCode(insideArgs[2], &rule)
if err != nil { if err != nil {
@ -110,16 +116,23 @@ func redirParse(c *Controller) ([]redirect.Rule, error) {
if !hadOptionalBlock { if !hadOptionalBlock {
var rule redirect.Rule var rule redirect.Rule
if c.Config.TLS.Enabled {
rule.FromScheme = "https"
} else {
rule.FromScheme = "http"
}
rule.Code = http.StatusMovedPermanently // default rule.Code = http.StatusMovedPermanently // default
switch len(args) { switch len(args) {
case 1: case 1:
// To specified (catch-all redirect) // To specified (catch-all redirect)
rule.From = "/" rule.FromPath = "/"
rule.To = args[0] rule.To = args[0]
case 2: case 2:
// To and Code specified (catch-all redirect) // To and Code specified (catch-all redirect)
rule.From = "/" rule.FromPath = "/"
rule.To = args[0] rule.To = args[0]
err := setRedirCode(args[1], &rule) err := setRedirCode(args[1], &rule)
if err != nil { if err != nil {
@ -127,7 +140,7 @@ func redirParse(c *Controller) ([]redirect.Rule, error) {
} }
case 3: case 3:
// From, To, and Code specified // From, To, and Code specified
rule.From = args[0] rule.FromPath = args[0]
rule.To = args[1] rule.To = args[1]
err := setRedirCode(args[2], &rule) err := setRedirCode(args[2], &rule)
if err != nil { if err != nil {
@ -149,12 +162,12 @@ func redirParse(c *Controller) ([]redirect.Rule, error) {
// httpRedirs is a list of supported HTTP redirect codes. // httpRedirs is a list of supported HTTP redirect codes.
var httpRedirs = map[string]int{ var httpRedirs = map[string]int{
"300": 300, // Multiple Choices "300": http.StatusMultipleChoices,
"301": 301, // Moved Permanently "301": http.StatusMovedPermanently,
"302": 302, // Found (NOT CORRECT for "Temporary Redirect", see 307) "302": http.StatusFound, // (NOT CORRECT for "Temporary Redirect", see 307)
"303": 303, // See Other "303": http.StatusSeeOther,
"304": 304, // Not Modified "304": http.StatusNotModified,
"305": 305, // Use Proxy "305": http.StatusUseProxy,
"307": 307, // Temporary Redirect "307": http.StatusTemporaryRedirect,
"308": 308, // Permanent Redirect "308": 308, // Permanent Redirect
} }

View File

@ -19,7 +19,7 @@ type Redirect struct {
// ServeHTTP implements the middleware.Handler interface. // ServeHTTP implements the middleware.Handler interface.
func (rd Redirect) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (rd Redirect) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range rd.Rules { for _, rule := range rd.Rules {
if rule.From == "/" || r.URL.Path == rule.From { if (rule.FromPath == "/" || r.URL.Path == rule.FromPath) && schemeMatches(rule, r) {
to := middleware.NewReplacer(r, nil, "").Replace(rule.To) to := middleware.NewReplacer(r, nil, "").Replace(rule.To)
if rule.Meta { if rule.Meta {
safeTo := html.EscapeString(to) safeTo := html.EscapeString(to)
@ -33,9 +33,14 @@ func (rd Redirect) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error
return rd.Next.ServeHTTP(w, r) return rd.Next.ServeHTTP(w, r)
} }
func schemeMatches(rule Rule, req *http.Request) bool {
return (rule.FromScheme == "https" && req.TLS != nil) ||
(rule.FromScheme != "https" && req.TLS == nil)
}
// Rule describes an HTTP redirect rule. // Rule describes an HTTP redirect rule.
type Rule struct { type Rule struct {
From, To string FromScheme, FromPath, To string
Code int Code int
Meta bool Meta bool
} }

View File

@ -2,9 +2,11 @@ package redirect
import ( import (
"bytes" "bytes"
"crypto/tls"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
@ -14,15 +16,22 @@ func TestRedirect(t *testing.T) {
for i, test := range []struct { for i, test := range []struct {
from string from string
expectedLocation string expectedLocation string
expectedCode int
}{ }{
{"/from", "/to"}, {"http://localhost/from", "/to", http.StatusMovedPermanently},
{"/a", "/b"}, {"http://localhost/a", "/b", http.StatusTemporaryRedirect},
{"/aa", ""}, {"http://localhost/aa", "", http.StatusOK},
{"/", ""}, {"http://localhost/", "", http.StatusOK},
{"/a?foo=bar", "/b"}, {"http://localhost/a?foo=bar", "/b", http.StatusTemporaryRedirect},
{"/asdf?foo=bar", ""}, {"http://localhost/asdf?foo=bar", "", http.StatusOK},
{"/foo#bar", ""}, {"http://localhost/foo#bar", "", http.StatusOK},
{"/a#foo", "/b"}, {"http://localhost/a#foo", "/b", http.StatusTemporaryRedirect},
{"http://localhost/scheme", "https://localhost/scheme", http.StatusMovedPermanently},
{"https://localhost/scheme", "", http.StatusOK},
{"https://localhost/scheme2", "http://localhost/scheme2", http.StatusMovedPermanently},
{"http://localhost/scheme2", "", http.StatusOK},
{"http://localhost/scheme3", "https://localhost/scheme3", http.StatusMovedPermanently},
{"https://localhost/scheme3", "", http.StatusOK},
} { } {
var nextCalled bool var nextCalled bool
@ -32,8 +41,11 @@ func TestRedirect(t *testing.T) {
return 0, nil return 0, nil
}), }),
Rules: []Rule{ Rules: []Rule{
{From: "/from", To: "/to"}, {FromPath: "/from", To: "/to", Code: http.StatusMovedPermanently},
{From: "/a", To: "/b"}, {FromPath: "/a", To: "/b", Code: http.StatusTemporaryRedirect},
{FromScheme: "http", FromPath: "/scheme", To: "https://localhost/scheme", Code: http.StatusMovedPermanently},
{FromScheme: "https", FromPath: "/scheme2", To: "http://localhost/scheme2", Code: http.StatusMovedPermanently},
{FromScheme: "", FromPath: "/scheme3", To: "https://localhost/scheme3", Code: http.StatusMovedPermanently},
}, },
} }
@ -41,6 +53,9 @@ func TestRedirect(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
} }
if strings.HasPrefix(test.from, "https://") {
req.TLS = new(tls.ConnectionState) // faux HTTPS
}
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
re.ServeHTTP(rec, req) re.ServeHTTP(rec, req)
@ -50,6 +65,11 @@ func TestRedirect(t *testing.T) {
i, test.expectedLocation, rec.Header().Get("Location")) i, test.expectedLocation, rec.Header().Get("Location"))
} }
if rec.Code != test.expectedCode {
t.Errorf("Test %d: Expected status code to be %d but was %d",
i, test.expectedCode, rec.Code)
}
if nextCalled && test.expectedLocation != "" { if nextCalled && test.expectedLocation != "" {
t.Errorf("Test %d: Next handler was unexpectedly called", i) t.Errorf("Test %d: Next handler was unexpectedly called", i)
} }
@ -59,7 +79,7 @@ func TestRedirect(t *testing.T) {
func TestParametersRedirect(t *testing.T) { func TestParametersRedirect(t *testing.T) {
re := Redirect{ re := Redirect{
Rules: []Rule{ Rules: []Rule{
{From: "/", Meta: false, To: "http://example.com{uri}"}, {FromPath: "/", Meta: false, To: "http://example.com{uri}"},
}, },
} }
@ -77,7 +97,7 @@ func TestParametersRedirect(t *testing.T) {
re = Redirect{ re = Redirect{
Rules: []Rule{ Rules: []Rule{
{From: "/", Meta: false, To: "http://example.com/a{path}?b=c&{query}"}, {FromPath: "/", Meta: false, To: "http://example.com/a{path}?b=c&{query}"},
}, },
} }
@ -96,13 +116,13 @@ func TestParametersRedirect(t *testing.T) {
func TestMetaRedirect(t *testing.T) { func TestMetaRedirect(t *testing.T) {
re := Redirect{ re := Redirect{
Rules: []Rule{ Rules: []Rule{
{From: "/whatever", Meta: true, To: "/something"}, {FromPath: "/whatever", Meta: true, To: "/something"},
{From: "/", Meta: true, To: "https://example.com/"}, {FromPath: "/", Meta: true, To: "https://example.com/"},
}, },
} }
for i, test := range re.Rules { for i, test := range re.Rules {
req, err := http.NewRequest("GET", test.From, nil) req, err := http.NewRequest("GET", test.FromPath, nil)
if err != nil { if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
} }