diff --git a/modules/caddyhttp/replacer.go b/modules/caddyhttp/replacer.go index 57bc880f6..439d24586 100644 --- a/modules/caddyhttp/replacer.go +++ b/modules/caddyhttp/replacer.go @@ -15,86 +15,143 @@ package caddyhttp import ( - "fmt" "net" "net/http" + "net/textproto" "path" + "strconv" "strings" "github.com/caddyserver/caddy/v2" ) -// TODO: A simple way to format or escape or encode each value would be nice -// ... TODO: Should we just use templates? :-/ yeesh... - func addHTTPVarsToReplacer(repl caddy.Replacer, req *http.Request, w http.ResponseWriter) { - httpVars := func() map[string]string { - m := make(map[string]string) + httpVars := func(key string) (string, bool) { if req != nil { - m["http.request.host"] = func() string { + if strings.HasPrefix(key, queryReplPrefix) { + vals := req.URL.Query()[key[len(queryReplPrefix):]] + // always return true, since the query param might + // be present only in some requests + return strings.Join(vals, ","), true + } + + if strings.HasPrefix(key, reqHeaderReplPrefix) { + field := key[len(reqHeaderReplPrefix):] + vals := req.Header[textproto.CanonicalMIMEHeaderKey(field)] + // always return true, since the header field might + // be present only in some requests + return strings.Join(vals, ","), true + } + + if strings.HasPrefix(key, cookieReplPrefix) { + name := key[len(cookieReplPrefix):] + for _, cookie := range req.Cookies() { + if strings.EqualFold(name, cookie.Name) { + // always return true, since the cookie might + // be present only in some requests + return cookie.Value, true + } + } + } + + switch key { + case "http.request.host": host, _, err := net.SplitHostPort(req.Host) if err != nil { - return req.Host // OK; there probably was no port + return req.Host, true // OK; there probably was no port } - return host - }() - m["http.request.hostport"] = req.Host // may include both host and port - m["http.request.method"] = req.Method - m["http.request.port"] = func() string { - // if there is no port, there will be an error; in - // that case, port is the empty string anyway + return host, true + case "http.request.hostport": + return req.Host, true + case "http.request.method": + return req.Method, true + case "http.request.port": _, port, _ := net.SplitHostPort(req.Host) - return port - }() - m["http.request.scheme"] = func() string { + return port, true + case "http.request.scheme": if req.TLS != nil { - return "https" + return "https", true } - return "http" - }() - m["http.request.uri"] = req.URL.RequestURI() - m["http.request.uri.path"] = req.URL.Path - m["http.request.uri.path.file"] = func() string { + return "http", true + case "http.request.uri": + return req.URL.RequestURI(), true + case "http.request.uri.path": + return req.URL.Path, true + case "http.request.uri.path.file": _, file := path.Split(req.URL.Path) - return file - }() - m["http.request.uri.path.dir"] = func() string { + return file, true + case "http.request.uri.path.dir": dir, _ := path.Split(req.URL.Path) - return dir - }() - m["http.request.uri.query"] = req.URL.RawQuery - - for param, vals := range req.URL.Query() { - m["http.request.uri.query."+param] = strings.Join(vals, ",") - } - for field, vals := range req.Header { - m["http.request.header."+strings.ToLower(field)] = strings.Join(vals, ",") - } - for _, cookie := range req.Cookies() { - m["http.request.cookie."+cookie.Name] = cookie.Value + return dir, true + case "http.request.uri.query": + return req.URL.RawQuery, true } - hostLabels := strings.Split(req.Host, ".") - for i, label := range hostLabels { - key := fmt.Sprintf("http.request.host.labels.%d", len(hostLabels)-i-1) - m[key] = label + if strings.HasPrefix(key, respHeaderReplPrefix) { + field := key[len(respHeaderReplPrefix):] + vals := w.Header()[textproto.CanonicalMIMEHeaderKey(field)] + // always return true, since the header field might + // be present only in some requests + return strings.Join(vals, ","), true } - pathParts := strings.Split(req.URL.Path, "/") - for i, label := range pathParts { - key := fmt.Sprintf("http.request.uri.path.%d", i) - m[key] = label + if strings.HasPrefix(key, hostLabelReplPrefix) { + idxStr := key[len(hostLabelReplPrefix):] + idx, err := strconv.Atoi(idxStr) + if err != nil { + return "", false + } + hostLabels := strings.Split(req.Host, ".") + if idx < 0 { + return "", false + } + if idx >= len(hostLabels) { + return "", true + } + return hostLabels[idx], true + } + + if strings.HasPrefix(key, pathPartsReplPrefix) { + idxStr := key[len(pathPartsReplPrefix):] + idx, err := strconv.Atoi(idxStr) + if err != nil { + return "", false + } + pathParts := strings.Split(req.URL.Path, "/") + if len(pathParts) > 0 && pathParts[0] == "" { + pathParts = pathParts[1:] + } + if idx < 0 { + return "", false + } + if idx >= len(pathParts) { + return "", true + } + return pathParts[idx], true } } if w != nil { - for field, vals := range w.Header() { - m["http.response.header."+strings.ToLower(field)] = strings.Join(vals, ",") + if strings.HasPrefix(key, respHeaderReplPrefix) { + field := key[len(respHeaderReplPrefix):] + vals := w.Header()[textproto.CanonicalMIMEHeaderKey(field)] + // always return true, since the header field might + // be present only in some responses + return strings.Join(vals, ","), true } } - return m + return "", false } repl.Map(httpVars) } + +const ( + queryReplPrefix = "http.request.uri.query." + reqHeaderReplPrefix = "http.request.header." + cookieReplPrefix = "http.request.cookie." + hostLabelReplPrefix = "http.request.host.labels." + pathPartsReplPrefix = "http.request.uri.path." + respHeaderReplPrefix = "http.response.header." +) diff --git a/replacer.go b/replacer.go index e21e3cfc6..e8a4f9b13 100644 --- a/replacer.go +++ b/replacer.go @@ -25,7 +25,7 @@ import ( type Replacer interface { Set(variable, value string) Delete(variable string) - Map(func() map[string]string) + Map(ReplacementFunc) ReplaceAll(input, empty string) string } @@ -34,23 +34,22 @@ func NewReplacer() Replacer { rep := &replacer{ static: make(map[string]string), } - rep.providers = []ReplacementsFunc{ - defaultReplacements, - func() map[string]string { return rep.static }, + rep.providers = []ReplacementFunc{ + globalDefaultReplacements, + rep.fromStatic, } return rep } type replacer struct { - providers []ReplacementsFunc + providers []ReplacementFunc static map[string]string } -// Map augments the map of replacements with those returned -// by the given replacements function. The function is only -// executed at replace-time. -func (r *replacer) Map(replacements func() map[string]string) { - r.providers = append(r.providers, replacements) +// Map adds mapFunc to the list of value providers. +// mapFunc will be executed only at replace-time. +func (r *replacer) Map(mapFunc ReplacementFunc) { + r.providers = append(r.providers, mapFunc) } // Set sets a custom variable to a static value. @@ -64,55 +63,104 @@ func (r *replacer) Delete(variable string) { delete(r.static, variable) } -// ReplaceAll replaces placeholders in input with their values. -// Values that are empty string will be substituted with the -// empty parameter. -func (r *replacer) ReplaceAll(input, empty string) string { - if !strings.Contains(input, phOpen) { - return input - } - for _, replacements := range r.providers { - for key, val := range replacements() { - if val == "" { - val = empty - } - input = strings.ReplaceAll(input, phOpen+key+phClose, val) - } - } - return input +// fromStatic provides values from r.static. +func (r *replacer) fromStatic(key string) (val string, ok bool) { + val, ok = r.static[key] + return } -// ReplacementsFunc is a function that returns replacements, -// which is variable names mapped to their values. The -// function will be evaluated only at replace-time to ensure -// the most current values are mapped. -type ReplacementsFunc func() map[string]string - -var defaultReplacements = func() map[string]string { - m := map[string]string{ - "system.hostname": func() string { - // OK if there is an error; just return empty string - name, _ := os.Hostname() - return name - }(), - "system.slash": string(filepath.Separator), - "system.os": runtime.GOOS, - "system.arch": runtime.GOARCH, +// ReplaceAll efficiently replaces placeholders in input with +// their values. Unrecognized placeholders will not be replaced. +// Values that are empty string will be substituted with empty. +func (r *replacer) ReplaceAll(input, empty string) string { + if !strings.Contains(input, string(phOpen)) { + return input } - // add environment variables - for _, keyval := range os.Environ() { - parts := strings.SplitN(keyval, "=", 2) - if len(parts) != 2 { + var sb strings.Builder + + // it is reasonable to assume that the output + // will be approximately as long as the input + sb.Grow(len(input)) + + // iterate the input to find each placeholder + var lastWriteCursor int + for i := 0; i < len(input); i++ { + if input[i] != phOpen { continue } - m["env."+strings.ToLower(parts[0])] = parts[1] + + // write the substring from the last cursor to this point + sb.WriteString(input[lastWriteCursor:i]) + + // find the end of the placeholder + end := strings.Index(input[i:], string(phClose)) + i + + // trim opening bracket + key := input[i+1 : end] + + // try to get a value for this key; if + // the key is not recognized, do not + // perform any replacement + var found bool + for _, mapFunc := range r.providers { + if val, ok := mapFunc(key); ok { + found = true + if val != "" { + sb.WriteString(val) + } else if empty != "" { + sb.WriteString(empty) + } + break + } + } + if !found { + continue + } + + // advance cursor to end of placeholder + i = end + 1 + lastWriteCursor = i } - return m + // flush any unwritten remainder + sb.WriteString(input[lastWriteCursor:]) + + return sb.String() +} + +// ReplacementFunc is a function that returns a replacement +// for the given key along with true if the function is able +// to service that key (even if the value is blank). If the +// function does not recognize the key, false should be +// returned. +type ReplacementFunc func(key string) (val string, ok bool) + +func globalDefaultReplacements(key string) (string, bool) { + // check environment variable + const envPrefix = "env." + if strings.HasPrefix(key, envPrefix) { + val := os.Getenv(key[len(envPrefix):]) + return val, val != "" + } + + switch key { + case "system.hostname": + // OK if there is an error; just return empty string + name, _ := os.Hostname() + return name, true + case "system.slash": + return string(filepath.Separator), true + case "system.os": + return runtime.GOOS, true + case "system.arch": + return runtime.GOARCH, true + } + + return "", false } // ReplacerCtxKey is the context key for a replacer. const ReplacerCtxKey CtxKey = "replacer" -const phOpen, phClose = "{", "}" +const phOpen, phClose = '{', '}'