mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-20 02:52:44 +08:00
Migrating more middleware packages
This commit is contained in:
parent
2f92443de7
commit
416af05a00
148
caddyhttp/basicauth/basicauth.go
Normal file
148
caddyhttp/basicauth/basicauth.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
// Package basicauth implements HTTP Basic Authentication.
|
||||
package basicauth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jimstudt/http-authentication/basic"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// BasicAuth is middleware to protect resources with a username and password.
|
||||
// Note that HTTP Basic Authentication is not secure by itself and should
|
||||
// not be used to protect important assets without HTTPS. Even then, the
|
||||
// security of HTTP Basic Auth is disputed. Use discretion when deciding
|
||||
// what to protect with BasicAuth.
|
||||
type BasicAuth struct {
|
||||
Next httpserver.Handler
|
||||
SiteRoot string
|
||||
Rules []Rule
|
||||
}
|
||||
|
||||
// ServeHTTP implements the httpserver.Handler interface.
|
||||
func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
|
||||
var hasAuth bool
|
||||
var isAuthenticated bool
|
||||
|
||||
for _, rule := range a.Rules {
|
||||
for _, res := range rule.Resources {
|
||||
if !httpserver.Path(r.URL.Path).Matches(res) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Path matches; parse auth header
|
||||
username, password, ok := r.BasicAuth()
|
||||
hasAuth = true
|
||||
|
||||
// Check credentials
|
||||
if !ok ||
|
||||
username != rule.Username ||
|
||||
!rule.Password(password) {
|
||||
//subtle.ConstantTimeCompare([]byte(password), []byte(rule.Password)) != 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Flag set only on successful authentication
|
||||
isAuthenticated = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasAuth {
|
||||
if !isAuthenticated {
|
||||
w.Header().Set("WWW-Authenticate", "Basic")
|
||||
return http.StatusUnauthorized, nil
|
||||
}
|
||||
// "It's an older code, sir, but it checks out. I was about to clear them."
|
||||
return a.Next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Pass-thru when no paths match
|
||||
return a.Next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Rule represents a BasicAuth rule. A username and password
|
||||
// combination protect the associated resources, which are
|
||||
// file or directory paths.
|
||||
type Rule struct {
|
||||
Username string
|
||||
Password func(string) bool
|
||||
Resources []string
|
||||
}
|
||||
|
||||
// PasswordMatcher determines whether a password matches a rule.
|
||||
type PasswordMatcher func(pw string) bool
|
||||
|
||||
var (
|
||||
htpasswords map[string]map[string]PasswordMatcher
|
||||
htpasswordsMu sync.Mutex
|
||||
)
|
||||
|
||||
// GetHtpasswdMatcher matches password rules.
|
||||
func GetHtpasswdMatcher(filename, username, siteRoot string) (PasswordMatcher, error) {
|
||||
filename = filepath.Join(siteRoot, filename)
|
||||
htpasswordsMu.Lock()
|
||||
if htpasswords == nil {
|
||||
htpasswords = make(map[string]map[string]PasswordMatcher)
|
||||
}
|
||||
pm := htpasswords[filename]
|
||||
if pm == nil {
|
||||
fh, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %q: %v", filename, err)
|
||||
}
|
||||
defer fh.Close()
|
||||
pm = make(map[string]PasswordMatcher)
|
||||
if err = parseHtpasswd(pm, fh); err != nil {
|
||||
return nil, fmt.Errorf("parsing htpasswd %q: %v", fh.Name(), err)
|
||||
}
|
||||
htpasswords[filename] = pm
|
||||
}
|
||||
htpasswordsMu.Unlock()
|
||||
if pm[username] == nil {
|
||||
return nil, fmt.Errorf("username %q not found in %q", username, filename)
|
||||
}
|
||||
return pm[username], nil
|
||||
}
|
||||
|
||||
func parseHtpasswd(pm map[string]PasswordMatcher, r io.Reader) error {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.IndexByte(line, '#') == 0 {
|
||||
continue
|
||||
}
|
||||
i := strings.IndexByte(line, ':')
|
||||
if i <= 0 {
|
||||
return fmt.Errorf("malformed line, no color: %q", line)
|
||||
}
|
||||
user, encoded := line[:i], line[i+1:]
|
||||
for _, p := range basic.DefaultSystems {
|
||||
matcher, err := p(encoded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if matcher != nil {
|
||||
pm[user] = matcher.MatchesPassword
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// PlainMatcher returns a PasswordMatcher that does a constant-time
|
||||
// byte-wise comparison.
|
||||
func PlainMatcher(passw string) PasswordMatcher {
|
||||
return func(pw string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(pw), []byte(passw)) == 1
|
||||
}
|
||||
}
|
146
caddyhttp/basicauth/basicauth_test.go
Normal file
146
caddyhttp/basicauth/basicauth_test.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
rw := BasicAuth{
|
||||
Next: httpserver.HandlerFunc(contentHandler),
|
||||
Rules: []Rule{
|
||||
{Username: "test", Password: PlainMatcher("ttest"), Resources: []string{"/testing"}},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
from string
|
||||
result int
|
||||
cred string
|
||||
}{
|
||||
{"/testing", http.StatusUnauthorized, "ttest:test"},
|
||||
{"/testing", http.StatusOK, "test:ttest"},
|
||||
{"/testing", http.StatusUnauthorized, ""},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
|
||||
req, err := http.NewRequest("GET", test.from, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request %v", i, err)
|
||||
}
|
||||
auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred))
|
||||
req.Header.Set("Authorization", auth)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
result, err := rw.ServeHTTP(rec, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not ServeHTTP %v", i, err)
|
||||
}
|
||||
if result != test.result {
|
||||
t.Errorf("Test %d: Expected Header '%d' but was '%d'",
|
||||
i, test.result, result)
|
||||
}
|
||||
if result == http.StatusUnauthorized {
|
||||
headers := rec.Header()
|
||||
if val, ok := headers["Www-Authenticate"]; ok {
|
||||
if val[0] != "Basic" {
|
||||
t.Errorf("Test %d, Www-Authenticate should be %s provided %s", i, "Basic", val[0])
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Test %d, should provide a header Www-Authenticate", i)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestMultipleOverlappingRules(t *testing.T) {
|
||||
rw := BasicAuth{
|
||||
Next: httpserver.HandlerFunc(contentHandler),
|
||||
Rules: []Rule{
|
||||
{Username: "t", Password: PlainMatcher("p1"), Resources: []string{"/t"}},
|
||||
{Username: "t1", Password: PlainMatcher("p2"), Resources: []string{"/t/t"}},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
from string
|
||||
result int
|
||||
cred string
|
||||
}{
|
||||
{"/t", http.StatusOK, "t:p1"},
|
||||
{"/t/t", http.StatusOK, "t:p1"},
|
||||
{"/t/t", http.StatusOK, "t1:p2"},
|
||||
{"/a", http.StatusOK, "t1:p2"},
|
||||
{"/t/t", http.StatusUnauthorized, "t1:p3"},
|
||||
{"/t", http.StatusUnauthorized, "t1:p2"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
|
||||
req, err := http.NewRequest("GET", test.from, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request %v", i, err)
|
||||
}
|
||||
auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred))
|
||||
req.Header.Set("Authorization", auth)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
result, err := rw.ServeHTTP(rec, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not ServeHTTP %v", i, err)
|
||||
}
|
||||
if result != test.result {
|
||||
t.Errorf("Test %d: Expected Header '%d' but was '%d'",
|
||||
i, test.result, result)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func contentHandler(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
fmt.Fprintf(w, r.URL.String())
|
||||
return http.StatusOK, nil
|
||||
}
|
||||
|
||||
func TestHtpasswd(t *testing.T) {
|
||||
htpasswdPasswd := "IedFOuGmTpT8"
|
||||
htpasswdFile := `sha1:{SHA}dcAUljwz99qFjYR0YLTXx0RqLww=
|
||||
md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61`
|
||||
|
||||
htfh, err := ioutil.TempFile("", "basicauth-")
|
||||
if err != nil {
|
||||
t.Skipf("Error creating temp file (%v), will skip htpassword test")
|
||||
return
|
||||
}
|
||||
defer os.Remove(htfh.Name())
|
||||
if _, err = htfh.Write([]byte(htpasswdFile)); err != nil {
|
||||
t.Fatalf("write htpasswd file %q: %v", htfh.Name(), err)
|
||||
}
|
||||
htfh.Close()
|
||||
|
||||
for i, username := range []string{"sha1", "md5"} {
|
||||
rule := Rule{Username: username, Resources: []string{"/testing"}}
|
||||
|
||||
siteRoot := filepath.Dir(htfh.Name())
|
||||
filename := filepath.Base(htfh.Name())
|
||||
if rule.Password, err = GetHtpasswdMatcher(filename, rule.Username, siteRoot); err != nil {
|
||||
t.Fatalf("GetHtpasswdMatcher(%q, %q): %v", htfh.Name(), rule.Username, err)
|
||||
}
|
||||
t.Logf("%d. username=%q", i, rule.Username)
|
||||
if !rule.Password(htpasswdPasswd) || rule.Password(htpasswdPasswd+"!") {
|
||||
t.Errorf("%d (%s) password does not match.", i, rule.Username)
|
||||
}
|
||||
}
|
||||
}
|
83
caddyhttp/basicauth/setup.go
Normal file
83
caddyhttp/basicauth/setup.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin(caddy.Plugin{
|
||||
Name: "basicauth",
|
||||
ServerType: "http",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// setup configures a new BasicAuth middleware instance.
|
||||
func setup(c *caddy.Controller) error {
|
||||
cfg := httpserver.GetConfig(c.Key)
|
||||
root := cfg.Root
|
||||
|
||||
rules, err := basicAuthParse(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
basic := BasicAuth{Rules: rules}
|
||||
|
||||
cfg.AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
basic.Next = next
|
||||
basic.SiteRoot = root
|
||||
return basic
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func basicAuthParse(c *caddy.Controller) ([]Rule, error) {
|
||||
var rules []Rule
|
||||
cfg := httpserver.GetConfig(c.Key)
|
||||
|
||||
var err error
|
||||
for c.Next() {
|
||||
var rule Rule
|
||||
|
||||
args := c.RemainingArgs()
|
||||
|
||||
switch len(args) {
|
||||
case 2:
|
||||
rule.Username = args[0]
|
||||
if rule.Password, err = passwordMatcher(rule.Username, args[1], cfg.Root); err != nil {
|
||||
return rules, c.Errf("Get password matcher from %s: %v", c.Val(), err)
|
||||
}
|
||||
|
||||
for c.NextBlock() {
|
||||
rule.Resources = append(rule.Resources, c.Val())
|
||||
if c.NextArg() {
|
||||
return rules, c.Errf("Expecting only one resource per line (extra '%s')", c.Val())
|
||||
}
|
||||
}
|
||||
case 3:
|
||||
rule.Resources = append(rule.Resources, args[0])
|
||||
rule.Username = args[1]
|
||||
if rule.Password, err = passwordMatcher(rule.Username, args[2], cfg.Root); err != nil {
|
||||
return rules, c.Errf("Get password matcher from %s: %v", c.Val(), err)
|
||||
}
|
||||
default:
|
||||
return rules, c.ArgErr()
|
||||
}
|
||||
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func passwordMatcher(username, passw, siteRoot string) (PasswordMatcher, error) {
|
||||
if !strings.HasPrefix(passw, "htpasswd=") {
|
||||
return PlainMatcher(passw), nil
|
||||
}
|
||||
return GetHtpasswdMatcher(passw[9:], username, siteRoot)
|
||||
}
|
131
caddyhttp/basicauth/setup_test.go
Normal file
131
caddyhttp/basicauth/setup_test.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
err := setup(caddy.NewTestController(`basicauth user pwd`))
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, but got: %v", err)
|
||||
}
|
||||
mids := httpserver.GetConfig("").Middleware()
|
||||
if len(mids) == 0 {
|
||||
t.Fatal("Expected middleware, got 0 instead")
|
||||
}
|
||||
|
||||
handler := mids[0](httpserver.EmptyNext)
|
||||
myHandler, ok := handler.(BasicAuth)
|
||||
if !ok {
|
||||
t.Fatalf("Expected handler to be type BasicAuth, got: %#v", handler)
|
||||
}
|
||||
|
||||
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
|
||||
t.Error("'Next' field of handler was not set properly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicAuthParse(t *testing.T) {
|
||||
htpasswdPasswd := "IedFOuGmTpT8"
|
||||
htpasswdFile := `sha1:{SHA}dcAUljwz99qFjYR0YLTXx0RqLww=
|
||||
md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61`
|
||||
|
||||
var skipHtpassword bool
|
||||
htfh, err := ioutil.TempFile(".", "basicauth-")
|
||||
if err != nil {
|
||||
t.Logf("Error creating temp file (%v), will skip htpassword test", err)
|
||||
skipHtpassword = true
|
||||
} else {
|
||||
if _, err = htfh.Write([]byte(htpasswdFile)); err != nil {
|
||||
t.Fatalf("write htpasswd file %q: %v", htfh.Name(), err)
|
||||
}
|
||||
htfh.Close()
|
||||
defer os.Remove(htfh.Name())
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
password string
|
||||
expected []Rule
|
||||
}{
|
||||
{`basicauth user pwd`, false, "pwd", []Rule{
|
||||
{Username: "user"},
|
||||
}},
|
||||
{`basicauth user pwd {
|
||||
}`, false, "pwd", []Rule{
|
||||
{Username: "user"},
|
||||
}},
|
||||
{`basicauth user pwd {
|
||||
/resource1
|
||||
/resource2
|
||||
}`, false, "pwd", []Rule{
|
||||
{Username: "user", Resources: []string{"/resource1", "/resource2"}},
|
||||
}},
|
||||
{`basicauth /resource user pwd`, false, "pwd", []Rule{
|
||||
{Username: "user", Resources: []string{"/resource"}},
|
||||
}},
|
||||
{`basicauth /res1 user1 pwd1
|
||||
basicauth /res2 user2 pwd2`, false, "pwd", []Rule{
|
||||
{Username: "user1", Resources: []string{"/res1"}},
|
||||
{Username: "user2", Resources: []string{"/res2"}},
|
||||
}},
|
||||
{`basicauth user`, true, "", []Rule{}},
|
||||
{`basicauth`, true, "", []Rule{}},
|
||||
{`basicauth /resource user pwd asdf`, true, "", []Rule{}},
|
||||
|
||||
{`basicauth sha1 htpasswd=` + htfh.Name(), false, htpasswdPasswd, []Rule{
|
||||
{Username: "sha1"},
|
||||
}},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
actual, err := basicAuthParse(caddy.NewTestController(test.input))
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Errorf("Test %d didn't error, but it should have", i)
|
||||
} else if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(test.expected) {
|
||||
t.Fatalf("Test %d expected %d rules, but got %d",
|
||||
i, len(test.expected), len(actual))
|
||||
}
|
||||
|
||||
for j, expectedRule := range test.expected {
|
||||
actualRule := actual[j]
|
||||
|
||||
if actualRule.Username != expectedRule.Username {
|
||||
t.Errorf("Test %d, rule %d: Expected username '%s', got '%s'",
|
||||
i, j, expectedRule.Username, actualRule.Username)
|
||||
}
|
||||
|
||||
if strings.Contains(test.input, "htpasswd=") && skipHtpassword {
|
||||
continue
|
||||
}
|
||||
pwd := test.password
|
||||
if len(actual) > 1 {
|
||||
pwd = fmt.Sprintf("%s%d", pwd, j+1)
|
||||
}
|
||||
if !actualRule.Password(pwd) || actualRule.Password(test.password+"!") {
|
||||
t.Errorf("Test %d, rule %d: Expected password '%v', got '%v'",
|
||||
i, j, test.password, actualRule.Password(""))
|
||||
}
|
||||
|
||||
expectedRes := fmt.Sprintf("%v", expectedRule.Resources)
|
||||
actualRes := fmt.Sprintf("%v", actualRule.Resources)
|
||||
if actualRes != expectedRes {
|
||||
t.Errorf("Test %d, rule %d: Expected resource list %s, but got %s",
|
||||
i, j, expectedRes, actualRes)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
45
caddyhttp/expvar/expvar.go
Normal file
45
caddyhttp/expvar/expvar.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package expvar
|
||||
|
||||
import (
|
||||
"expvar"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// ExpVar is a simple struct to hold expvar's configuration
|
||||
type ExpVar struct {
|
||||
Next httpserver.Handler
|
||||
Resource Resource
|
||||
}
|
||||
|
||||
// ServeHTTP handles requests to expvar's configured entry point with
|
||||
// expvar, or passes all other requests up the chain.
|
||||
func (e ExpVar) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
if httpserver.Path(r.URL.Path).Matches(string(e.Resource)) {
|
||||
expvarHandler(w, r)
|
||||
return 0, nil
|
||||
}
|
||||
return e.Next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// expvarHandler returns a JSON object will all the published variables.
|
||||
//
|
||||
// This is lifted straight from the expvar package.
|
||||
func expvarHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
fmt.Fprintf(w, "{\n")
|
||||
first := true
|
||||
expvar.Do(func(kv expvar.KeyValue) {
|
||||
if !first {
|
||||
fmt.Fprintf(w, ",\n")
|
||||
}
|
||||
first = false
|
||||
fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value)
|
||||
})
|
||||
fmt.Fprintf(w, "\n}\n")
|
||||
}
|
||||
|
||||
// Resource contains the path to the expvar entry point
|
||||
type Resource string
|
46
caddyhttp/expvar/expvar_test.go
Normal file
46
caddyhttp/expvar/expvar_test.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package expvar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestExpVar(t *testing.T) {
|
||||
rw := ExpVar{
|
||||
Next: httpserver.HandlerFunc(contentHandler),
|
||||
Resource: "/d/v",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
from string
|
||||
result int
|
||||
}{
|
||||
{"/d/v", 0},
|
||||
{"/x/y", http.StatusOK},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
req, err := http.NewRequest("GET", test.from, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request %v", i, err)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
result, err := rw.ServeHTTP(rec, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not ServeHTTP %v", i, err)
|
||||
}
|
||||
if result != test.result {
|
||||
t.Errorf("Test %d: Expected Header '%d' but was '%d'",
|
||||
i, test.result, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func contentHandler(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
fmt.Fprintf(w, r.URL.String())
|
||||
return http.StatusOK, nil
|
||||
}
|
70
caddyhttp/expvar/setup.go
Normal file
70
caddyhttp/expvar/setup.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package expvar
|
||||
|
||||
import (
|
||||
"expvar"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin(caddy.Plugin{
|
||||
Name: "expvar",
|
||||
ServerType: "http",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// setup configures a new ExpVar middleware instance.
|
||||
func setup(c *caddy.Controller) error {
|
||||
resource, err := expVarParse(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// publish any extra information/metrics we may want to capture
|
||||
publishExtraVars()
|
||||
|
||||
ev := ExpVar{Resource: resource}
|
||||
|
||||
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
ev.Next = next
|
||||
return ev
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func expVarParse(c *caddy.Controller) (Resource, error) {
|
||||
var resource Resource
|
||||
var err error
|
||||
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
switch len(args) {
|
||||
case 0:
|
||||
resource = Resource(defaultExpvarPath)
|
||||
case 1:
|
||||
resource = Resource(args[0])
|
||||
default:
|
||||
return resource, c.ArgErr()
|
||||
}
|
||||
}
|
||||
|
||||
return resource, err
|
||||
}
|
||||
|
||||
func publishExtraVars() {
|
||||
// By using sync.Once instead of an init() function, we don't clutter
|
||||
// the app's expvar export unnecessarily, or risk colliding with it.
|
||||
publishOnce.Do(func() {
|
||||
expvar.Publish("Goroutines", expvar.Func(func() interface{} {
|
||||
return runtime.NumGoroutine()
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
var publishOnce sync.Once // publishing variables should only be done once
|
||||
var defaultExpvarPath = "/debug/vars"
|
40
caddyhttp/expvar/setup_test.go
Normal file
40
caddyhttp/expvar/setup_test.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package expvar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
err := setup(caddy.NewTestController(`expvar`))
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, got: %v", err)
|
||||
}
|
||||
mids := httpserver.GetConfig("").Middleware()
|
||||
if len(mids) == 0 {
|
||||
t.Fatal("Expected middleware, got 0 instead")
|
||||
}
|
||||
|
||||
err = setup(caddy.NewTestController(`expvar /d/v`))
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, got: %v", err)
|
||||
}
|
||||
mids = httpserver.GetConfig("").Middleware()
|
||||
if len(mids) == 0 {
|
||||
t.Fatal("Expected middleware, got 0 instead")
|
||||
}
|
||||
|
||||
handler := mids[1](httpserver.EmptyNext)
|
||||
myHandler, ok := handler.(ExpVar)
|
||||
if !ok {
|
||||
t.Fatalf("Expected handler to be type ExpVar, got: %#v", handler)
|
||||
}
|
||||
if myHandler.Resource != "/d/v" {
|
||||
t.Errorf("Expected /d/v as expvar resource")
|
||||
}
|
||||
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
|
||||
t.Error("'Next' field of handler was not set properly")
|
||||
}
|
||||
}
|
52
caddyhttp/extensions/ext.go
Normal file
52
caddyhttp/extensions/ext.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
// Package extensions contains middleware for clean URLs.
|
||||
//
|
||||
// The root path of the site is passed in as well as possible extensions
|
||||
// to try internally for paths requested that don't match an existing
|
||||
// resource. The first path+ext combination that matches a valid file
|
||||
// will be used.
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// Ext can assume an extension from clean URLs.
|
||||
// It tries extensions in the order listed in Extensions.
|
||||
type Ext struct {
|
||||
// Next handler in the chain
|
||||
Next httpserver.Handler
|
||||
|
||||
// Path to ther root of the site
|
||||
Root string
|
||||
|
||||
// List of extensions to try
|
||||
Extensions []string
|
||||
}
|
||||
|
||||
// ServeHTTP implements the httpserver.Handler interface.
|
||||
func (e Ext) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
urlpath := strings.TrimSuffix(r.URL.Path, "/")
|
||||
if path.Ext(urlpath) == "" && len(r.URL.Path) > 0 && r.URL.Path[len(r.URL.Path)-1] != '/' {
|
||||
for _, ext := range e.Extensions {
|
||||
if resourceExists(e.Root, urlpath+ext) {
|
||||
r.URL.Path = urlpath + ext
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return e.Next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// resourceExists returns true if the file specified at
|
||||
// root + path exists; false otherwise.
|
||||
func resourceExists(root, path string) bool {
|
||||
_, err := os.Stat(root + path)
|
||||
// technically we should use os.IsNotExist(err)
|
||||
// but we don't handle any other kinds of errors anyway
|
||||
return err == nil
|
||||
}
|
54
caddyhttp/extensions/setup.go
Normal file
54
caddyhttp/extensions/setup.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package extensions
|
||||
|
||||
import (
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin(caddy.Plugin{
|
||||
Name: "ext",
|
||||
ServerType: "http",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// setup configures a new instance of 'extensions' middleware for clean URLs.
|
||||
func setup(c *caddy.Controller) error {
|
||||
cfg := httpserver.GetConfig(c.Key)
|
||||
root := cfg.Root
|
||||
|
||||
exts, err := extParse(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
return Ext{
|
||||
Next: next,
|
||||
Extensions: exts,
|
||||
Root: root,
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extParse sets up an instance of extension middleware
|
||||
// from a middleware controller and returns a list of extensions.
|
||||
func extParse(c *caddy.Controller) ([]string, error) {
|
||||
var exts []string
|
||||
|
||||
for c.Next() {
|
||||
// At least one extension is required
|
||||
if !c.NextArg() {
|
||||
return exts, c.ArgErr()
|
||||
}
|
||||
exts = append(exts, c.Val())
|
||||
|
||||
// Tack on any other extensions that may have been listed
|
||||
exts = append(exts, c.RemainingArgs()...)
|
||||
}
|
||||
|
||||
return exts, nil
|
||||
}
|
74
caddyhttp/extensions/setup_test.go
Normal file
74
caddyhttp/extensions/setup_test.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package extensions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
err := setup(caddy.NewTestController(`ext .html .htm .php`))
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no errors, got: %v", err)
|
||||
}
|
||||
|
||||
mids := httpserver.GetConfig("").Middleware()
|
||||
if len(mids) == 0 {
|
||||
t.Fatal("Expected middleware, had 0 instead")
|
||||
}
|
||||
|
||||
handler := mids[0](httpserver.EmptyNext)
|
||||
myHandler, ok := handler.(Ext)
|
||||
|
||||
if !ok {
|
||||
t.Fatalf("Expected handler to be type Ext, got: %#v", handler)
|
||||
}
|
||||
|
||||
if myHandler.Extensions[0] != ".html" {
|
||||
t.Errorf("Expected .html in the list of Extensions")
|
||||
}
|
||||
if myHandler.Extensions[1] != ".htm" {
|
||||
t.Errorf("Expected .htm in the list of Extensions")
|
||||
}
|
||||
if myHandler.Extensions[2] != ".php" {
|
||||
t.Errorf("Expected .php in the list of Extensions")
|
||||
}
|
||||
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
|
||||
t.Error("'Next' field of handler was not set properly")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestExtParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
inputExts string
|
||||
shouldErr bool
|
||||
expectedExts []string
|
||||
}{
|
||||
{`ext .html .htm .php`, false, []string{".html", ".htm", ".php"}},
|
||||
{`ext .php .html .xml`, false, []string{".php", ".html", ".xml"}},
|
||||
{`ext .txt .php .xml`, false, []string{".txt", ".php", ".xml"}},
|
||||
}
|
||||
for i, test := range tests {
|
||||
actualExts, err := extParse(caddy.NewTestController(test.inputExts))
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Errorf("Test %d didn't error, but it should have", i)
|
||||
} else if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
|
||||
}
|
||||
|
||||
if len(actualExts) != len(test.expectedExts) {
|
||||
t.Fatalf("Test %d expected %d rules, but got %d",
|
||||
i, len(test.expectedExts), len(actualExts))
|
||||
}
|
||||
for j, actualExt := range actualExts {
|
||||
if actualExt != test.expectedExts[j] {
|
||||
t.Fatalf("Test %d expected %dth extension to be %s , but got %s",
|
||||
i, j, test.expectedExts[j], actualExt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
51
caddyhttp/header/header.go
Normal file
51
caddyhttp/header/header.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
// Package header provides middleware that appends headers to
|
||||
// requests based on a set of configuration rules that define
|
||||
// which routes receive which headers.
|
||||
package header
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// Headers is middleware that adds headers to the responses
|
||||
// for requests matching a certain path.
|
||||
type Headers struct {
|
||||
Next httpserver.Handler
|
||||
Rules []Rule
|
||||
}
|
||||
|
||||
// ServeHTTP implements the httpserver.Handler interface and serves requests,
|
||||
// setting headers on the response according to the configured rules.
|
||||
func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
replacer := httpserver.NewReplacer(r, nil, "")
|
||||
for _, rule := range h.Rules {
|
||||
if httpserver.Path(r.URL.Path).Matches(rule.Path) {
|
||||
for _, header := range rule.Headers {
|
||||
if strings.HasPrefix(header.Name, "-") {
|
||||
w.Header().Del(strings.TrimLeft(header.Name, "-"))
|
||||
} else {
|
||||
w.Header().Set(header.Name, replacer.Replace(header.Value))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return h.Next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
type (
|
||||
// Rule groups a slice of HTTP headers by a URL pattern.
|
||||
// TODO: use http.Header type instead?
|
||||
Rule struct {
|
||||
Path string
|
||||
Headers []Header
|
||||
}
|
||||
|
||||
// Header represents a single HTTP header, simply a name and value.
|
||||
Header struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
)
|
57
caddyhttp/header/header_test.go
Normal file
57
caddyhttp/header/header_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package header
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestHeader(t *testing.T) {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
t.Fatalf("Could not determine hostname: %v", err)
|
||||
}
|
||||
for i, test := range []struct {
|
||||
from string
|
||||
name string
|
||||
value string
|
||||
}{
|
||||
{"/a", "Foo", "Bar"},
|
||||
{"/a", "Bar", ""},
|
||||
{"/a", "Baz", ""},
|
||||
{"/a", "ServerName", hostname},
|
||||
{"/b", "Foo", ""},
|
||||
{"/b", "Bar", "Removed in /a"},
|
||||
} {
|
||||
he := Headers{
|
||||
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
return 0, nil
|
||||
}),
|
||||
Rules: []Rule{
|
||||
{Path: "/a", Headers: []Header{
|
||||
{Name: "Foo", Value: "Bar"},
|
||||
{Name: "ServerName", Value: "{hostname}"},
|
||||
{Name: "-Bar"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", test.from, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
rec.Header().Set("Bar", "Removed in /a")
|
||||
|
||||
he.ServeHTTP(rec, req)
|
||||
|
||||
if got := rec.Header().Get(test.name); got != test.value {
|
||||
t.Errorf("Test %d: Expected %s header to be %q but was %q",
|
||||
i, test.name, test.value, got)
|
||||
}
|
||||
}
|
||||
}
|
94
caddyhttp/header/setup.go
Normal file
94
caddyhttp/header/setup.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package header
|
||||
|
||||
import (
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin(caddy.Plugin{
|
||||
Name: "header",
|
||||
ServerType: "http",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// setup configures a new Headers middleware instance.
|
||||
func setup(c *caddy.Controller) error {
|
||||
rules, err := headersParse(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
return Headers{Next: next, Rules: rules}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func headersParse(c *caddy.Controller) ([]Rule, error) {
|
||||
var rules []Rule
|
||||
|
||||
for c.NextLine() {
|
||||
var head Rule
|
||||
var isNewPattern bool
|
||||
|
||||
if !c.NextArg() {
|
||||
return rules, c.ArgErr()
|
||||
}
|
||||
pattern := c.Val()
|
||||
|
||||
// See if we already have a definition for this Path pattern...
|
||||
for _, h := range rules {
|
||||
if h.Path == pattern {
|
||||
head = h
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// ...otherwise, this is a new pattern
|
||||
if head.Path == "" {
|
||||
head.Path = pattern
|
||||
isNewPattern = true
|
||||
}
|
||||
|
||||
for c.NextBlock() {
|
||||
// A block of headers was opened...
|
||||
|
||||
h := Header{Name: c.Val()}
|
||||
|
||||
if c.NextArg() {
|
||||
h.Value = c.Val()
|
||||
}
|
||||
|
||||
head.Headers = append(head.Headers, h)
|
||||
}
|
||||
if c.NextArg() {
|
||||
// ... or single header was defined as an argument instead.
|
||||
|
||||
h := Header{Name: c.Val()}
|
||||
|
||||
h.Value = c.Val()
|
||||
|
||||
if c.NextArg() {
|
||||
h.Value = c.Val()
|
||||
}
|
||||
|
||||
head.Headers = append(head.Headers, h)
|
||||
}
|
||||
|
||||
if isNewPattern {
|
||||
rules = append(rules, head)
|
||||
} else {
|
||||
for i := 0; i < len(rules); i++ {
|
||||
if rules[i].Path == pattern {
|
||||
rules[i] = head
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
85
caddyhttp/header/setup_test.go
Normal file
85
caddyhttp/header/setup_test.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package header
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
err := setup(caddy.NewTestController(`header / Foo Bar`))
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, but got: %v", err)
|
||||
}
|
||||
|
||||
mids := httpserver.GetConfig("").Middleware()
|
||||
if len(mids) == 0 {
|
||||
t.Fatal("Expected middleware, had 0 instead")
|
||||
}
|
||||
|
||||
handler := mids[0](httpserver.EmptyNext)
|
||||
myHandler, ok := handler.(Headers)
|
||||
if !ok {
|
||||
t.Fatalf("Expected handler to be type Headers, got: %#v", handler)
|
||||
}
|
||||
|
||||
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
|
||||
t.Error("'Next' field of handler was not set properly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadersParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expected []Rule
|
||||
}{
|
||||
{`header /foo Foo "Bar Baz"`,
|
||||
false, []Rule{
|
||||
{Path: "/foo", Headers: []Header{
|
||||
{Name: "Foo", Value: "Bar Baz"},
|
||||
}},
|
||||
}},
|
||||
{`header /bar { Foo "Bar Baz" Baz Qux }`,
|
||||
false, []Rule{
|
||||
{Path: "/bar", Headers: []Header{
|
||||
{Name: "Foo", Value: "Bar Baz"},
|
||||
{Name: "Baz", Value: "Qux"},
|
||||
}},
|
||||
}},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
actual, err := headersParse(caddy.NewTestController(test.input))
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Errorf("Test %d didn't error, but it should have", i)
|
||||
} else if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(test.expected) {
|
||||
t.Fatalf("Test %d expected %d rules, but got %d",
|
||||
i, len(test.expected), len(actual))
|
||||
}
|
||||
|
||||
for j, expectedRule := range test.expected {
|
||||
actualRule := actual[j]
|
||||
|
||||
if actualRule.Path != expectedRule.Path {
|
||||
t.Errorf("Test %d, rule %d: Expected path %s, but got %s",
|
||||
i, j, expectedRule.Path, actualRule.Path)
|
||||
}
|
||||
|
||||
expectedHeaders := fmt.Sprintf("%v", expectedRule.Headers)
|
||||
actualHeaders := fmt.Sprintf("%v", actualRule.Headers)
|
||||
|
||||
if actualHeaders != expectedHeaders {
|
||||
t.Errorf("Test %d, rule %d: Expected headers %s, but got %s",
|
||||
i, j, expectedHeaders, actualHeaders)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
93
caddyhttp/internalsrv/internal.go
Normal file
93
caddyhttp/internalsrv/internal.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
// Package internalsrv provides a simple middleware that (a) prevents access
|
||||
// to internal locations and (b) allows to return files from internal location
|
||||
// by setting a special header, e.g. in a proxy response.
|
||||
//
|
||||
// The package is named internalsrv so as not to conflict with Go tooling
|
||||
// convention which treats folders called "internal" differently.
|
||||
package internalsrv
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// Internal middleware protects internal locations from external requests -
|
||||
// but allows access from the inside by using a special HTTP header.
|
||||
type Internal struct {
|
||||
Next httpserver.Handler
|
||||
Paths []string
|
||||
}
|
||||
|
||||
const (
|
||||
redirectHeader string = "X-Accel-Redirect"
|
||||
maxRedirectCount int = 10
|
||||
)
|
||||
|
||||
func isInternalRedirect(w http.ResponseWriter) bool {
|
||||
return w.Header().Get(redirectHeader) != ""
|
||||
}
|
||||
|
||||
// ServeHTTP implements the httpserver.Handler interface.
|
||||
func (i Internal) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
|
||||
// Internal location requested? -> Not found.
|
||||
for _, prefix := range i.Paths {
|
||||
if httpserver.Path(r.URL.Path).Matches(prefix) {
|
||||
return http.StatusNotFound, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Use internal response writer to ignore responses that will be
|
||||
// redirected to internal locations
|
||||
iw := internalResponseWriter{ResponseWriter: w}
|
||||
status, err := i.Next.ServeHTTP(iw, r)
|
||||
|
||||
for c := 0; c < maxRedirectCount && isInternalRedirect(iw); c++ {
|
||||
// Redirect - adapt request URL path and send it again
|
||||
// "down the chain"
|
||||
r.URL.Path = iw.Header().Get(redirectHeader)
|
||||
iw.ClearHeader()
|
||||
|
||||
status, err = i.Next.ServeHTTP(iw, r)
|
||||
}
|
||||
|
||||
if isInternalRedirect(iw) {
|
||||
// Too many redirect cycles
|
||||
iw.ClearHeader()
|
||||
return http.StatusInternalServerError, nil
|
||||
}
|
||||
|
||||
return status, err
|
||||
}
|
||||
|
||||
// internalResponseWriter wraps the underlying http.ResponseWriter and ignores
|
||||
// calls to Write and WriteHeader if the response should be redirected to an
|
||||
// internal location.
|
||||
type internalResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
// ClearHeader removes all header fields that are already set.
|
||||
func (w internalResponseWriter) ClearHeader() {
|
||||
for k := range w.Header() {
|
||||
w.Header().Del(k)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteHeader ignores the call if the response should be redirected to an
|
||||
// internal location.
|
||||
func (w internalResponseWriter) WriteHeader(code int) {
|
||||
if !isInternalRedirect(w) {
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
}
|
||||
|
||||
// Write ignores the call if the response should be redirected to an internal
|
||||
// location.
|
||||
func (w internalResponseWriter) Write(b []byte) (int, error) {
|
||||
if isInternalRedirect(w) {
|
||||
return 0, nil
|
||||
}
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
64
caddyhttp/internalsrv/internal_test.go
Normal file
64
caddyhttp/internalsrv/internal_test.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package internalsrv
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestInternal(t *testing.T) {
|
||||
im := Internal{
|
||||
Next: httpserver.HandlerFunc(internalTestHandlerFunc),
|
||||
Paths: []string{"/internal"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
url string
|
||||
expectedCode int
|
||||
expectedBody string
|
||||
}{
|
||||
{"/internal", http.StatusNotFound, ""},
|
||||
|
||||
{"/public", 0, "/public"},
|
||||
{"/public/internal", 0, "/public/internal"},
|
||||
|
||||
{"/redirect", 0, "/internal"},
|
||||
|
||||
{"/cycle", http.StatusInternalServerError, ""},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
req, err := http.NewRequest("GET", test.url, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
code, _ := im.ServeHTTP(rec, req)
|
||||
|
||||
if code != test.expectedCode {
|
||||
t.Errorf("Test %d: Expected status code %d for %s, but got %d",
|
||||
i, test.expectedCode, test.url, code)
|
||||
}
|
||||
if rec.Body.String() != test.expectedBody {
|
||||
t.Errorf("Test %d: Expected body '%s' for %s, but got '%s'",
|
||||
i, test.expectedBody, test.url, rec.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func internalTestHandlerFunc(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
switch r.URL.Path {
|
||||
case "/redirect":
|
||||
w.Header().Set("X-Accel-Redirect", "/internal")
|
||||
case "/cycle":
|
||||
w.Header().Set("X-Accel-Redirect", "/cycle")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, r.URL.String())
|
||||
|
||||
return 0, nil
|
||||
}
|
41
caddyhttp/internalsrv/setup.go
Normal file
41
caddyhttp/internalsrv/setup.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package internalsrv
|
||||
|
||||
import (
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin(caddy.Plugin{
|
||||
Name: "internal",
|
||||
ServerType: "http",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// Internal configures a new Internal middleware instance.
|
||||
func setup(c *caddy.Controller) error {
|
||||
paths, err := internalParse(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
return Internal{Next: next, Paths: paths}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func internalParse(c *caddy.Controller) ([]string, error) {
|
||||
var paths []string
|
||||
|
||||
for c.Next() {
|
||||
if !c.NextArg() {
|
||||
return paths, c.ArgErr()
|
||||
}
|
||||
paths = append(paths, c.Val())
|
||||
}
|
||||
|
||||
return paths, nil
|
||||
}
|
69
caddyhttp/internalsrv/setup_test.go
Normal file
69
caddyhttp/internalsrv/setup_test.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package internalsrv
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
err := setup(caddy.NewTestController(`internal /internal`))
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, got: %v", err)
|
||||
}
|
||||
mids := httpserver.GetConfig("").Middleware()
|
||||
if len(mids) == 0 {
|
||||
t.Fatal("Expected middleware, got 0 instead")
|
||||
}
|
||||
|
||||
handler := mids[0](httpserver.EmptyNext)
|
||||
myHandler, ok := handler.(Internal)
|
||||
|
||||
if !ok {
|
||||
t.Fatalf("Expected handler to be type Internal, got: %#v", handler)
|
||||
}
|
||||
|
||||
if myHandler.Paths[0] != "/internal" {
|
||||
t.Errorf("Expected internal in the list of internal Paths")
|
||||
}
|
||||
|
||||
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
|
||||
t.Error("'Next' field of handler was not set properly")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestInternalParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
inputInternalPaths string
|
||||
shouldErr bool
|
||||
expectedInternalPaths []string
|
||||
}{
|
||||
{`internal /internal`, false, []string{"/internal"}},
|
||||
|
||||
{`internal /internal1
|
||||
internal /internal2`, false, []string{"/internal1", "/internal2"}},
|
||||
}
|
||||
for i, test := range tests {
|
||||
actualInternalPaths, err := internalParse(caddy.NewTestController(test.inputInternalPaths))
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Errorf("Test %d didn't error, but it should have", i)
|
||||
} else if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
|
||||
}
|
||||
|
||||
if len(actualInternalPaths) != len(test.expectedInternalPaths) {
|
||||
t.Fatalf("Test %d expected %d InternalPaths, but got %d",
|
||||
i, len(test.expectedInternalPaths), len(actualInternalPaths))
|
||||
}
|
||||
for j, actualInternalPath := range actualInternalPaths {
|
||||
if actualInternalPath != test.expectedInternalPaths[j] {
|
||||
t.Fatalf("Test %d expected %dth Internal Path to be %s , but got %s",
|
||||
i, j, test.expectedInternalPaths[j], actualInternalPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
31
caddyhttp/mime/mime.go
Normal file
31
caddyhttp/mime/mime.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package mime
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// Config represent a mime config. Map from extension to mime-type.
|
||||
// Note, this should be safe with concurrent read access, as this is
|
||||
// not modified concurrently.
|
||||
type Config map[string]string
|
||||
|
||||
// Mime sets Content-Type header of requests based on configurations.
|
||||
type Mime struct {
|
||||
Next httpserver.Handler
|
||||
Configs Config
|
||||
}
|
||||
|
||||
// ServeHTTP implements the httpserver.Handler interface.
|
||||
func (e Mime) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
// Get a clean /-path, grab the extension
|
||||
ext := path.Ext(path.Clean(r.URL.Path))
|
||||
|
||||
if contentType, ok := e.Configs[ext]; ok {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
}
|
||||
|
||||
return e.Next.ServeHTTP(w, r)
|
||||
}
|
69
caddyhttp/mime/mime_test.go
Normal file
69
caddyhttp/mime/mime_test.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package mime
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestMimeHandler(t *testing.T) {
|
||||
mimes := Config{
|
||||
".html": "text/html",
|
||||
".txt": "text/plain",
|
||||
".swf": "application/x-shockwave-flash",
|
||||
}
|
||||
|
||||
m := Mime{Configs: mimes}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
exts := []string{
|
||||
".html", ".txt", ".swf",
|
||||
}
|
||||
for _, e := range exts {
|
||||
url := "/file" + e
|
||||
r, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
m.Next = nextFunc(true, mimes[e])
|
||||
_, err = m.ServeHTTP(w, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
exts = []string{
|
||||
".htm1", ".abc", ".mdx",
|
||||
}
|
||||
for _, e := range exts {
|
||||
url := "/file" + e
|
||||
r, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
m.Next = nextFunc(false, "")
|
||||
_, err = m.ServeHTTP(w, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func nextFunc(shouldMime bool, contentType string) httpserver.Handler {
|
||||
return httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
if shouldMime {
|
||||
if w.Header().Get("Content-Type") != contentType {
|
||||
return 0, fmt.Errorf("expected Content-Type: %v, found %v", contentType, r.Header.Get("Content-Type"))
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
if w.Header().Get("Content-Type") != "" {
|
||||
return 0, fmt.Errorf("Content-Type header not expected")
|
||||
}
|
||||
return 0, nil
|
||||
})
|
||||
}
|
75
caddyhttp/mime/setup.go
Normal file
75
caddyhttp/mime/setup.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package mime
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin(caddy.Plugin{
|
||||
Name: "mime",
|
||||
ServerType: "http",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// setup configures a new mime middleware instance.
|
||||
func setup(c *caddy.Controller) error {
|
||||
configs, err := mimeParse(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
return Mime{Next: next, Configs: configs}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mimeParse(c *caddy.Controller) (Config, error) {
|
||||
configs := Config{}
|
||||
|
||||
for c.Next() {
|
||||
// At least one extension is required
|
||||
|
||||
args := c.RemainingArgs()
|
||||
switch len(args) {
|
||||
case 2:
|
||||
if err := validateExt(configs, args[0]); err != nil {
|
||||
return configs, err
|
||||
}
|
||||
configs[args[0]] = args[1]
|
||||
case 1:
|
||||
return configs, c.ArgErr()
|
||||
case 0:
|
||||
for c.NextBlock() {
|
||||
ext := c.Val()
|
||||
if err := validateExt(configs, ext); err != nil {
|
||||
return configs, err
|
||||
}
|
||||
if !c.NextArg() {
|
||||
return configs, c.ArgErr()
|
||||
}
|
||||
configs[ext] = c.Val()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
// validateExt checks for valid file name extension.
|
||||
func validateExt(configs Config, ext string) error {
|
||||
if !strings.HasPrefix(ext, ".") {
|
||||
return fmt.Errorf(`mime: invalid extension "%v" (must start with dot)`, ext)
|
||||
}
|
||||
if _, ok := configs[ext]; ok {
|
||||
return fmt.Errorf(`mime: duplicate extension "%v" found`, ext)
|
||||
}
|
||||
return nil
|
||||
}
|
62
caddyhttp/mime/setup_test.go
Normal file
62
caddyhttp/mime/setup_test.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package mime
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
err := setup(caddy.NewTestController(`mime .txt text/plain`))
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, but got: %v", err)
|
||||
}
|
||||
mids := httpserver.GetConfig("").Middleware()
|
||||
if len(mids) == 0 {
|
||||
t.Fatal("Expected middleware, but had 0 instead")
|
||||
}
|
||||
|
||||
handler := mids[0](httpserver.EmptyNext)
|
||||
myHandler, ok := handler.(Mime)
|
||||
if !ok {
|
||||
t.Fatalf("Expected handler to be type Mime, got: %#v", handler)
|
||||
}
|
||||
|
||||
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
|
||||
t.Error("'Next' field of handler was not set properly")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
}{
|
||||
{`mime {`, true},
|
||||
{`mime {}`, true},
|
||||
{`mime a b`, true},
|
||||
{`mime a {`, true},
|
||||
{`mime { txt f } `, true},
|
||||
{`mime { html } `, true},
|
||||
{`mime {
|
||||
.html text/html
|
||||
.txt text/plain
|
||||
} `, false},
|
||||
{`mime {
|
||||
.foo text/foo
|
||||
.bar text/bar
|
||||
.foo text/foobar
|
||||
} `, true},
|
||||
{`mime { .html text/html } `, false},
|
||||
{`mime { .html
|
||||
} `, true},
|
||||
{`mime .txt text/plain`, false},
|
||||
}
|
||||
for i, test := range tests {
|
||||
m, err := mimeParse(caddy.NewTestController(test.input))
|
||||
if test.shouldErr && err == nil {
|
||||
t.Errorf("Test %v: Expected error but found nil %v", i, m)
|
||||
} else if !test.shouldErr && err != nil {
|
||||
t.Errorf("Test %v: Expected no error but found error: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
41
caddyhttp/pprof/pprof.go
Normal file
41
caddyhttp/pprof/pprof.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package pprof
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
pp "net/http/pprof"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// BasePath is the base path to match for all pprof requests.
|
||||
const BasePath = "/debug/pprof"
|
||||
|
||||
// Handler is a simple struct whose ServeHTTP will delegate pprof
|
||||
// endpoints to their equivalent net/http/pprof handlers.
|
||||
type Handler struct {
|
||||
Next httpserver.Handler
|
||||
Mux *http.ServeMux
|
||||
}
|
||||
|
||||
// ServeHTTP handles requests to BasePath with pprof, or passes
|
||||
// all other requests up the chain.
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
if httpserver.Path(r.URL.Path).Matches(BasePath) {
|
||||
h.Mux.ServeHTTP(w, r)
|
||||
return 0, nil
|
||||
}
|
||||
return h.Next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// NewMux returns a new http.ServeMux that routes pprof requests.
|
||||
// It pretty much copies what the std lib pprof does on init:
|
||||
// https://golang.org/src/net/http/pprof/pprof.go#L67
|
||||
func NewMux() *http.ServeMux {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(BasePath+"/", pp.Index)
|
||||
mux.HandleFunc(BasePath+"/cmdline", pp.Cmdline)
|
||||
mux.HandleFunc(BasePath+"/profile", pp.Profile)
|
||||
mux.HandleFunc(BasePath+"/symbol", pp.Symbol)
|
||||
mux.HandleFunc(BasePath+"/trace", pp.Trace)
|
||||
return mux
|
||||
}
|
55
caddyhttp/pprof/pprof_test.go
Normal file
55
caddyhttp/pprof/pprof_test.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package pprof
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestServeHTTP(t *testing.T) {
|
||||
h := Handler{
|
||||
Next: httpserver.HandlerFunc(nextHandler),
|
||||
Mux: NewMux(),
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/debug/pprof", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
status, err := h.ServeHTTP(w, r)
|
||||
|
||||
if status != 0 {
|
||||
t.Errorf("Expected status %d but got %d", 0, status)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, but got: %v", err)
|
||||
}
|
||||
if w.Body.String() == "content" {
|
||||
t.Errorf("Expected pprof to handle request, but it didn't")
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
r, err = http.NewRequest("GET", "/foo", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
status, err = h.ServeHTTP(w, r)
|
||||
if status != http.StatusNotFound {
|
||||
t.Errorf("Test two: Expected status %d but got %d", http.StatusNotFound, status)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Test two: Expected nil error, but got: %v", err)
|
||||
}
|
||||
if w.Body.String() != "content" {
|
||||
t.Errorf("Expected pprof to pass the request thru, but it didn't; got: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func nextHandler(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
fmt.Fprintf(w, "content")
|
||||
return http.StatusNotFound, nil
|
||||
}
|
38
caddyhttp/pprof/setup.go
Normal file
38
caddyhttp/pprof/setup.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package pprof
|
||||
|
||||
import (
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin(caddy.Plugin{
|
||||
Name: "pprof",
|
||||
ServerType: "http",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// setup returns a new instance of a pprof handler. It accepts no arguments or options.
|
||||
func setup(c *caddy.Controller) error {
|
||||
found := false
|
||||
|
||||
for c.Next() {
|
||||
if found {
|
||||
return c.Err("pprof can only be specified once")
|
||||
}
|
||||
if len(c.RemainingArgs()) != 0 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
if c.NextBlock() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
found = true
|
||||
}
|
||||
|
||||
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
return &Handler{Next: next, Mux: NewMux()}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
31
caddyhttp/pprof/setup_test.go
Normal file
31
caddyhttp/pprof/setup_test.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package pprof
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
}{
|
||||
{`pprof`, false},
|
||||
{`pprof {}`, true},
|
||||
{`pprof /foo`, true},
|
||||
{`pprof {
|
||||
a b
|
||||
}`, true},
|
||||
{`pprof
|
||||
pprof`, true},
|
||||
}
|
||||
for i, test := range tests {
|
||||
err := setup(caddy.NewTestController(test.input))
|
||||
if test.shouldErr && err == nil {
|
||||
t.Errorf("Test %v: Expected error but found nil", i)
|
||||
} else if !test.shouldErr && err != nil {
|
||||
t.Errorf("Test %v: Expected no error but found error: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
101
caddyhttp/proxy/policy.go
Normal file
101
caddyhttp/proxy/policy.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// HostPool is a collection of UpstreamHosts.
|
||||
type HostPool []*UpstreamHost
|
||||
|
||||
// Policy decides how a host will be selected from a pool.
|
||||
type Policy interface {
|
||||
Select(pool HostPool) *UpstreamHost
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterPolicy("random", func() Policy { return &Random{} })
|
||||
RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
|
||||
RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
|
||||
}
|
||||
|
||||
// Random is a policy that selects up hosts from a pool at random.
|
||||
type Random struct{}
|
||||
|
||||
// Select selects an up host at random from the specified pool.
|
||||
func (r *Random) Select(pool HostPool) *UpstreamHost {
|
||||
// instead of just generating a random index
|
||||
// this is done to prevent selecting a unavailable host
|
||||
var randHost *UpstreamHost
|
||||
count := 0
|
||||
for _, host := range pool {
|
||||
if !host.Available() {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
if count == 1 {
|
||||
randHost = host
|
||||
} else {
|
||||
r := rand.Int() % count
|
||||
if r == (count - 1) {
|
||||
randHost = host
|
||||
}
|
||||
}
|
||||
}
|
||||
return randHost
|
||||
}
|
||||
|
||||
// LeastConn is a policy that selects the host with the least connections.
|
||||
type LeastConn struct{}
|
||||
|
||||
// Select selects the up host with the least number of connections in the
|
||||
// pool. If more than one host has the same least number of connections,
|
||||
// one of the hosts is chosen at random.
|
||||
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
|
||||
var bestHost *UpstreamHost
|
||||
count := 0
|
||||
leastConn := int64(1<<63 - 1)
|
||||
for _, host := range pool {
|
||||
if !host.Available() {
|
||||
continue
|
||||
}
|
||||
hostConns := host.Conns
|
||||
if hostConns < leastConn {
|
||||
bestHost = host
|
||||
leastConn = hostConns
|
||||
count = 1
|
||||
} else if hostConns == leastConn {
|
||||
// randomly select host among hosts with least connections
|
||||
count++
|
||||
if count == 1 {
|
||||
bestHost = host
|
||||
} else {
|
||||
r := rand.Int() % count
|
||||
if r == (count - 1) {
|
||||
bestHost = host
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return bestHost
|
||||
}
|
||||
|
||||
// RoundRobin is a policy that selects hosts based on round robin ordering.
|
||||
type RoundRobin struct {
|
||||
Robin uint32
|
||||
}
|
||||
|
||||
// Select selects an up host from the pool using a round robin ordering scheme.
|
||||
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
|
||||
poolLen := uint32(len(pool))
|
||||
selection := atomic.AddUint32(&r.Robin, 1) % poolLen
|
||||
host := pool[selection]
|
||||
// if the currently selected host is not available, just ffwd to up host
|
||||
for i := uint32(1); !host.Available() && i < poolLen; i++ {
|
||||
host = pool[(selection+i)%poolLen]
|
||||
}
|
||||
if !host.Available() {
|
||||
return nil
|
||||
}
|
||||
return host
|
||||
}
|
98
caddyhttp/proxy/policy_test.go
Normal file
98
caddyhttp/proxy/policy_test.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var workableServer *httptest.Server
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
workableServer = httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
// do nothing
|
||||
}))
|
||||
r := m.Run()
|
||||
workableServer.Close()
|
||||
os.Exit(r)
|
||||
}
|
||||
|
||||
type customPolicy struct{}
|
||||
|
||||
func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
|
||||
return pool[0]
|
||||
}
|
||||
|
||||
func testPool() HostPool {
|
||||
pool := []*UpstreamHost{
|
||||
{
|
||||
Name: workableServer.URL, // this should resolve (healthcheck test)
|
||||
},
|
||||
{
|
||||
Name: "http://shouldnot.resolve", // this shouldn't
|
||||
},
|
||||
{
|
||||
Name: "http://C",
|
||||
},
|
||||
}
|
||||
return HostPool(pool)
|
||||
}
|
||||
|
||||
func TestRoundRobinPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
rrPolicy := &RoundRobin{}
|
||||
h := rrPolicy.Select(pool)
|
||||
// First selected host is 1, because counter starts at 0
|
||||
// and increments before host is selected
|
||||
if h != pool[1] {
|
||||
t.Error("Expected first round robin host to be second host in the pool.")
|
||||
}
|
||||
h = rrPolicy.Select(pool)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected second round robin host to be third host in the pool.")
|
||||
}
|
||||
h = rrPolicy.Select(pool)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected third round robin host to be first host in the pool.")
|
||||
}
|
||||
// mark host as down
|
||||
pool[1].Unhealthy = true
|
||||
h = rrPolicy.Select(pool)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected to skip down host.")
|
||||
}
|
||||
// mark host as full
|
||||
pool[2].Conns = 1
|
||||
pool[2].MaxConns = 1
|
||||
h = rrPolicy.Select(pool)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected to skip full host.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastConnPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
lcPolicy := &LeastConn{}
|
||||
pool[0].Conns = 10
|
||||
pool[1].Conns = 10
|
||||
h := lcPolicy.Select(pool)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected least connection host to be third host.")
|
||||
}
|
||||
pool[2].Conns = 100
|
||||
h = lcPolicy.Select(pool)
|
||||
if h != pool[0] && h != pool[1] {
|
||||
t.Error("Expected least connection host to be first or second host.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
customPolicy := &customPolicy{}
|
||||
h := customPolicy.Select(pool)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected custom policy host to be the first host.")
|
||||
}
|
||||
}
|
242
caddyhttp/proxy/proxy.go
Normal file
242
caddyhttp/proxy/proxy.go
Normal file
|
@ -0,0 +1,242 @@
|
|||
// Package proxy is middleware that proxies HTTP requests.
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
var errUnreachable = errors.New("unreachable backend")
|
||||
|
||||
// Proxy represents a middleware instance that can proxy requests.
|
||||
type Proxy struct {
|
||||
Next httpserver.Handler
|
||||
Upstreams []Upstream
|
||||
}
|
||||
|
||||
// Upstream manages a pool of proxy upstream hosts. Select should return a
|
||||
// suitable upstream host, or nil if no such hosts are available.
|
||||
type Upstream interface {
|
||||
// The path this upstream host should be routed on
|
||||
From() string
|
||||
// Selects an upstream host to be routed to.
|
||||
Select() *UpstreamHost
|
||||
// Checks if subpath is not an ignored path
|
||||
AllowedPath(string) bool
|
||||
}
|
||||
|
||||
// UpstreamHostDownFunc can be used to customize how Down behaves.
|
||||
type UpstreamHostDownFunc func(*UpstreamHost) bool
|
||||
|
||||
// UpstreamHost represents a single proxy upstream
|
||||
type UpstreamHost struct {
|
||||
Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
|
||||
Name string // hostname of this upstream host
|
||||
ReverseProxy *ReverseProxy
|
||||
Fails int32
|
||||
FailTimeout time.Duration
|
||||
Unhealthy bool
|
||||
UpstreamHeaders http.Header
|
||||
DownstreamHeaders http.Header
|
||||
CheckDown UpstreamHostDownFunc
|
||||
WithoutPathPrefix string
|
||||
MaxConns int64
|
||||
}
|
||||
|
||||
// Down checks whether the upstream host is down or not.
|
||||
// Down will try to use uh.CheckDown first, and will fall
|
||||
// back to some default criteria if necessary.
|
||||
func (uh *UpstreamHost) Down() bool {
|
||||
if uh.CheckDown == nil {
|
||||
// Default settings
|
||||
return uh.Unhealthy || uh.Fails > 0
|
||||
}
|
||||
return uh.CheckDown(uh)
|
||||
}
|
||||
|
||||
// Full checks whether the upstream host has reached its maximum connections
|
||||
func (uh *UpstreamHost) Full() bool {
|
||||
return uh.MaxConns > 0 && uh.Conns >= uh.MaxConns
|
||||
}
|
||||
|
||||
// Available checks whether the upstream host is available for proxying to
|
||||
func (uh *UpstreamHost) Available() bool {
|
||||
return !uh.Down() && !uh.Full()
|
||||
}
|
||||
|
||||
// tryDuration is how long to try upstream hosts; failures result in
|
||||
// immediate retries until this duration ends or we get a nil host.
|
||||
var tryDuration = 60 * time.Second
|
||||
|
||||
// ServeHTTP satisfies the httpserver.Handler interface.
|
||||
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
for _, upstream := range p.Upstreams {
|
||||
if !httpserver.Path(r.URL.Path).Matches(upstream.From()) ||
|
||||
!upstream.AllowedPath(r.URL.Path) {
|
||||
continue
|
||||
}
|
||||
|
||||
var replacer httpserver.Replacer
|
||||
start := time.Now()
|
||||
|
||||
outreq := createUpstreamRequest(r)
|
||||
|
||||
// Since Select() should give us "up" hosts, keep retrying
|
||||
// hosts until timeout (or until we get a nil host).
|
||||
for time.Now().Sub(start) < tryDuration {
|
||||
host := upstream.Select()
|
||||
if host == nil {
|
||||
return http.StatusBadGateway, errUnreachable
|
||||
}
|
||||
if rr, ok := w.(*httpserver.ResponseRecorder); ok && rr.Replacer != nil {
|
||||
rr.Replacer.Set("upstream", host.Name)
|
||||
}
|
||||
|
||||
outreq.Host = host.Name
|
||||
if host.UpstreamHeaders != nil {
|
||||
if replacer == nil {
|
||||
rHost := r.Host
|
||||
replacer = httpserver.NewReplacer(r, nil, "")
|
||||
outreq.Host = rHost
|
||||
}
|
||||
if v, ok := host.UpstreamHeaders["Host"]; ok {
|
||||
outreq.Host = replacer.Replace(v[len(v)-1])
|
||||
}
|
||||
// Modify headers for request that will be sent to the upstream host
|
||||
upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer)
|
||||
for k, v := range upHeaders {
|
||||
outreq.Header[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
var downHeaderUpdateFn respUpdateFn
|
||||
if host.DownstreamHeaders != nil {
|
||||
if replacer == nil {
|
||||
rHost := r.Host
|
||||
replacer = httpserver.NewReplacer(r, nil, "")
|
||||
outreq.Host = rHost
|
||||
}
|
||||
//Creates a function that is used to update headers the response received by the reverse proxy
|
||||
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
|
||||
}
|
||||
|
||||
proxy := host.ReverseProxy
|
||||
if baseURL, err := url.Parse(host.Name); err == nil {
|
||||
r.Host = baseURL.Host
|
||||
if proxy == nil {
|
||||
proxy = NewSingleHostReverseProxy(baseURL, host.WithoutPathPrefix)
|
||||
}
|
||||
} else if proxy == nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
atomic.AddInt64(&host.Conns, 1)
|
||||
backendErr := proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
|
||||
atomic.AddInt64(&host.Conns, -1)
|
||||
if backendErr == nil {
|
||||
return 0, nil
|
||||
}
|
||||
timeout := host.FailTimeout
|
||||
if timeout == 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
atomic.AddInt32(&host.Fails, 1)
|
||||
go func(host *UpstreamHost, timeout time.Duration) {
|
||||
time.Sleep(timeout)
|
||||
atomic.AddInt32(&host.Fails, -1)
|
||||
}(host, timeout)
|
||||
}
|
||||
return http.StatusBadGateway, errUnreachable
|
||||
}
|
||||
|
||||
return p.Next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// createUpstremRequest shallow-copies r into a new request
|
||||
// that can be sent upstream.
|
||||
func createUpstreamRequest(r *http.Request) *http.Request {
|
||||
outreq := new(http.Request)
|
||||
*outreq = *r // includes shallow copies of maps, but okay
|
||||
|
||||
// Restore URL Path if it has been modified
|
||||
if outreq.URL.RawPath != "" {
|
||||
outreq.URL.Opaque = outreq.URL.RawPath
|
||||
}
|
||||
|
||||
// Remove hop-by-hop headers to the backend. Especially
|
||||
// important is "Connection" because we want a persistent
|
||||
// connection, regardless of what the client sent to us. This
|
||||
// is modifying the same underlying map from r (shallow
|
||||
// copied above) so we only copy it if necessary.
|
||||
for _, h := range hopHeaders {
|
||||
if outreq.Header.Get(h) != "" {
|
||||
outreq.Header = make(http.Header)
|
||||
copyHeader(outreq.Header, r.Header)
|
||||
outreq.Header.Del(h)
|
||||
}
|
||||
}
|
||||
|
||||
if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
// If we aren't the first proxy, retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
outreq.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
|
||||
return outreq
|
||||
}
|
||||
|
||||
func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn {
|
||||
return func(resp *http.Response) {
|
||||
newHeaders := createHeadersByRules(rules, resp.Header, replacer)
|
||||
for h, v := range newHeaders {
|
||||
resp.Header[h] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createHeadersByRules(rules http.Header, base http.Header, repl httpserver.Replacer) http.Header {
|
||||
newHeaders := make(http.Header)
|
||||
for header, values := range rules {
|
||||
if strings.HasPrefix(header, "+") {
|
||||
header = strings.TrimLeft(header, "+")
|
||||
add(newHeaders, header, base[header])
|
||||
applyEach(values, repl.Replace)
|
||||
add(newHeaders, header, values)
|
||||
} else if strings.HasPrefix(header, "-") {
|
||||
base.Del(strings.TrimLeft(header, "-"))
|
||||
} else if _, ok := base[header]; ok {
|
||||
applyEach(values, repl.Replace)
|
||||
for _, v := range values {
|
||||
newHeaders.Set(header, v)
|
||||
}
|
||||
} else {
|
||||
applyEach(values, repl.Replace)
|
||||
add(newHeaders, header, values)
|
||||
add(newHeaders, header, base[header])
|
||||
}
|
||||
}
|
||||
return newHeaders
|
||||
}
|
||||
|
||||
func applyEach(values []string, mapFn func(string) string) {
|
||||
for i, v := range values {
|
||||
values[i] = mapFn(v)
|
||||
}
|
||||
}
|
||||
|
||||
func add(base http.Header, header string, values []string) {
|
||||
for _, v := range values {
|
||||
base.Add(header, v)
|
||||
}
|
||||
}
|
583
caddyhttp/proxy/proxy_test.go
Normal file
583
caddyhttp/proxy/proxy_test.go
Normal file
|
@ -0,0 +1,583 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func init() {
|
||||
tryDuration = 50 * time.Millisecond // prevent tests from hanging
|
||||
}
|
||||
|
||||
func TestReverseProxy(t *testing.T) {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
var requestReceived bool
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestReceived = true
|
||||
w.Write([]byte("Hello, client"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
if !requestReceived {
|
||||
t.Error("Expected backend to receive request, but it didn't")
|
||||
}
|
||||
|
||||
// Make sure {upstream} placeholder is set
|
||||
rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
|
||||
rr.Replacer = httpserver.NewReplacer(r, rr, "-")
|
||||
|
||||
p.ServeHTTP(rr, r)
|
||||
|
||||
if got, want := rr.Replacer.Replace("{upstream}"), backend.URL; got != want {
|
||||
t.Errorf("Expected custom placeholder {upstream} to be set (%s), but it wasn't; got: %s", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyInsecureSkipVerify(t *testing.T) {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
var requestReceived bool
|
||||
backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestReceived = true
|
||||
w.Write([]byte("Hello, client"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
if !requestReceived {
|
||||
t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
||||
// No-op websocket backend simply allows the WS connection to be
|
||||
// accepted then it will be immediately closed. Perfect for testing.
|
||||
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {}))
|
||||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL)
|
||||
|
||||
// Create client request
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
r.Header = http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
"Upgrade": {"websocket"},
|
||||
"Origin": {wsNop.URL},
|
||||
"Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="},
|
||||
"Sec-WebSocket-Version": {"13"},
|
||||
}
|
||||
|
||||
// Capture the request
|
||||
w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)}
|
||||
|
||||
// Booya! Do the test.
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
// Make sure the backend accepted the WS connection.
|
||||
// Mostly interested in the Upgrade and Connection response headers
|
||||
// and the 101 status code.
|
||||
expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n")
|
||||
actual := w.fakeConn.writeBuf.Bytes()
|
||||
if !bytes.Equal(actual, expected) {
|
||||
t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
||||
// Echo server allows us to test that socket bytes are properly
|
||||
// being proxied.
|
||||
wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
|
||||
io.Copy(ws, ws)
|
||||
}))
|
||||
defer wsEcho.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsEcho.URL)
|
||||
|
||||
// This is a full end-end test, so the proxy handler
|
||||
// has to be part of a server listening on a port. Our
|
||||
// WS client will connect to this test server, not
|
||||
// the echo client directly.
|
||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
}))
|
||||
defer echoProxy.Close()
|
||||
|
||||
// Set up WebSocket client
|
||||
url := strings.Replace(echoProxy.URL, "http://", "ws://", 1)
|
||||
ws, err := websocket.Dial(url, "", echoProxy.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// Send test message
|
||||
trialMsg := "Is it working?"
|
||||
websocket.Message.Send(ws, trialMsg)
|
||||
|
||||
// It should be echoed back to us
|
||||
var actualMsg string
|
||||
websocket.Message.Receive(ws, &actualMsg)
|
||||
if actualMsg != trialMsg {
|
||||
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketProxy(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
trialMsg := "Is it working?"
|
||||
|
||||
var proxySuccess bool
|
||||
|
||||
// This is our fake "application" we want to proxy to
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Request was proxied when this is called
|
||||
proxySuccess = true
|
||||
|
||||
fmt.Fprint(w, trialMsg)
|
||||
}))
|
||||
|
||||
// Get absolute path for unix: socket
|
||||
socketPath, err := filepath.Abs("./test_socket")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get absolute path: %v", err)
|
||||
}
|
||||
|
||||
// Change httptest.Server listener to listen to unix: socket
|
||||
ln, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to listen: %v", err)
|
||||
}
|
||||
ts.Listener = ln
|
||||
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
||||
p := newWebSocketTestProxy(url)
|
||||
|
||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
}))
|
||||
defer echoProxy.Close()
|
||||
|
||||
res, err := http.Get(echoProxy.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to GET: %v", err)
|
||||
}
|
||||
|
||||
greeting, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to GET: %v", err)
|
||||
}
|
||||
|
||||
actualMsg := fmt.Sprintf("%s", greeting)
|
||||
|
||||
if !proxySuccess {
|
||||
t.Errorf("Expected request to be proxied, but it wasn't")
|
||||
}
|
||||
|
||||
if actualMsg != trialMsg {
|
||||
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func GetHTTPProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, messageFormat, r.URL.String())
|
||||
}))
|
||||
|
||||
return newPrefixedWebSocketTestProxy(ts.URL, prefix), ts
|
||||
}
|
||||
|
||||
func GetSocketProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server, error) {
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, messageFormat, r.URL.String())
|
||||
}))
|
||||
|
||||
socketPath, err := filepath.Abs("./test_socket")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Unable to get absolute path: %v", err)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Unable to listen: %v", err)
|
||||
}
|
||||
ts.Listener = ln
|
||||
|
||||
ts.Start()
|
||||
|
||||
tsURL := strings.Replace(ts.URL, "http://", "unix:", 1)
|
||||
|
||||
return newPrefixedWebSocketTestProxy(tsURL, prefix), ts, nil
|
||||
}
|
||||
|
||||
func GetTestServerMessage(p *Proxy, ts *httptest.Server, path string) (string, error) {
|
||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
}))
|
||||
|
||||
// *httptest.Server is passed so it can be `defer`red properly
|
||||
defer ts.Close()
|
||||
defer echoProxy.Close()
|
||||
|
||||
res, err := http.Get(echoProxy.URL + path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Unable to GET: %v", err)
|
||||
}
|
||||
|
||||
greeting, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Unable to read body: %v", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s", greeting), nil
|
||||
}
|
||||
|
||||
func TestUnixSocketProxyPaths(t *testing.T) {
|
||||
greeting := "Hello route %s"
|
||||
|
||||
tests := []struct {
|
||||
url string
|
||||
prefix string
|
||||
expected string
|
||||
}{
|
||||
{"", "", fmt.Sprintf(greeting, "/")},
|
||||
{"/hello", "", fmt.Sprintf(greeting, "/hello")},
|
||||
{"/foo/bar", "", fmt.Sprintf(greeting, "/foo/bar")},
|
||||
{"/foo?bar", "", fmt.Sprintf(greeting, "/foo?bar")},
|
||||
{"/greet?name=john", "", fmt.Sprintf(greeting, "/greet?name=john")},
|
||||
{"/world?wonderful&colorful", "", fmt.Sprintf(greeting, "/world?wonderful&colorful")},
|
||||
{"/proxy/hello", "/proxy", fmt.Sprintf(greeting, "/hello")},
|
||||
{"/proxy/foo/bar", "/proxy", fmt.Sprintf(greeting, "/foo/bar")},
|
||||
{"/proxy/?foo=bar", "/proxy", fmt.Sprintf(greeting, "/?foo=bar")},
|
||||
{"/queues/%2F/fetchtasks", "", fmt.Sprintf(greeting, "/queues/%2F/fetchtasks")},
|
||||
{"/queues/%2F/fetchtasks?foo=bar", "", fmt.Sprintf(greeting, "/queues/%2F/fetchtasks?foo=bar")},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
p, ts := GetHTTPProxy(greeting, test.prefix)
|
||||
|
||||
actualMsg, err := GetTestServerMessage(p, ts, test.url)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Getting server message failed - %v", err)
|
||||
}
|
||||
|
||||
if actualMsg != test.expected {
|
||||
t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg)
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
p, ts, err := GetSocketProxy(greeting, test.prefix)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Getting socket proxy failed - %v", err)
|
||||
}
|
||||
|
||||
actualMsg, err := GetTestServerMessage(p, ts, test.url)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Getting server message failed - %v", err)
|
||||
}
|
||||
|
||||
if actualMsg != test.expected {
|
||||
t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamHeadersUpdate(t *testing.T) {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
var actualHeaders http.Header
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hello, client"))
|
||||
actualHeaders = r.Header
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream.host.UpstreamHeaders = http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"},
|
||||
"+Merge-Me": {"Merge-Value"},
|
||||
"+Add-Me": {"Add-Value"},
|
||||
"-Remove-Me": {""},
|
||||
"Replace-Me": {"{hostname}"},
|
||||
}
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Upstreams: []Upstream{upstream},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
//add initial headers
|
||||
r.Header.Add("Merge-Me", "Initial")
|
||||
r.Header.Add("Remove-Me", "Remove-Value")
|
||||
r.Header.Add("Replace-Me", "Replace-Value")
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
replacer := httpserver.NewReplacer(r, nil, "")
|
||||
|
||||
headerKey := "Merge-Me"
|
||||
values, ok := actualHeaders[headerKey]
|
||||
if !ok {
|
||||
t.Errorf("Request sent to upstream backend does not contain expected %v header. Expected header to be added", headerKey)
|
||||
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
|
||||
t.Errorf("Values for proxy header `+Merge-Me` should be merged. Got %v", values)
|
||||
}
|
||||
|
||||
headerKey = "Add-Me"
|
||||
if _, ok := actualHeaders[headerKey]; !ok {
|
||||
t.Errorf("Request sent to upstream backend does not contain expected %v header", headerKey)
|
||||
}
|
||||
|
||||
headerKey = "Remove-Me"
|
||||
if _, ok := actualHeaders[headerKey]; ok {
|
||||
t.Errorf("Request sent to upstream backend should not contain %v header", headerKey)
|
||||
}
|
||||
|
||||
headerKey = "Replace-Me"
|
||||
headerValue := replacer.Replace("{hostname}")
|
||||
value, ok := actualHeaders[headerKey]
|
||||
if !ok {
|
||||
t.Errorf("Request sent to upstream backend should not remove %v header", headerKey)
|
||||
} else if len(value) > 0 && headerValue != value[0] {
|
||||
t.Errorf("Request sent to upstream backend should replace value of %v header with %v. Instead value was %v", headerKey, headerValue, value)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestDownstreamHeadersUpdate(t *testing.T) {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Merge-Me", "Initial")
|
||||
w.Header().Add("Remove-Me", "Remove-Value")
|
||||
w.Header().Add("Replace-Me", "Replace-Value")
|
||||
w.Write([]byte("Hello, client"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream.host.DownstreamHeaders = http.Header{
|
||||
"+Merge-Me": {"Merge-Value"},
|
||||
"+Add-Me": {"Add-Value"},
|
||||
"-Remove-Me": {""},
|
||||
"Replace-Me": {"{hostname}"},
|
||||
}
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Upstreams: []Upstream{upstream},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
replacer := httpserver.NewReplacer(r, nil, "")
|
||||
actualHeaders := w.Header()
|
||||
|
||||
headerKey := "Merge-Me"
|
||||
values, ok := actualHeaders[headerKey]
|
||||
if !ok {
|
||||
t.Errorf("Downstream response does not contain expected %v header. Expected header should be added", headerKey)
|
||||
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
|
||||
t.Errorf("Values for header `+Merge-Me` should be merged. Got %v", values)
|
||||
}
|
||||
|
||||
headerKey = "Add-Me"
|
||||
if _, ok := actualHeaders[headerKey]; !ok {
|
||||
t.Errorf("Downstream response does not contain expected %v header", headerKey)
|
||||
}
|
||||
|
||||
headerKey = "Remove-Me"
|
||||
if _, ok := actualHeaders[headerKey]; ok {
|
||||
t.Errorf("Downstream response should not contain %v header received from upstream", headerKey)
|
||||
}
|
||||
|
||||
headerKey = "Replace-Me"
|
||||
headerValue := replacer.Replace("{hostname}")
|
||||
value, ok := actualHeaders[headerKey]
|
||||
if !ok {
|
||||
t.Errorf("Downstream response should contain %v header and not remove it", headerKey)
|
||||
} else if len(value) > 0 && headerValue != value[0] {
|
||||
t.Errorf("Downstream response should have header %v with value %v. Instead value was %v", headerKey, headerValue, value)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
||||
uri, _ := url.Parse(name)
|
||||
u := &fakeUpstream{
|
||||
name: name,
|
||||
host: &UpstreamHost{
|
||||
Name: name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, ""),
|
||||
},
|
||||
}
|
||||
if insecure {
|
||||
u.host.ReverseProxy.Transport = InsecureTransport
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
type fakeUpstream struct {
|
||||
name string
|
||||
host *UpstreamHost
|
||||
}
|
||||
|
||||
func (u *fakeUpstream) From() string {
|
||||
return "/"
|
||||
}
|
||||
|
||||
func (u *fakeUpstream) Select() *UpstreamHost {
|
||||
return u.host
|
||||
}
|
||||
|
||||
func (u *fakeUpstream) AllowedPath(requestPath string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// newWebSocketTestProxy returns a test proxy that will
|
||||
// redirect to the specified backendAddr. The function
|
||||
// also sets up the rules/environment for testing WebSocket
|
||||
// proxy.
|
||||
func newWebSocketTestProxy(backendAddr string) *Proxy {
|
||||
return &Proxy{
|
||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}},
|
||||
}
|
||||
}
|
||||
|
||||
func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
|
||||
return &Proxy{
|
||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}},
|
||||
}
|
||||
}
|
||||
|
||||
type fakeWsUpstream struct {
|
||||
name string
|
||||
without string
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) From() string {
|
||||
return "/"
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) Select() *UpstreamHost {
|
||||
uri, _ := url.Parse(u.name)
|
||||
return &UpstreamHost{
|
||||
Name: u.name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without),
|
||||
UpstreamHeaders: http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"}},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// recorderHijacker is a ResponseRecorder that can
|
||||
// be hijacked.
|
||||
type recorderHijacker struct {
|
||||
*httptest.ResponseRecorder
|
||||
fakeConn *fakeConn
|
||||
}
|
||||
|
||||
func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return rh.fakeConn, nil, nil
|
||||
}
|
||||
|
||||
type fakeConn struct {
|
||||
readBuf bytes.Buffer
|
||||
writeBuf bytes.Buffer
|
||||
}
|
||||
|
||||
func (c *fakeConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *fakeConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *fakeConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
func (c *fakeConn) Close() error { return nil }
|
||||
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
|
||||
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
|
269
caddyhttp/proxy/reverseproxy.go
Normal file
269
caddyhttp/proxy/reverseproxy.go
Normal file
|
@ -0,0 +1,269 @@
|
|||
// This file is adapted from code in the net/http/httputil
|
||||
// package of the Go standard library, which is by the
|
||||
// Go Authors, and bears this copyright and license info:
|
||||
//
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
//
|
||||
// This file has been modified from the standard lib to
|
||||
// meet the needs of the application.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// onExitFlushLoop is a callback set by tests to detect the state of the
|
||||
// flushLoop() goroutine.
|
||||
var onExitFlushLoop func()
|
||||
|
||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||
// sends it to another server, proxying the response back to the
|
||||
// client.
|
||||
type ReverseProxy struct {
|
||||
// Director must be a function which modifies
|
||||
// the request into a new request to be sent
|
||||
// using Transport. Its response is then copied
|
||||
// back to the original client unmodified.
|
||||
Director func(*http.Request)
|
||||
|
||||
// The transport used to perform proxy requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// FlushInterval specifies the flush interval
|
||||
// to flush to the client while copying the
|
||||
// response body.
|
||||
// If zero, no periodic flushing is done.
|
||||
FlushInterval time.Duration
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
// Though the relevant directive prefix is just "unix:", url.Parse
|
||||
// will - assuming the regular URL scheme - add additional slashes
|
||||
// as if "unix" was a request protocol.
|
||||
// What we need is just the path, so if "unix:/var/run/www.socket"
|
||||
// was the proxy directive, the parsed hostName would be
|
||||
// "unix:///var/run/www.socket", hence the ambiguous trimming.
|
||||
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
|
||||
return func(network, addr string) (conn net.Conn, err error) {
|
||||
return net.Dial("unix", hostName[len("unix://"):])
|
||||
}
|
||||
}
|
||||
|
||||
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
|
||||
// URLs to the scheme, host, and base path provided in target. If the
|
||||
// target's path is "/base" and the incoming request was for "/dir",
|
||||
// the target request will be for /base/dir.
|
||||
// Without logic: target's path is "/", incoming is "/api/messages",
|
||||
// without is "/api", then the target request will be for /messages.
|
||||
func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
|
||||
targetQuery := target.RawQuery
|
||||
director := func(req *http.Request) {
|
||||
if target.Scheme == "unix" {
|
||||
// to make Dial work with unix URL,
|
||||
// scheme and host have to be faked
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = "socket"
|
||||
} else {
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
}
|
||||
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
||||
}
|
||||
// Trims the path of the socket from the URL path.
|
||||
// This is done because req.URL passed to your proxied service
|
||||
// will have the full path of the socket file prefixed to it.
|
||||
// Calling /test on a server that proxies requests to
|
||||
// unix:/var/run/www.socket will thus set the requested path
|
||||
// to /var/run/www.socket/test, rendering paths useless.
|
||||
if target.Scheme == "unix" {
|
||||
// See comment on socketDial for the trim
|
||||
socketPrefix := target.String()[len("unix://"):]
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix)
|
||||
}
|
||||
// We are then safe to remove the `without` prefix.
|
||||
if without != "" {
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
|
||||
}
|
||||
}
|
||||
rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events
|
||||
if target.Scheme == "unix" {
|
||||
rp.Transport = &http.Transport{
|
||||
Dial: socketDial(target.String()),
|
||||
}
|
||||
}
|
||||
return rp
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailers",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
// InsecureTransport is used to facilitate HTTPS proxying
|
||||
// when it is OK for upstream to be using a bad certificate,
|
||||
// since this transport skips verification.
|
||||
var InsecureTransport http.RoundTripper = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
type respUpdateFn func(resp *http.Response)
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
|
||||
transport := p.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
p.Director(outreq)
|
||||
outreq.Proto = "HTTP/1.1"
|
||||
outreq.ProtoMajor = 1
|
||||
outreq.ProtoMinor = 1
|
||||
outreq.Close = false
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if respUpdateFn != nil {
|
||||
respUpdateFn(res)
|
||||
}
|
||||
|
||||
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
|
||||
res.Body.Close()
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
backendConn, err := net.Dial("tcp", outreq.URL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
outreq.Write(backendConn)
|
||||
|
||||
go func() {
|
||||
io.Copy(backendConn, conn) // write tcp stream to backend.
|
||||
}()
|
||||
io.Copy(conn, backendConn) // read tcp stream from backend.
|
||||
} else {
|
||||
defer res.Body.Close()
|
||||
for _, h := range hopHeaders {
|
||||
res.Header.Del(h)
|
||||
}
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
p.copyResponse(rw, res.Body)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||
if p.FlushInterval != 0 {
|
||||
if wf, ok := dst.(writeFlusher); ok {
|
||||
mlw := &maxLatencyWriter{
|
||||
dst: wf,
|
||||
latency: p.FlushInterval,
|
||||
done: make(chan bool),
|
||||
}
|
||||
go mlw.flushLoop()
|
||||
defer mlw.stop()
|
||||
dst = mlw
|
||||
}
|
||||
}
|
||||
io.Copy(dst, src)
|
||||
}
|
||||
|
||||
type writeFlusher interface {
|
||||
io.Writer
|
||||
http.Flusher
|
||||
}
|
||||
|
||||
type maxLatencyWriter struct {
|
||||
dst writeFlusher
|
||||
latency time.Duration
|
||||
|
||||
lk sync.Mutex // protects Write + Flush
|
||||
done chan bool
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
|
||||
m.lk.Lock()
|
||||
defer m.lk.Unlock()
|
||||
return m.dst.Write(p)
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) flushLoop() {
|
||||
t := time.NewTicker(m.latency)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.done:
|
||||
if onExitFlushLoop != nil {
|
||||
onExitFlushLoop()
|
||||
}
|
||||
return
|
||||
case <-t.C:
|
||||
m.lk.Lock()
|
||||
m.dst.Flush()
|
||||
m.lk.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) stop() { m.done <- true }
|
26
caddyhttp/proxy/setup.go
Normal file
26
caddyhttp/proxy/setup.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin(caddy.Plugin{
|
||||
Name: "proxy",
|
||||
ServerType: "http",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// setup configures a new Proxy middleware instance.
|
||||
func setup(c *caddy.Controller) error {
|
||||
upstreams, err := NewStaticUpstreams(c.Dispenser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
return Proxy{Next: next, Upstreams: upstreams}
|
||||
})
|
||||
return nil
|
||||
}
|
140
caddyhttp/proxy/setup_test.go
Normal file
140
caddyhttp/proxy/setup_test.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expectedHosts map[string]struct{}
|
||||
}{
|
||||
// test #0 test usual to destination still works normally
|
||||
{
|
||||
"proxy / localhost:80",
|
||||
false,
|
||||
map[string]struct{}{
|
||||
"http://localhost:80": {},
|
||||
},
|
||||
},
|
||||
|
||||
// test #1 test usual to destination with port range
|
||||
{
|
||||
"proxy / localhost:8080-8082",
|
||||
false,
|
||||
map[string]struct{}{
|
||||
"http://localhost:8080": {},
|
||||
"http://localhost:8081": {},
|
||||
"http://localhost:8082": {},
|
||||
},
|
||||
},
|
||||
|
||||
// test #2 test upstream directive
|
||||
{
|
||||
"proxy / {\n upstream localhost:8080\n}",
|
||||
false,
|
||||
map[string]struct{}{
|
||||
"http://localhost:8080": {},
|
||||
},
|
||||
},
|
||||
|
||||
// test #3 test upstream directive with port range
|
||||
{
|
||||
"proxy / {\n upstream localhost:8080-8081\n}",
|
||||
false,
|
||||
map[string]struct{}{
|
||||
"http://localhost:8080": {},
|
||||
"http://localhost:8081": {},
|
||||
},
|
||||
},
|
||||
|
||||
// test #4 test to destination with upstream directive
|
||||
{
|
||||
"proxy / localhost:8080 {\n upstream localhost:8081-8082\n}",
|
||||
false,
|
||||
map[string]struct{}{
|
||||
"http://localhost:8080": {},
|
||||
"http://localhost:8081": {},
|
||||
"http://localhost:8082": {},
|
||||
},
|
||||
},
|
||||
|
||||
// test #5 test with unix sockets
|
||||
{
|
||||
"proxy / localhost:8080 {\n upstream unix:/var/foo\n}",
|
||||
false,
|
||||
map[string]struct{}{
|
||||
"http://localhost:8080": {},
|
||||
"unix:/var/foo": {},
|
||||
},
|
||||
},
|
||||
|
||||
// test #6 test fail on malformed port range
|
||||
{
|
||||
"proxy / localhost:8090-8080",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
|
||||
// test #7 test fail on malformed port range 2
|
||||
{
|
||||
"proxy / {\n upstream localhost:80-A\n}",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
|
||||
// test #8 test upstreams without ports work correctly
|
||||
{
|
||||
"proxy / http://localhost {\n upstream testendpoint\n}",
|
||||
false,
|
||||
map[string]struct{}{
|
||||
"http://localhost": {},
|
||||
"http://testendpoint": {},
|
||||
},
|
||||
},
|
||||
|
||||
// test #9 test several upstream directives
|
||||
{
|
||||
"proxy / localhost:8080 {\n upstream localhost:8081-8082\n upstream localhost:8083-8085\n}",
|
||||
false,
|
||||
map[string]struct{}{
|
||||
"http://localhost:8080": {},
|
||||
"http://localhost:8081": {},
|
||||
"http://localhost:8082": {},
|
||||
"http://localhost:8083": {},
|
||||
"http://localhost:8084": {},
|
||||
"http://localhost:8085": {},
|
||||
},
|
||||
},
|
||||
} {
|
||||
err := setup(caddy.NewTestController(test.input))
|
||||
if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test case #%d received an error of %v", i, err)
|
||||
} else if test.shouldErr {
|
||||
continue
|
||||
}
|
||||
|
||||
mids := httpserver.GetConfig("").Middleware()
|
||||
mid := mids[len(mids)-1]
|
||||
|
||||
upstreams := mid(nil).(Proxy).Upstreams
|
||||
for _, upstream := range upstreams {
|
||||
val := reflect.ValueOf(upstream).Elem()
|
||||
hosts := val.FieldByName("Hosts").Interface().(HostPool)
|
||||
if len(hosts) != len(test.expectedHosts) {
|
||||
t.Errorf("Test case #%d expected %d hosts but received %d", i, len(test.expectedHosts), len(hosts))
|
||||
} else {
|
||||
for _, host := range hosts {
|
||||
if _, found := test.expectedHosts[host.Name]; !found {
|
||||
t.Errorf("Test case #%d has an unexpected host %s", i, host.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
345
caddyhttp/proxy/upstream.go
Normal file
345
caddyhttp/proxy/upstream.go
Normal file
|
@ -0,0 +1,345 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
var (
|
||||
supportedPolicies = make(map[string]func() Policy)
|
||||
)
|
||||
|
||||
type staticUpstream struct {
|
||||
from string
|
||||
upstreamHeaders http.Header
|
||||
downstreamHeaders http.Header
|
||||
Hosts HostPool
|
||||
Policy Policy
|
||||
insecureSkipVerify bool
|
||||
|
||||
FailTimeout time.Duration
|
||||
MaxFails int32
|
||||
MaxConns int64
|
||||
HealthCheck struct {
|
||||
Path string
|
||||
Interval time.Duration
|
||||
}
|
||||
WithoutPathPrefix string
|
||||
IgnoredSubPaths []string
|
||||
}
|
||||
|
||||
// NewStaticUpstreams parses the configuration input and sets up
|
||||
// static upstreams for the proxy middleware.
|
||||
func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) {
|
||||
var upstreams []Upstream
|
||||
for c.Next() {
|
||||
upstream := &staticUpstream{
|
||||
from: "",
|
||||
upstreamHeaders: make(http.Header),
|
||||
downstreamHeaders: make(http.Header),
|
||||
Hosts: nil,
|
||||
Policy: &Random{},
|
||||
FailTimeout: 10 * time.Second,
|
||||
MaxFails: 1,
|
||||
MaxConns: 0,
|
||||
}
|
||||
|
||||
if !c.Args(&upstream.from) {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
|
||||
var to []string
|
||||
for _, t := range c.RemainingArgs() {
|
||||
parsed, err := parseUpstream(t)
|
||||
if err != nil {
|
||||
return upstreams, err
|
||||
}
|
||||
to = append(to, parsed...)
|
||||
}
|
||||
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "upstream":
|
||||
if !c.NextArg() {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
parsed, err := parseUpstream(c.Val())
|
||||
if err != nil {
|
||||
return upstreams, err
|
||||
}
|
||||
to = append(to, parsed...)
|
||||
default:
|
||||
if err := parseBlock(&c, upstream); err != nil {
|
||||
return upstreams, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(to) == 0 {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
|
||||
upstream.Hosts = make([]*UpstreamHost, len(to))
|
||||
for i, host := range to {
|
||||
uh, err := upstream.NewHost(host)
|
||||
if err != nil {
|
||||
return upstreams, err
|
||||
}
|
||||
upstream.Hosts[i] = uh
|
||||
}
|
||||
|
||||
if upstream.HealthCheck.Path != "" {
|
||||
go upstream.HealthCheckWorker(nil)
|
||||
}
|
||||
upstreams = append(upstreams, upstream)
|
||||
}
|
||||
return upstreams, nil
|
||||
}
|
||||
|
||||
// RegisterPolicy adds a custom policy to the proxy.
|
||||
func RegisterPolicy(name string, policy func() Policy) {
|
||||
supportedPolicies[name] = policy
|
||||
}
|
||||
|
||||
func (u *staticUpstream) From() string {
|
||||
return u.from
|
||||
}
|
||||
|
||||
func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
|
||||
if !strings.HasPrefix(host, "http") &&
|
||||
!strings.HasPrefix(host, "unix:") {
|
||||
host = "http://" + host
|
||||
}
|
||||
uh := &UpstreamHost{
|
||||
Name: host,
|
||||
Conns: 0,
|
||||
Fails: 0,
|
||||
FailTimeout: u.FailTimeout,
|
||||
Unhealthy: false,
|
||||
UpstreamHeaders: u.upstreamHeaders,
|
||||
DownstreamHeaders: u.downstreamHeaders,
|
||||
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
|
||||
return func(uh *UpstreamHost) bool {
|
||||
if uh.Unhealthy {
|
||||
return true
|
||||
}
|
||||
if uh.Fails >= u.MaxFails &&
|
||||
u.MaxFails != 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
}(u),
|
||||
WithoutPathPrefix: u.WithoutPathPrefix,
|
||||
MaxConns: u.MaxConns,
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(uh.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix)
|
||||
if u.insecureSkipVerify {
|
||||
uh.ReverseProxy.Transport = InsecureTransport
|
||||
}
|
||||
return uh, nil
|
||||
}
|
||||
|
||||
func parseUpstream(u string) ([]string, error) {
|
||||
if !strings.HasPrefix(u, "unix:") {
|
||||
colonIdx := strings.LastIndex(u, ":")
|
||||
protoIdx := strings.Index(u, "://")
|
||||
|
||||
if colonIdx != -1 && colonIdx != protoIdx {
|
||||
us := u[:colonIdx]
|
||||
ports := u[len(us)+1:]
|
||||
if separators := strings.Count(ports, "-"); separators > 1 {
|
||||
return nil, fmt.Errorf("port range [%s] is invalid", ports)
|
||||
} else if separators == 1 {
|
||||
portsStr := strings.Split(ports, "-")
|
||||
pIni, err := strconv.Atoi(portsStr[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pEnd, err := strconv.Atoi(portsStr[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pEnd <= pIni {
|
||||
return nil, fmt.Errorf("port range [%s] is invalid", ports)
|
||||
}
|
||||
|
||||
hosts := []string{}
|
||||
for p := pIni; p <= pEnd; p++ {
|
||||
hosts = append(hosts, fmt.Sprintf("%s:%d", us, p))
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return []string{u}, nil
|
||||
|
||||
}
|
||||
|
||||
func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
|
||||
switch c.Val() {
|
||||
case "policy":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
policyCreateFunc, ok := supportedPolicies[c.Val()]
|
||||
if !ok {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.Policy = policyCreateFunc()
|
||||
case "fail_timeout":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
dur, err := time.ParseDuration(c.Val())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.FailTimeout = dur
|
||||
case "max_fails":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
n, err := strconv.Atoi(c.Val())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.MaxFails = int32(n)
|
||||
case "max_conns":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
n, err := strconv.ParseInt(c.Val(), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.MaxConns = n
|
||||
case "health_check":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.HealthCheck.Path = c.Val()
|
||||
u.HealthCheck.Interval = 30 * time.Second
|
||||
if c.NextArg() {
|
||||
dur, err := time.ParseDuration(c.Val())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.HealthCheck.Interval = dur
|
||||
}
|
||||
case "header_upstream":
|
||||
fallthrough
|
||||
case "proxy_header":
|
||||
var header, value string
|
||||
if !c.Args(&header, &value) {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.upstreamHeaders.Add(header, value)
|
||||
case "header_downstream":
|
||||
var header, value string
|
||||
if !c.Args(&header, &value) {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.downstreamHeaders.Add(header, value)
|
||||
case "websocket":
|
||||
u.upstreamHeaders.Add("Connection", "{>Connection}")
|
||||
u.upstreamHeaders.Add("Upgrade", "{>Upgrade}")
|
||||
case "without":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.WithoutPathPrefix = c.Val()
|
||||
case "except":
|
||||
ignoredPaths := c.RemainingArgs()
|
||||
if len(ignoredPaths) == 0 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.IgnoredSubPaths = ignoredPaths
|
||||
case "insecure_skip_verify":
|
||||
u.insecureSkipVerify = true
|
||||
default:
|
||||
return c.Errf("unknown property '%s'", c.Val())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *staticUpstream) healthCheck() {
|
||||
for _, host := range u.Hosts {
|
||||
hostURL := host.Name + u.HealthCheck.Path
|
||||
if r, err := http.Get(hostURL); err == nil {
|
||||
io.Copy(ioutil.Discard, r.Body)
|
||||
r.Body.Close()
|
||||
host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
|
||||
} else {
|
||||
host.Unhealthy = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
|
||||
ticker := time.NewTicker(u.HealthCheck.Interval)
|
||||
u.healthCheck()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
u.healthCheck()
|
||||
case <-stop:
|
||||
// TODO: the library should provide a stop channel and global
|
||||
// waitgroup to allow goroutines started by plugins a chance
|
||||
// to clean themselves up.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *staticUpstream) Select() *UpstreamHost {
|
||||
pool := u.Hosts
|
||||
if len(pool) == 1 {
|
||||
if !pool[0].Available() {
|
||||
return nil
|
||||
}
|
||||
return pool[0]
|
||||
}
|
||||
allUnavailable := true
|
||||
for _, host := range pool {
|
||||
if host.Available() {
|
||||
allUnavailable = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allUnavailable {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.Policy == nil {
|
||||
return (&Random{}).Select(pool)
|
||||
}
|
||||
return u.Policy.Select(pool)
|
||||
}
|
||||
|
||||
func (u *staticUpstream) AllowedPath(requestPath string) bool {
|
||||
for _, ignoredSubPath := range u.IgnoredSubPaths {
|
||||
if httpserver.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
135
caddyhttp/proxy/upstream_test.go
Normal file
135
caddyhttp/proxy/upstream_test.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewHost(t *testing.T) {
|
||||
upstream := &staticUpstream{
|
||||
FailTimeout: 10 * time.Second,
|
||||
MaxConns: 1,
|
||||
MaxFails: 1,
|
||||
}
|
||||
|
||||
uh, err := upstream.NewHost("example.com")
|
||||
if err != nil {
|
||||
t.Error("Expected no error")
|
||||
}
|
||||
if uh.Name != "http://example.com" {
|
||||
t.Error("Expected default schema to be added to Name.")
|
||||
}
|
||||
if uh.FailTimeout != upstream.FailTimeout {
|
||||
t.Error("Expected default FailTimeout to be set.")
|
||||
}
|
||||
if uh.MaxConns != upstream.MaxConns {
|
||||
t.Error("Expected default MaxConns to be set.")
|
||||
}
|
||||
if uh.CheckDown == nil {
|
||||
t.Error("Expected default CheckDown to be set.")
|
||||
}
|
||||
if uh.CheckDown(uh) {
|
||||
t.Error("Expected new host not to be down.")
|
||||
}
|
||||
// mark Unhealthy
|
||||
uh.Unhealthy = true
|
||||
if !uh.CheckDown(uh) {
|
||||
t.Error("Expected unhealthy host to be down.")
|
||||
}
|
||||
// mark with Fails
|
||||
uh.Unhealthy = false
|
||||
uh.Fails = 1
|
||||
if !uh.CheckDown(uh) {
|
||||
t.Error("Expected failed host to be down.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
upstream := &staticUpstream{
|
||||
from: "",
|
||||
Hosts: testPool(),
|
||||
Policy: &Random{},
|
||||
FailTimeout: 10 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
upstream.healthCheck()
|
||||
if upstream.Hosts[0].Down() {
|
||||
t.Error("Expected first host in testpool to not fail healthcheck.")
|
||||
}
|
||||
if !upstream.Hosts[1].Down() {
|
||||
t.Error("Expected second host in testpool to fail healthcheck.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelect(t *testing.T) {
|
||||
upstream := &staticUpstream{
|
||||
from: "",
|
||||
Hosts: testPool()[:3],
|
||||
Policy: &Random{},
|
||||
FailTimeout: 10 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
upstream.Hosts[0].Unhealthy = true
|
||||
upstream.Hosts[1].Unhealthy = true
|
||||
upstream.Hosts[2].Unhealthy = true
|
||||
if h := upstream.Select(); h != nil {
|
||||
t.Error("Expected select to return nil as all host are down")
|
||||
}
|
||||
upstream.Hosts[2].Unhealthy = false
|
||||
if h := upstream.Select(); h == nil {
|
||||
t.Error("Expected select to not return nil")
|
||||
}
|
||||
upstream.Hosts[0].Conns = 1
|
||||
upstream.Hosts[0].MaxConns = 1
|
||||
upstream.Hosts[1].Conns = 1
|
||||
upstream.Hosts[1].MaxConns = 1
|
||||
upstream.Hosts[2].Conns = 1
|
||||
upstream.Hosts[2].MaxConns = 1
|
||||
if h := upstream.Select(); h != nil {
|
||||
t.Error("Expected select to return nil as all hosts are full")
|
||||
}
|
||||
upstream.Hosts[2].Conns = 0
|
||||
if h := upstream.Select(); h == nil {
|
||||
t.Error("Expected select to not return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterPolicy(t *testing.T) {
|
||||
name := "custom"
|
||||
customPolicy := &customPolicy{}
|
||||
RegisterPolicy(name, func() Policy { return customPolicy })
|
||||
if _, ok := supportedPolicies[name]; !ok {
|
||||
t.Error("Expected supportedPolicies to have a custom policy.")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAllowedPaths(t *testing.T) {
|
||||
upstream := &staticUpstream{
|
||||
from: "/proxy",
|
||||
IgnoredSubPaths: []string{"/download", "/static"},
|
||||
}
|
||||
tests := []struct {
|
||||
url string
|
||||
expected bool
|
||||
}{
|
||||
{"/proxy", true},
|
||||
{"/proxy/dl", true},
|
||||
{"/proxy/download", false},
|
||||
{"/proxy/download/static", false},
|
||||
{"/proxy/static", false},
|
||||
{"/proxy/static/download", false},
|
||||
{"/proxy/something/download", true},
|
||||
{"/proxy/something/static", true},
|
||||
{"/proxy//static", false},
|
||||
{"/proxy//static//download", false},
|
||||
{"/proxy//download", false},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
allowed := upstream.AllowedPath(test.url)
|
||||
if test.expected != allowed {
|
||||
t.Errorf("Test %d: expected %v found %v", i+1, test.expected, allowed)
|
||||
}
|
||||
}
|
||||
}
|
58
caddyhttp/redirect/redirect.go
Normal file
58
caddyhttp/redirect/redirect.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
// Package redirect is middleware for redirecting certain requests
|
||||
// to other locations.
|
||||
package redirect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"net/http"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// Redirect is middleware to respond with HTTP redirects
|
||||
type Redirect struct {
|
||||
Next httpserver.Handler
|
||||
Rules []Rule
|
||||
}
|
||||
|
||||
// ServeHTTP implements the httpserver.Handler interface.
|
||||
func (rd Redirect) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
for _, rule := range rd.Rules {
|
||||
if (rule.FromPath == "/" || r.URL.Path == rule.FromPath) && schemeMatches(rule, r) {
|
||||
to := httpserver.NewReplacer(r, nil, "").Replace(rule.To)
|
||||
if rule.Meta {
|
||||
safeTo := html.EscapeString(to)
|
||||
fmt.Fprintf(w, metaRedir, safeTo, safeTo)
|
||||
} else {
|
||||
http.Redirect(w, r, to, rule.Code)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
return rd.Next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func schemeMatches(rule Rule, req *http.Request) bool {
|
||||
return (rule.FromScheme == "https" && req.TLS != nil) ||
|
||||
(rule.FromScheme != "https" && req.TLS == nil)
|
||||
}
|
||||
|
||||
// Rule describes an HTTP redirect rule.
|
||||
type Rule struct {
|
||||
FromScheme, FromPath, To string
|
||||
Code int
|
||||
Meta bool
|
||||
}
|
||||
|
||||
// Script tag comes first since that will better imitate a redirect in the browser's
|
||||
// history, but the meta tag is a fallback for most non-JS clients.
|
||||
const metaRedir = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<script>window.location.replace("%s");</script>
|
||||
<meta http-equiv="refresh" content="0; URL='%s'">
|
||||
</head>
|
||||
<body>Redirecting...</body>
|
||||
</html>
|
||||
`
|
154
caddyhttp/redirect/redirect_test.go
Normal file
154
caddyhttp/redirect/redirect_test.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestRedirect(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
from string
|
||||
expectedLocation string
|
||||
expectedCode int
|
||||
}{
|
||||
{"http://localhost/from", "/to", http.StatusMovedPermanently},
|
||||
{"http://localhost/a", "/b", http.StatusTemporaryRedirect},
|
||||
{"http://localhost/aa", "", http.StatusOK},
|
||||
{"http://localhost/", "", http.StatusOK},
|
||||
{"http://localhost/a?foo=bar", "/b", http.StatusTemporaryRedirect},
|
||||
{"http://localhost/asdf?foo=bar", "", http.StatusOK},
|
||||
{"http://localhost/foo#bar", "", http.StatusOK},
|
||||
{"http://localhost/a#foo", "/b", http.StatusTemporaryRedirect},
|
||||
|
||||
// The scheme checks that were added to this package don't actually
|
||||
// help with redirects because of Caddy's design: a redirect middleware
|
||||
// for http will always be different than the redirect middleware for
|
||||
// https because they have to be on different listeners. These tests
|
||||
// just go to show extra bulletproofing, I guess.
|
||||
{"http://localhost/scheme", "https://localhost/scheme", http.StatusMovedPermanently},
|
||||
{"https://localhost/scheme", "", http.StatusOK},
|
||||
{"https://localhost/scheme2", "http://localhost/scheme2", http.StatusMovedPermanently},
|
||||
{"http://localhost/scheme2", "", http.StatusOK},
|
||||
{"http://localhost/scheme3", "https://localhost/scheme3", http.StatusMovedPermanently},
|
||||
{"https://localhost/scheme3", "", http.StatusOK},
|
||||
} {
|
||||
var nextCalled bool
|
||||
|
||||
re := Redirect{
|
||||
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
nextCalled = true
|
||||
return 0, nil
|
||||
}),
|
||||
Rules: []Rule{
|
||||
{FromPath: "/from", To: "/to", Code: http.StatusMovedPermanently},
|
||||
{FromPath: "/a", To: "/b", Code: http.StatusTemporaryRedirect},
|
||||
|
||||
// These http and https schemes would never actually be mixed in the same
|
||||
// redirect rule with Caddy because http and https schemes have different listeners,
|
||||
// so they don't share a redirect rule. So although these tests prove something
|
||||
// impossible with Caddy, it's extra bulletproofing at very little cost.
|
||||
{FromScheme: "http", FromPath: "/scheme", To: "https://localhost/scheme", Code: http.StatusMovedPermanently},
|
||||
{FromScheme: "https", FromPath: "/scheme2", To: "http://localhost/scheme2", Code: http.StatusMovedPermanently},
|
||||
{FromScheme: "", FromPath: "/scheme3", To: "https://localhost/scheme3", Code: http.StatusMovedPermanently},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", test.from, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
|
||||
}
|
||||
if strings.HasPrefix(test.from, "https://") {
|
||||
req.TLS = new(tls.ConnectionState) // faux HTTPS
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
re.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Header().Get("Location") != test.expectedLocation {
|
||||
t.Errorf("Test %d: Expected Location header to be %q but was %q",
|
||||
i, test.expectedLocation, rec.Header().Get("Location"))
|
||||
}
|
||||
|
||||
if rec.Code != test.expectedCode {
|
||||
t.Errorf("Test %d: Expected status code to be %d but was %d",
|
||||
i, test.expectedCode, rec.Code)
|
||||
}
|
||||
|
||||
if nextCalled && test.expectedLocation != "" {
|
||||
t.Errorf("Test %d: Next handler was unexpectedly called", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParametersRedirect(t *testing.T) {
|
||||
re := Redirect{
|
||||
Rules: []Rule{
|
||||
{FromPath: "/", Meta: false, To: "http://example.com{uri}"},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "/a?b=c", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test: Could not create HTTP request: %v", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
re.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Header().Get("Location") != "http://example.com/a?b=c" {
|
||||
t.Fatalf("Test: expected location header %q but was %q", "http://example.com/a?b=c", rec.Header().Get("Location"))
|
||||
}
|
||||
|
||||
re = Redirect{
|
||||
Rules: []Rule{
|
||||
{FromPath: "/", Meta: false, To: "http://example.com/a{path}?b=c&{query}"},
|
||||
},
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", "/d?e=f", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test: Could not create HTTP request: %v", err)
|
||||
}
|
||||
|
||||
re.ServeHTTP(rec, req)
|
||||
|
||||
if "http://example.com/a/d?b=c&e=f" != rec.Header().Get("Location") {
|
||||
t.Fatalf("Test: expected location header %q but was %q", "http://example.com/a/d?b=c&e=f", rec.Header().Get("Location"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetaRedirect(t *testing.T) {
|
||||
re := Redirect{
|
||||
Rules: []Rule{
|
||||
{FromPath: "/whatever", Meta: true, To: "/something"},
|
||||
{FromPath: "/", Meta: true, To: "https://example.com/"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range re.Rules {
|
||||
req, err := http.NewRequest("GET", test.FromPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
re.ServeHTTP(rec, req)
|
||||
|
||||
body, err := ioutil.ReadAll(rec.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not read HTTP response body: %v", i, err)
|
||||
}
|
||||
expectedSnippet := `<meta http-equiv="refresh" content="0; URL='` + test.To + `'">`
|
||||
if !bytes.Contains(body, []byte(expectedSnippet)) {
|
||||
t.Errorf("Test %d: Expected Response Body to contain %q but was %q",
|
||||
i, expectedSnippet, body)
|
||||
}
|
||||
}
|
||||
}
|
185
caddyhttp/redirect/setup.go
Normal file
185
caddyhttp/redirect/setup.go
Normal file
|
@ -0,0 +1,185 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin(caddy.Plugin{
|
||||
Name: "redir",
|
||||
ServerType: "http",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// setup configures a new Redirect middleware instance.
|
||||
func setup(c *caddy.Controller) error {
|
||||
rules, err := redirParse(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
return Redirect{Next: next, Rules: rules}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func redirParse(c *caddy.Controller) ([]Rule, error) {
|
||||
var redirects []Rule
|
||||
|
||||
cfg := httpserver.GetConfig(c.Key)
|
||||
|
||||
// setRedirCode sets the redirect code for rule if it can, or returns an error
|
||||
setRedirCode := func(code string, rule *Rule) error {
|
||||
if code == "meta" {
|
||||
rule.Meta = true
|
||||
} else if codeNumber, ok := httpRedirs[code]; ok {
|
||||
rule.Code = codeNumber
|
||||
} else {
|
||||
return c.Errf("Invalid redirect code '%v'", code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAndSaveRule checks the rule for validity (except the redir code)
|
||||
// and saves it if it's valid, or returns an error.
|
||||
checkAndSaveRule := func(rule Rule) error {
|
||||
if rule.FromPath == rule.To {
|
||||
return c.Err("'from' and 'to' values of redirect rule cannot be the same")
|
||||
}
|
||||
|
||||
for _, otherRule := range redirects {
|
||||
if otherRule.FromPath == rule.FromPath {
|
||||
return c.Errf("rule with duplicate 'from' value: %s -> %s", otherRule.FromPath, otherRule.To)
|
||||
}
|
||||
}
|
||||
|
||||
redirects = append(redirects, rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
|
||||
var hadOptionalBlock bool
|
||||
for c.NextBlock() {
|
||||
hadOptionalBlock = true
|
||||
|
||||
var rule Rule
|
||||
|
||||
if cfg.TLS.Enabled {
|
||||
rule.FromScheme = "https"
|
||||
} else {
|
||||
rule.FromScheme = "http"
|
||||
}
|
||||
|
||||
// Set initial redirect code
|
||||
// BUG: If the code is specified for a whole block and that code is invalid,
|
||||
// the line number will appear on the first line inside the block, even if that
|
||||
// line overwrites the block-level code with a valid redirect code. The program
|
||||
// still functions correctly, but the line number in the error reporting is
|
||||
// misleading to the user.
|
||||
if len(args) == 1 {
|
||||
err := setRedirCode(args[0], &rule)
|
||||
if err != nil {
|
||||
return redirects, err
|
||||
}
|
||||
} else {
|
||||
rule.Code = http.StatusMovedPermanently // default code
|
||||
}
|
||||
|
||||
// RemainingArgs only gets the values after the current token, but in our
|
||||
// case we want to include the current token to get an accurate count.
|
||||
insideArgs := append([]string{c.Val()}, c.RemainingArgs()...)
|
||||
|
||||
switch len(insideArgs) {
|
||||
case 1:
|
||||
// To specified (catch-all redirect)
|
||||
// Not sure why user is doing this in a table, as it causes all other redirects to be ignored.
|
||||
// As such, this feature remains undocumented.
|
||||
rule.FromPath = "/"
|
||||
rule.To = insideArgs[0]
|
||||
case 2:
|
||||
// From and To specified
|
||||
rule.FromPath = insideArgs[0]
|
||||
rule.To = insideArgs[1]
|
||||
case 3:
|
||||
// From, To, and Code specified
|
||||
rule.FromPath = insideArgs[0]
|
||||
rule.To = insideArgs[1]
|
||||
err := setRedirCode(insideArgs[2], &rule)
|
||||
if err != nil {
|
||||
return redirects, err
|
||||
}
|
||||
default:
|
||||
return redirects, c.ArgErr()
|
||||
}
|
||||
|
||||
err := checkAndSaveRule(rule)
|
||||
if err != nil {
|
||||
return redirects, err
|
||||
}
|
||||
}
|
||||
|
||||
if !hadOptionalBlock {
|
||||
var rule Rule
|
||||
|
||||
if cfg.TLS.Enabled {
|
||||
rule.FromScheme = "https"
|
||||
} else {
|
||||
rule.FromScheme = "http"
|
||||
}
|
||||
|
||||
rule.Code = http.StatusMovedPermanently // default
|
||||
|
||||
switch len(args) {
|
||||
case 1:
|
||||
// To specified (catch-all redirect)
|
||||
rule.FromPath = "/"
|
||||
rule.To = args[0]
|
||||
case 2:
|
||||
// To and Code specified (catch-all redirect)
|
||||
rule.FromPath = "/"
|
||||
rule.To = args[0]
|
||||
err := setRedirCode(args[1], &rule)
|
||||
if err != nil {
|
||||
return redirects, err
|
||||
}
|
||||
case 3:
|
||||
// From, To, and Code specified
|
||||
rule.FromPath = args[0]
|
||||
rule.To = args[1]
|
||||
err := setRedirCode(args[2], &rule)
|
||||
if err != nil {
|
||||
return redirects, err
|
||||
}
|
||||
default:
|
||||
return redirects, c.ArgErr()
|
||||
}
|
||||
|
||||
err := checkAndSaveRule(rule)
|
||||
if err != nil {
|
||||
return redirects, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return redirects, nil
|
||||
}
|
||||
|
||||
// httpRedirs is a list of supported HTTP redirect codes.
|
||||
var httpRedirs = map[string]int{
|
||||
"300": http.StatusMultipleChoices,
|
||||
"301": http.StatusMovedPermanently,
|
||||
"302": http.StatusFound, // (NOT CORRECT for "Temporary Redirect", see 307)
|
||||
"303": http.StatusSeeOther,
|
||||
"304": http.StatusNotModified,
|
||||
"305": http.StatusUseProxy,
|
||||
"307": http.StatusTemporaryRedirect,
|
||||
"308": 308, // Permanent Redirect
|
||||
}
|
69
caddyhttp/redirect/setup_test.go
Normal file
69
caddyhttp/redirect/setup_test.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
|
||||
for j, test := range []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expectedRules []Rule
|
||||
}{
|
||||
// test case #0 tests the recognition of a valid HTTP status code defined outside of block statement
|
||||
{"redir 300 {\n/ /foo\n}", false, []Rule{{FromPath: "/", To: "/foo", Code: 300}}},
|
||||
|
||||
// test case #1 tests the recognition of an invalid HTTP status code defined outside of block statement
|
||||
{"redir 9000 {\n/ /foo\n}", true, []Rule{{}}},
|
||||
|
||||
// test case #2 tests the detection of a valid HTTP status code outside of a block statement being overriden by an invalid HTTP status code inside statement of a block statement
|
||||
{"redir 300 {\n/ /foo 9000\n}", true, []Rule{{}}},
|
||||
|
||||
// test case #3 tests the detection of an invalid HTTP status code outside of a block statement being overriden by a valid HTTP status code inside statement of a block statement
|
||||
{"redir 9000 {\n/ /foo 300\n}", true, []Rule{{}}},
|
||||
|
||||
// test case #4 tests the recognition of a TO redirection in a block statement.The HTTP status code is set to the default of 301 - MovedPermanently
|
||||
{"redir 302 {\n/foo\n}", false, []Rule{{FromPath: "/", To: "/foo", Code: 302}}},
|
||||
|
||||
// test case #5 tests the recognition of a TO and From redirection in a block statement
|
||||
{"redir {\n/bar /foo 303\n}", false, []Rule{{FromPath: "/bar", To: "/foo", Code: 303}}},
|
||||
|
||||
// test case #6 tests the recognition of a TO redirection in a non-block statement. The HTTP status code is set to the default of 301 - MovedPermanently
|
||||
{"redir /foo", false, []Rule{{FromPath: "/", To: "/foo", Code: 301}}},
|
||||
|
||||
// test case #7 tests the recognition of a TO and From redirection in a non-block statement
|
||||
{"redir /bar /foo 303", false, []Rule{{FromPath: "/bar", To: "/foo", Code: 303}}},
|
||||
|
||||
// test case #8 tests the recognition of multiple redirections
|
||||
{"redir {\n / /foo 304 \n} \n redir {\n /bar /foobar 305 \n}", false, []Rule{{FromPath: "/", To: "/foo", Code: 304}, {FromPath: "/bar", To: "/foobar", Code: 305}}},
|
||||
|
||||
// test case #9 tests the detection of duplicate redirections
|
||||
{"redir {\n /bar /foo 304 \n} redir {\n /bar /foo 304 \n}", true, []Rule{{}}},
|
||||
} {
|
||||
err := setup(caddy.NewTestController(test.input))
|
||||
if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test case #%d recieved an error of %v", j, err)
|
||||
} else if test.shouldErr {
|
||||
continue
|
||||
}
|
||||
mids := httpserver.GetConfig("").Middleware()
|
||||
recievedRules := mids[len(mids)-1](nil).(Redirect).Rules
|
||||
|
||||
for i, recievedRule := range recievedRules {
|
||||
if recievedRule.FromPath != test.expectedRules[i].FromPath {
|
||||
t.Errorf("Test case #%d.%d expected a from path of %s, but recieved a from path of %s", j, i, test.expectedRules[i].FromPath, recievedRule.FromPath)
|
||||
}
|
||||
if recievedRule.To != test.expectedRules[i].To {
|
||||
t.Errorf("Test case #%d.%d expected a TO path of %s, but recieved a TO path of %s", j, i, test.expectedRules[i].To, recievedRule.To)
|
||||
}
|
||||
if recievedRule.Code != test.expectedRules[i].Code {
|
||||
t.Errorf("Test case #%d.%d expected a HTTP status code of %d, but recieved a code of %d", j, i, test.expectedRules[i].Code, recievedRule.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
130
caddyhttp/rewrite/condition.go
Normal file
130
caddyhttp/rewrite/condition.go
Normal file
|
@ -0,0 +1,130 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// Operators
|
||||
const (
|
||||
Is = "is"
|
||||
Not = "not"
|
||||
Has = "has"
|
||||
NotHas = "not_has"
|
||||
StartsWith = "starts_with"
|
||||
EndsWith = "ends_with"
|
||||
Match = "match"
|
||||
NotMatch = "not_match"
|
||||
)
|
||||
|
||||
func operatorError(operator string) error {
|
||||
return fmt.Errorf("Invalid operator %v", operator)
|
||||
}
|
||||
|
||||
func newReplacer(r *http.Request) httpserver.Replacer {
|
||||
return httpserver.NewReplacer(r, nil, "")
|
||||
}
|
||||
|
||||
// condition is a rewrite condition.
|
||||
type condition func(string, string) bool
|
||||
|
||||
var conditions = map[string]condition{
|
||||
Is: isFunc,
|
||||
Not: notFunc,
|
||||
Has: hasFunc,
|
||||
NotHas: notHasFunc,
|
||||
StartsWith: startsWithFunc,
|
||||
EndsWith: endsWithFunc,
|
||||
Match: matchFunc,
|
||||
NotMatch: notMatchFunc,
|
||||
}
|
||||
|
||||
// isFunc is condition for Is operator.
|
||||
// It checks for equality.
|
||||
func isFunc(a, b string) bool {
|
||||
return a == b
|
||||
}
|
||||
|
||||
// notFunc is condition for Not operator.
|
||||
// It checks for inequality.
|
||||
func notFunc(a, b string) bool {
|
||||
return a != b
|
||||
}
|
||||
|
||||
// hasFunc is condition for Has operator.
|
||||
// It checks if b is a substring of a.
|
||||
func hasFunc(a, b string) bool {
|
||||
return strings.Contains(a, b)
|
||||
}
|
||||
|
||||
// notHasFunc is condition for NotHas operator.
|
||||
// It checks if b is not a substring of a.
|
||||
func notHasFunc(a, b string) bool {
|
||||
return !strings.Contains(a, b)
|
||||
}
|
||||
|
||||
// startsWithFunc is condition for StartsWith operator.
|
||||
// It checks if b is a prefix of a.
|
||||
func startsWithFunc(a, b string) bool {
|
||||
return strings.HasPrefix(a, b)
|
||||
}
|
||||
|
||||
// endsWithFunc is condition for EndsWith operator.
|
||||
// It checks if b is a suffix of a.
|
||||
func endsWithFunc(a, b string) bool {
|
||||
return strings.HasSuffix(a, b)
|
||||
}
|
||||
|
||||
// matchFunc is condition for Match operator.
|
||||
// It does regexp matching of a against pattern in b
|
||||
// and returns if they match.
|
||||
func matchFunc(a, b string) bool {
|
||||
matched, _ := regexp.MatchString(b, a)
|
||||
return matched
|
||||
}
|
||||
|
||||
// notMatchFunc is condition for NotMatch operator.
|
||||
// It does regexp matching of a against pattern in b
|
||||
// and returns if they do not match.
|
||||
func notMatchFunc(a, b string) bool {
|
||||
matched, _ := regexp.MatchString(b, a)
|
||||
return !matched
|
||||
}
|
||||
|
||||
// If is statement for a rewrite condition.
|
||||
type If struct {
|
||||
A string
|
||||
Operator string
|
||||
B string
|
||||
}
|
||||
|
||||
// True returns true if the condition is true and false otherwise.
|
||||
// If r is not nil, it replaces placeholders before comparison.
|
||||
func (i If) True(r *http.Request) bool {
|
||||
if c, ok := conditions[i.Operator]; ok {
|
||||
a, b := i.A, i.B
|
||||
if r != nil {
|
||||
replacer := newReplacer(r)
|
||||
a = replacer.Replace(i.A)
|
||||
b = replacer.Replace(i.B)
|
||||
}
|
||||
return c(a, b)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NewIf creates a new If condition.
|
||||
func NewIf(a, operator, b string) (If, error) {
|
||||
if _, ok := conditions[operator]; !ok {
|
||||
return If{}, operatorError(operator)
|
||||
}
|
||||
return If{
|
||||
A: a,
|
||||
Operator: operator,
|
||||
B: b,
|
||||
}, nil
|
||||
}
|
106
caddyhttp/rewrite/condition_test.go
Normal file
106
caddyhttp/rewrite/condition_test.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
condition string
|
||||
isTrue bool
|
||||
}{
|
||||
{"a is b", false},
|
||||
{"a is a", true},
|
||||
{"a not b", true},
|
||||
{"a not a", false},
|
||||
{"a has a", true},
|
||||
{"a has b", false},
|
||||
{"ba has b", true},
|
||||
{"bab has b", true},
|
||||
{"bab has bb", false},
|
||||
{"a not_has a", false},
|
||||
{"a not_has b", true},
|
||||
{"ba not_has b", false},
|
||||
{"bab not_has b", false},
|
||||
{"bab not_has bb", true},
|
||||
{"bab starts_with bb", false},
|
||||
{"bab starts_with ba", true},
|
||||
{"bab starts_with bab", true},
|
||||
{"bab ends_with bb", false},
|
||||
{"bab ends_with bab", true},
|
||||
{"bab ends_with ab", true},
|
||||
{"a match *", false},
|
||||
{"a match a", true},
|
||||
{"a match .*", true},
|
||||
{"a match a.*", true},
|
||||
{"a match b.*", false},
|
||||
{"ba match b.*", true},
|
||||
{"ba match b[a-z]", true},
|
||||
{"b0 match b[a-z]", false},
|
||||
{"b0a match b[a-z]", false},
|
||||
{"b0a match b[a-z]+", false},
|
||||
{"b0a match b[a-z0-9]+", true},
|
||||
{"a not_match *", true},
|
||||
{"a not_match a", false},
|
||||
{"a not_match .*", false},
|
||||
{"a not_match a.*", false},
|
||||
{"a not_match b.*", true},
|
||||
{"ba not_match b.*", false},
|
||||
{"ba not_match b[a-z]", false},
|
||||
{"b0 not_match b[a-z]", true},
|
||||
{"b0a not_match b[a-z]", true},
|
||||
{"b0a not_match b[a-z]+", true},
|
||||
{"b0a not_match b[a-z0-9]+", false},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
str := strings.Fields(test.condition)
|
||||
ifCond, err := NewIf(str[0], str[1], str[2])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
isTrue := ifCond.True(nil)
|
||||
if isTrue != test.isTrue {
|
||||
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
|
||||
}
|
||||
}
|
||||
|
||||
invalidOperators := []string{"ss", "and", "if"}
|
||||
for _, op := range invalidOperators {
|
||||
_, err := NewIf("a", op, "b")
|
||||
if err == nil {
|
||||
t.Errorf("Invalid operator %v used, expected error.", op)
|
||||
}
|
||||
}
|
||||
|
||||
replaceTests := []struct {
|
||||
url string
|
||||
condition string
|
||||
isTrue bool
|
||||
}{
|
||||
{"/home", "{uri} match /home", true},
|
||||
{"/hom", "{uri} match /home", false},
|
||||
{"/hom", "{uri} starts_with /home", false},
|
||||
{"/hom", "{uri} starts_with /h", true},
|
||||
{"/home/.hiddenfile", `{uri} match \/\.(.*)`, true},
|
||||
{"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true},
|
||||
}
|
||||
|
||||
for i, test := range replaceTests {
|
||||
r, err := http.NewRequest("GET", test.url, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
str := strings.Fields(test.condition)
|
||||
ifCond, err := NewIf(str[0], str[1], str[2])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
isTrue := ifCond.True(r)
|
||||
if isTrue != test.isTrue {
|
||||
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
|
||||
}
|
||||
}
|
||||
}
|
236
caddyhttp/rewrite/rewrite.go
Normal file
236
caddyhttp/rewrite/rewrite.go
Normal file
|
@ -0,0 +1,236 @@
|
|||
// Package rewrite is middleware for rewriting requests internally to
|
||||
// a different path.
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// Result is the result of a rewrite
|
||||
type Result int
|
||||
|
||||
const (
|
||||
// RewriteIgnored is returned when rewrite is not done on request.
|
||||
RewriteIgnored Result = iota
|
||||
// RewriteDone is returned when rewrite is done on request.
|
||||
RewriteDone
|
||||
// RewriteStatus is returned when rewrite is not needed and status code should be set
|
||||
// for the request.
|
||||
RewriteStatus
|
||||
)
|
||||
|
||||
// Rewrite is middleware to rewrite request locations internally before being handled.
|
||||
type Rewrite struct {
|
||||
Next httpserver.Handler
|
||||
FileSys http.FileSystem
|
||||
Rules []Rule
|
||||
}
|
||||
|
||||
// ServeHTTP implements the httpserver.Handler interface.
|
||||
func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
outer:
|
||||
for _, rule := range rw.Rules {
|
||||
switch result := rule.Rewrite(rw.FileSys, r); result {
|
||||
case RewriteDone:
|
||||
break outer
|
||||
case RewriteIgnored:
|
||||
break
|
||||
case RewriteStatus:
|
||||
// only valid for complex rules.
|
||||
if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 {
|
||||
return cRule.Status, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return rw.Next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Rule describes an internal location rewrite rule.
|
||||
type Rule interface {
|
||||
// Rewrite rewrites the internal location of the current request.
|
||||
Rewrite(http.FileSystem, *http.Request) Result
|
||||
}
|
||||
|
||||
// SimpleRule is a simple rewrite rule.
|
||||
type SimpleRule struct {
|
||||
From, To string
|
||||
}
|
||||
|
||||
// NewSimpleRule creates a new Simple Rule
|
||||
func NewSimpleRule(from, to string) SimpleRule {
|
||||
return SimpleRule{from, to}
|
||||
}
|
||||
|
||||
// Rewrite rewrites the internal location of the current request.
|
||||
func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
|
||||
if s.From == r.URL.Path {
|
||||
// take note of this rewrite for internal use by fastcgi
|
||||
// all we need is the URI, not full URL
|
||||
r.Header.Set(headerFieldName, r.URL.RequestURI())
|
||||
|
||||
// attempt rewrite
|
||||
return To(fs, r, s.To, newReplacer(r))
|
||||
}
|
||||
return RewriteIgnored
|
||||
}
|
||||
|
||||
// ComplexRule is a rewrite rule based on a regular expression
|
||||
type ComplexRule struct {
|
||||
// Path base. Request to this path and subpaths will be rewritten
|
||||
Base string
|
||||
|
||||
// Path to rewrite to
|
||||
To string
|
||||
|
||||
// If set, neither performs rewrite nor proceeds
|
||||
// with request. Only returns code.
|
||||
Status int
|
||||
|
||||
// Extensions to filter by
|
||||
Exts []string
|
||||
|
||||
// Rewrite conditions
|
||||
Ifs []If
|
||||
|
||||
*regexp.Regexp
|
||||
}
|
||||
|
||||
// NewComplexRule creates a new RegexpRule. It returns an error if regexp
|
||||
// pattern (pattern) or extensions (ext) are invalid.
|
||||
func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) {
|
||||
// validate regexp if present
|
||||
var r *regexp.Regexp
|
||||
if pattern != "" {
|
||||
var err error
|
||||
r, err = regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// validate extensions if present
|
||||
for _, v := range ext {
|
||||
if len(v) < 2 || (len(v) < 3 && v[0] == '!') {
|
||||
// check if no extension is specified
|
||||
if v != "/" && v != "!/" {
|
||||
return nil, fmt.Errorf("invalid extension %v", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ComplexRule{
|
||||
Base: base,
|
||||
To: to,
|
||||
Status: status,
|
||||
Exts: ext,
|
||||
Ifs: ifs,
|
||||
Regexp: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Rewrite rewrites the internal location of the current request.
|
||||
func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result) {
|
||||
rPath := req.URL.Path
|
||||
replacer := newReplacer(req)
|
||||
|
||||
// validate base
|
||||
if !httpserver.Path(rPath).Matches(r.Base) {
|
||||
return
|
||||
}
|
||||
|
||||
// validate extensions
|
||||
if !r.matchExt(rPath) {
|
||||
return
|
||||
}
|
||||
|
||||
// validate regexp if present
|
||||
if r.Regexp != nil {
|
||||
// include trailing slash in regexp if present
|
||||
start := len(r.Base)
|
||||
if strings.HasSuffix(r.Base, "/") {
|
||||
start--
|
||||
}
|
||||
|
||||
matches := r.FindStringSubmatch(rPath[start:])
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
// no match
|
||||
return
|
||||
default:
|
||||
// set regexp match variables {1}, {2} ...
|
||||
|
||||
// url escaped values of ? and #.
|
||||
q, f := url.QueryEscape("?"), url.QueryEscape("#")
|
||||
|
||||
for i := 1; i < len(matches); i++ {
|
||||
// Special case of unescaped # and ? by stdlib regexp.
|
||||
// Reverse the unescape.
|
||||
if strings.ContainsAny(matches[i], "?#") {
|
||||
matches[i] = strings.NewReplacer("?", q, "#", f).Replace(matches[i])
|
||||
}
|
||||
|
||||
replacer.Set(fmt.Sprint(i), matches[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validate rewrite conditions
|
||||
for _, i := range r.Ifs {
|
||||
if !i.True(req) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// if status is present, stop rewrite and return it.
|
||||
if r.Status != 0 {
|
||||
return RewriteStatus
|
||||
}
|
||||
|
||||
// attempt rewrite
|
||||
return To(fs, req, r.To, replacer)
|
||||
}
|
||||
|
||||
// matchExt matches rPath against registered file extensions.
|
||||
// Returns true if a match is found and false otherwise.
|
||||
func (r *ComplexRule) matchExt(rPath string) bool {
|
||||
f := filepath.Base(rPath)
|
||||
ext := path.Ext(f)
|
||||
if ext == "" {
|
||||
ext = "/"
|
||||
}
|
||||
|
||||
mustUse := false
|
||||
for _, v := range r.Exts {
|
||||
use := true
|
||||
if v[0] == '!' {
|
||||
use = false
|
||||
v = v[1:]
|
||||
}
|
||||
|
||||
if use {
|
||||
mustUse = true
|
||||
}
|
||||
|
||||
if ext == v {
|
||||
return use
|
||||
}
|
||||
}
|
||||
|
||||
if mustUse {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// When a rewrite is performed, this header is added to the request
|
||||
// and is for internal use only, specifically the fastcgi middleware.
|
||||
// It contains the original request URI before the rewrite.
|
||||
const headerFieldName = "Caddy-Rewrite-Original-URI"
|
163
caddyhttp/rewrite/rewrite_test.go
Normal file
163
caddyhttp/rewrite/rewrite_test.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestRewrite(t *testing.T) {
|
||||
rw := Rewrite{
|
||||
Next: httpserver.HandlerFunc(urlPrinter),
|
||||
Rules: []Rule{
|
||||
NewSimpleRule("/from", "/to"),
|
||||
NewSimpleRule("/a", "/b"),
|
||||
NewSimpleRule("/b", "/b{uri}"),
|
||||
},
|
||||
FileSys: http.Dir("."),
|
||||
}
|
||||
|
||||
regexps := [][]string{
|
||||
{"/reg/", ".*", "/to", ""},
|
||||
{"/r/", "[a-z]+", "/toaz", "!.html|"},
|
||||
{"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""},
|
||||
{"/ab/", "ab", "/ab?{query}", ".txt|"},
|
||||
{"/ab/", "ab", "/ab?type=html&{query}", ".html|"},
|
||||
{"/abc/", "ab", "/abc/{file}", ".html|"},
|
||||
{"/abcd/", "ab", "/a/{dir}/{file}", ".html|"},
|
||||
{"/abcde/", "ab", "/a#{fragment}", ".html|"},
|
||||
{"/ab/", `.*\.jpg`, "/ajpg", ""},
|
||||
{"/reggrp", `/ad/([0-9]+)([a-z]*)`, "/a{1}/{2}", ""},
|
||||
{"/reg2grp", `(.*)`, "/{1}", ""},
|
||||
{"/reg3grp", `(.*)/(.*)/(.*)`, "/{1}{2}{3}", ""},
|
||||
{"/hashtest", "(.*)", "/{1}", ""},
|
||||
}
|
||||
|
||||
for _, regexpRule := range regexps {
|
||||
var ext []string
|
||||
if s := strings.Split(regexpRule[3], "|"); len(s) > 1 {
|
||||
ext = s[:len(s)-1]
|
||||
}
|
||||
rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rw.Rules = append(rw.Rules, rule)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
from string
|
||||
expectedTo string
|
||||
}{
|
||||
{"/from", "/to"},
|
||||
{"/a", "/b"},
|
||||
{"/b", "/b/b"},
|
||||
{"/aa", "/aa"},
|
||||
{"/", "/"},
|
||||
{"/a?foo=bar", "/b?foo=bar"},
|
||||
{"/asdf?foo=bar", "/asdf?foo=bar"},
|
||||
{"/foo#bar", "/foo#bar"},
|
||||
{"/a#foo", "/b#foo"},
|
||||
{"/reg/foo", "/to"},
|
||||
{"/re", "/re"},
|
||||
{"/r/", "/r/"},
|
||||
{"/r/123", "/r/123"},
|
||||
{"/r/a123", "/toaz"},
|
||||
{"/r/abcz", "/toaz"},
|
||||
{"/r/z", "/toaz"},
|
||||
{"/r/z.html", "/r/z.html"},
|
||||
{"/r/z.js", "/toaz"},
|
||||
{"/url/asAB", "/to/url/asAB"},
|
||||
{"/url/aBsAB", "/url/aBsAB"},
|
||||
{"/url/a00sAB", "/to/url/a00sAB"},
|
||||
{"/url/a0z0sAB", "/to/url/a0z0sAB"},
|
||||
{"/ab/aa", "/ab/aa"},
|
||||
{"/ab/ab", "/ab/ab"},
|
||||
{"/ab/ab.txt", "/ab"},
|
||||
{"/ab/ab.txt?name=name", "/ab?name=name"},
|
||||
{"/ab/ab.html?name=name", "/ab?type=html&name=name"},
|
||||
{"/abc/ab.html", "/abc/ab.html"},
|
||||
{"/abcd/abcd.html", "/a/abcd/abcd.html"},
|
||||
{"/abcde/abcde.html", "/a"},
|
||||
{"/abcde/abcde.html#1234", "/a#1234"},
|
||||
{"/ab/ab.jpg", "/ajpg"},
|
||||
{"/reggrp/ad/12", "/a12"},
|
||||
{"/reggrp/ad/124a", "/a124/a"},
|
||||
{"/reggrp/ad/124abc", "/a124/abc"},
|
||||
{"/reg2grp/ad/124abc", "/ad/124abc"},
|
||||
{"/reg3grp/ad/aa/66", "/adaa66"},
|
||||
{"/reg3grp/ad612/n1n/ab", "/ad612n1nab"},
|
||||
{"/hashtest/a%20%23%20test", "/a%20%23%20test"},
|
||||
{"/hashtest/a%20%3F%20test", "/a%20%3F%20test"},
|
||||
{"/hashtest/a%20%3F%23test", "/a%20%3F%23test"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
req, err := http.NewRequest("GET", test.from, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
rw.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Body.String() != test.expectedTo {
|
||||
t.Errorf("Test %d: Expected URL to be '%s' but was '%s'",
|
||||
i, test.expectedTo, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
statusTests := []struct {
|
||||
status int
|
||||
base string
|
||||
to string
|
||||
regexp string
|
||||
statusExpected bool
|
||||
}{
|
||||
{400, "/status", "", "", true},
|
||||
{400, "/ignore", "", "", false},
|
||||
{400, "/", "", "^/ignore", false},
|
||||
{400, "/", "", "(.*)", true},
|
||||
{400, "/status", "", "", true},
|
||||
}
|
||||
|
||||
for i, s := range statusTests {
|
||||
urlPath := fmt.Sprintf("/status%d", i)
|
||||
rule, err := NewComplexRule(s.base, s.regexp, s.to, s.status, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: No error expected for rule but found %v", i, err)
|
||||
}
|
||||
rw.Rules = []Rule{rule}
|
||||
req, err := http.NewRequest("GET", urlPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
code, err := rw.ServeHTTP(rec, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: No error expected for handler but found %v", i, err)
|
||||
}
|
||||
if s.statusExpected {
|
||||
if rec.Body.String() != "" {
|
||||
t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String())
|
||||
}
|
||||
if code != s.status {
|
||||
t.Errorf("Test %d: Expected status code %d found %d", i, s.status, code)
|
||||
}
|
||||
} else {
|
||||
if code != 0 {
|
||||
t.Errorf("Test %d: Expected no status code found %d", i, code)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
fmt.Fprint(w, r.URL.String())
|
||||
return 0, nil
|
||||
}
|
121
caddyhttp/rewrite/setup.go
Normal file
121
caddyhttp/rewrite/setup.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin(caddy.Plugin{
|
||||
Name: "rewrite",
|
||||
ServerType: "http",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// setup configures a new Rewrite middleware instance.
|
||||
func setup(c *caddy.Controller) error {
|
||||
rewrites, err := rewriteParse(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := httpserver.GetConfig(c.Key)
|
||||
|
||||
cfg.AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
return Rewrite{
|
||||
Next: next,
|
||||
FileSys: http.Dir(cfg.Root),
|
||||
Rules: rewrites,
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func rewriteParse(c *caddy.Controller) ([]Rule, error) {
|
||||
var simpleRules []Rule
|
||||
var regexpRules []Rule
|
||||
|
||||
for c.Next() {
|
||||
var rule Rule
|
||||
var err error
|
||||
var base = "/"
|
||||
var pattern, to string
|
||||
var status int
|
||||
var ext []string
|
||||
|
||||
args := c.RemainingArgs()
|
||||
|
||||
var ifs []If
|
||||
|
||||
switch len(args) {
|
||||
case 1:
|
||||
base = args[0]
|
||||
fallthrough
|
||||
case 0:
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "r", "regexp":
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
pattern = c.Val()
|
||||
case "to":
|
||||
args1 := c.RemainingArgs()
|
||||
if len(args1) == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
to = strings.Join(args1, " ")
|
||||
case "ext":
|
||||
args1 := c.RemainingArgs()
|
||||
if len(args1) == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
ext = args1
|
||||
case "if":
|
||||
args1 := c.RemainingArgs()
|
||||
if len(args1) != 3 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
ifCond, err := NewIf(args1[0], args1[1], args1[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ifs = append(ifs, ifCond)
|
||||
case "status":
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
status, _ = strconv.Atoi(c.Val())
|
||||
if status < 200 || (status > 299 && status < 400) || status > 499 {
|
||||
return nil, c.Err("status must be 2xx or 4xx")
|
||||
}
|
||||
default:
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
}
|
||||
// ensure to or status is specified
|
||||
if to == "" && status == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
if rule, err = NewComplexRule(base, pattern, to, status, ext, ifs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
regexpRules = append(regexpRules, rule)
|
||||
|
||||
// the only unhandled case is 2 and above
|
||||
default:
|
||||
rule = NewSimpleRule(args[0], strings.Join(args[1:], " "))
|
||||
simpleRules = append(simpleRules, rule)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// put simple rules in front to avoid regexp computation for them
|
||||
return append(simpleRules, regexpRules...), nil
|
||||
}
|
239
caddyhttp/rewrite/setup_test.go
Normal file
239
caddyhttp/rewrite/setup_test.go
Normal file
|
@ -0,0 +1,239 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
err := setup(caddy.NewTestController(`rewrite /from /to`))
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, but got: %v", err)
|
||||
}
|
||||
mids := httpserver.GetConfig("").Middleware()
|
||||
if len(mids) == 0 {
|
||||
t.Fatal("Expected middleware, had 0 instead")
|
||||
}
|
||||
|
||||
handler := mids[0](httpserver.EmptyNext)
|
||||
myHandler, ok := handler.(Rewrite)
|
||||
if !ok {
|
||||
t.Fatalf("Expected handler to be type Rewrite, got: %#v", handler)
|
||||
}
|
||||
|
||||
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
|
||||
t.Error("'Next' field of handler was not set properly")
|
||||
}
|
||||
|
||||
if len(myHandler.Rules) != 1 {
|
||||
t.Errorf("Expected handler to have %d rule, has %d instead", 1, len(myHandler.Rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteParse(t *testing.T) {
|
||||
simpleTests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expected []Rule
|
||||
}{
|
||||
{`rewrite /from /to`, false, []Rule{
|
||||
SimpleRule{From: "/from", To: "/to"},
|
||||
}},
|
||||
{`rewrite /from /to
|
||||
rewrite a b`, false, []Rule{
|
||||
SimpleRule{From: "/from", To: "/to"},
|
||||
SimpleRule{From: "a", To: "b"},
|
||||
}},
|
||||
{`rewrite a`, true, []Rule{}},
|
||||
{`rewrite`, true, []Rule{}},
|
||||
{`rewrite a b c`, false, []Rule{
|
||||
SimpleRule{From: "a", To: "b c"},
|
||||
}},
|
||||
}
|
||||
|
||||
for i, test := range simpleTests {
|
||||
actual, err := rewriteParse(caddy.NewTestController(test.input))
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Errorf("Test %d didn't error, but it should have", i)
|
||||
} else if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
|
||||
} else if err != nil && test.shouldErr {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(actual) != len(test.expected) {
|
||||
t.Fatalf("Test %d expected %d rules, but got %d",
|
||||
i, len(test.expected), len(actual))
|
||||
}
|
||||
|
||||
for j, e := range test.expected {
|
||||
actualRule := actual[j].(SimpleRule)
|
||||
expectedRule := e.(SimpleRule)
|
||||
|
||||
if actualRule.From != expectedRule.From {
|
||||
t.Errorf("Test %d, rule %d: Expected From=%s, got %s",
|
||||
i, j, expectedRule.From, actualRule.From)
|
||||
}
|
||||
|
||||
if actualRule.To != expectedRule.To {
|
||||
t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
|
||||
i, j, expectedRule.To, actualRule.To)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
regexpTests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expected []Rule
|
||||
}{
|
||||
{`rewrite {
|
||||
r .*
|
||||
to /to /index.php?
|
||||
}`, false, []Rule{
|
||||
&ComplexRule{Base: "/", To: "/to /index.php?", Regexp: regexp.MustCompile(".*")},
|
||||
}},
|
||||
{`rewrite {
|
||||
regexp .*
|
||||
to /to
|
||||
ext / html txt
|
||||
}`, false, []Rule{
|
||||
&ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")},
|
||||
}},
|
||||
{`rewrite /path {
|
||||
r rr
|
||||
to /dest
|
||||
}
|
||||
rewrite / {
|
||||
regexp [a-z]+
|
||||
to /to /to2
|
||||
}
|
||||
`, false, []Rule{
|
||||
&ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")},
|
||||
&ComplexRule{Base: "/", To: "/to /to2", Regexp: regexp.MustCompile("[a-z]+")},
|
||||
}},
|
||||
{`rewrite {
|
||||
r .*
|
||||
}`, true, []Rule{
|
||||
&ComplexRule{},
|
||||
}},
|
||||
{`rewrite {
|
||||
|
||||
}`, true, []Rule{
|
||||
&ComplexRule{},
|
||||
}},
|
||||
{`rewrite /`, true, []Rule{
|
||||
&ComplexRule{},
|
||||
}},
|
||||
{`rewrite {
|
||||
to /to
|
||||
if {path} is a
|
||||
}`, false, []Rule{
|
||||
&ComplexRule{Base: "/", To: "/to", Ifs: []If{{A: "{path}", Operator: "is", B: "a"}}},
|
||||
}},
|
||||
{`rewrite {
|
||||
status 500
|
||||
}`, true, []Rule{
|
||||
&ComplexRule{},
|
||||
}},
|
||||
{`rewrite {
|
||||
status 400
|
||||
}`, false, []Rule{
|
||||
&ComplexRule{Base: "/", Status: 400},
|
||||
}},
|
||||
{`rewrite {
|
||||
to /to
|
||||
status 400
|
||||
}`, false, []Rule{
|
||||
&ComplexRule{Base: "/", To: "/to", Status: 400},
|
||||
}},
|
||||
{`rewrite {
|
||||
status 399
|
||||
}`, true, []Rule{
|
||||
&ComplexRule{},
|
||||
}},
|
||||
{`rewrite {
|
||||
status 200
|
||||
}`, false, []Rule{
|
||||
&ComplexRule{Base: "/", Status: 200},
|
||||
}},
|
||||
{`rewrite {
|
||||
to /to
|
||||
status 200
|
||||
}`, false, []Rule{
|
||||
&ComplexRule{Base: "/", To: "/to", Status: 200},
|
||||
}},
|
||||
{`rewrite {
|
||||
status 199
|
||||
}`, true, []Rule{
|
||||
&ComplexRule{},
|
||||
}},
|
||||
{`rewrite {
|
||||
status 0
|
||||
}`, true, []Rule{
|
||||
&ComplexRule{},
|
||||
}},
|
||||
{`rewrite {
|
||||
to /to
|
||||
status 0
|
||||
}`, true, []Rule{
|
||||
&ComplexRule{},
|
||||
}},
|
||||
}
|
||||
|
||||
for i, test := range regexpTests {
|
||||
actual, err := rewriteParse(caddy.NewTestController(test.input))
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Errorf("Test %d didn't error, but it should have", i)
|
||||
} else if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
|
||||
} else if err != nil && test.shouldErr {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(actual) != len(test.expected) {
|
||||
t.Fatalf("Test %d expected %d rules, but got %d",
|
||||
i, len(test.expected), len(actual))
|
||||
}
|
||||
|
||||
for j, e := range test.expected {
|
||||
actualRule := actual[j].(*ComplexRule)
|
||||
expectedRule := e.(*ComplexRule)
|
||||
|
||||
if actualRule.Base != expectedRule.Base {
|
||||
t.Errorf("Test %d, rule %d: Expected Base=%s, got %s",
|
||||
i, j, expectedRule.Base, actualRule.Base)
|
||||
}
|
||||
|
||||
if actualRule.To != expectedRule.To {
|
||||
t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
|
||||
i, j, expectedRule.To, actualRule.To)
|
||||
}
|
||||
|
||||
if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) {
|
||||
t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v",
|
||||
i, j, expectedRule.To, actualRule.To)
|
||||
}
|
||||
|
||||
if actualRule.Regexp != nil {
|
||||
if actualRule.String() != expectedRule.String() {
|
||||
t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
|
||||
i, j, expectedRule.String(), actualRule.String())
|
||||
}
|
||||
}
|
||||
|
||||
if fmt.Sprint(actualRule.Ifs) != fmt.Sprint(expectedRule.Ifs) {
|
||||
t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
|
||||
i, j, fmt.Sprint(expectedRule.Ifs), fmt.Sprint(actualRule.Ifs))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
0
caddyhttp/rewrite/testdata/testdir/empty
vendored
Normal file
0
caddyhttp/rewrite/testdata/testdir/empty
vendored
Normal file
1
caddyhttp/rewrite/testdata/testfile
vendored
Normal file
1
caddyhttp/rewrite/testdata/testfile
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
empty
|
87
caddyhttp/rewrite/to.go
Normal file
87
caddyhttp/rewrite/to.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
// To attempts rewrite. It attempts to rewrite to first valid path
|
||||
// or the last path if none of the paths are valid.
|
||||
// Returns true if rewrite is successful and false otherwise.
|
||||
func To(fs http.FileSystem, r *http.Request, to string, replacer httpserver.Replacer) Result {
|
||||
tos := strings.Fields(to)
|
||||
|
||||
// try each rewrite paths
|
||||
t := ""
|
||||
for _, v := range tos {
|
||||
t = path.Clean(replacer.Replace(v))
|
||||
|
||||
// add trailing slash for directories, if present
|
||||
if strings.HasSuffix(v, "/") && !strings.HasSuffix(t, "/") {
|
||||
t += "/"
|
||||
}
|
||||
|
||||
// validate file
|
||||
if isValidFile(fs, t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// validate resulting path
|
||||
u, err := url.Parse(t)
|
||||
if err != nil {
|
||||
// Let the user know we got here. Rewrite is expected but
|
||||
// the resulting url is invalid.
|
||||
log.Printf("[ERROR] rewrite: resulting path '%v' is invalid. error: %v", t, err)
|
||||
return RewriteIgnored
|
||||
}
|
||||
|
||||
// take note of this rewrite for internal use by fastcgi
|
||||
// all we need is the URI, not full URL
|
||||
r.Header.Set(headerFieldName, r.URL.RequestURI())
|
||||
|
||||
// perform rewrite
|
||||
r.URL.Path = u.Path
|
||||
if u.RawQuery != "" {
|
||||
// overwrite query string if present
|
||||
r.URL.RawQuery = u.RawQuery
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
// overwrite fragment if present
|
||||
r.URL.Fragment = u.Fragment
|
||||
}
|
||||
|
||||
return RewriteDone
|
||||
}
|
||||
|
||||
// isValidFile checks if file exists on the filesystem.
|
||||
// if file ends with `/`, it is validated as a directory.
|
||||
func isValidFile(fs http.FileSystem, file string) bool {
|
||||
if fs == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
f, err := fs.Open(file)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// directory
|
||||
if strings.HasSuffix(file, "/") {
|
||||
return stat.IsDir()
|
||||
}
|
||||
|
||||
// file
|
||||
return !stat.IsDir()
|
||||
}
|
44
caddyhttp/rewrite/to_test.go
Normal file
44
caddyhttp/rewrite/to_test.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTo(t *testing.T) {
|
||||
fs := http.Dir("testdata")
|
||||
tests := []struct {
|
||||
url string
|
||||
to string
|
||||
expected string
|
||||
}{
|
||||
{"/", "/somefiles", "/somefiles"},
|
||||
{"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"},
|
||||
{"/somefiles", "/testfile /index.php{uri}", "/testfile"},
|
||||
{"/somefiles", "/testfile/ /index.php{uri}", "/index.php/somefiles"},
|
||||
{"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"},
|
||||
{"/?a=b", "/somefiles /index.php?{query}", "/index.php?a=b"},
|
||||
{"/?a=b", "/testfile /index.php?{query}", "/testfile?a=b"},
|
||||
{"/?a=b", "/testdir /index.php?{query}", "/index.php?a=b"},
|
||||
{"/?a=b", "/testdir/ /index.php?{query}", "/testdir/?a=b"},
|
||||
}
|
||||
|
||||
uri := func(r *url.URL) string {
|
||||
uri := r.Path
|
||||
if r.RawQuery != "" {
|
||||
uri += "?" + r.RawQuery
|
||||
}
|
||||
return uri
|
||||
}
|
||||
for i, test := range tests {
|
||||
r, err := http.NewRequest("GET", test.url, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
To(fs, r, test.to, newReplacer(r))
|
||||
if uri(r.URL) != test.expected {
|
||||
t.Errorf("Test %v: expected %v found %v", i, test.expected, uri(r.URL))
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user