From ec14ccdd4075e46dd0235cc99ecfd92ed131c10a Mon Sep 17 00:00:00 2001 From: Tim Culverhouse Date: Mon, 29 Nov 2021 11:29:40 -0600 Subject: [PATCH] templates: fix inconsistent nested includes (#4452) --- .../caddyhttp/templates/frontmatter_fuzz.go | 1 + modules/caddyhttp/templates/tplcontext.go | 30 ++++--- .../caddyhttp/templates/tplcontext_test.go | 90 +++++++++++++++++++ 3 files changed, 109 insertions(+), 12 deletions(-) diff --git a/modules/caddyhttp/templates/frontmatter_fuzz.go b/modules/caddyhttp/templates/frontmatter_fuzz.go index 8d8427b02..361b4b626 100644 --- a/modules/caddyhttp/templates/frontmatter_fuzz.go +++ b/modules/caddyhttp/templates/frontmatter_fuzz.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build gofuzz // +build gofuzz package templates diff --git a/modules/caddyhttp/templates/tplcontext.go b/modules/caddyhttp/templates/tplcontext.go index b1646c127..4f3cbf50a 100644 --- a/modules/caddyhttp/templates/tplcontext.go +++ b/modules/caddyhttp/templates/tplcontext.go @@ -91,7 +91,12 @@ func (c TemplateContext) OriginalReq() http.Request { // trusted files. If it is not trusted, be sure to use escaping functions // in your template. func (c TemplateContext) funcInclude(filename string, args ...interface{}) (string, error) { - bodyBuf, err := c.readFileToBuffer(filename) + + bodyBuf := bufPool.Get().(*bytes.Buffer) + bodyBuf.Reset() + defer bufPool.Put(bodyBuf) + + err := c.readFileToBuffer(filename, bodyBuf) if err != nil { return "", err @@ -107,28 +112,24 @@ func (c TemplateContext) funcInclude(filename string, args ...interface{}) (stri return bodyBuf.String(), nil } -// readFileToBuffer returns the contents of filename relative to root as a buffer -func (c TemplateContext) readFileToBuffer(filename string) (*bytes.Buffer, error) { +// readFileToBuffer reads a file into a buffer +func (c TemplateContext) readFileToBuffer(filename string, bodyBuf *bytes.Buffer) error { if c.Root == nil { - return nil, fmt.Errorf("root file system not specified") + return fmt.Errorf("root file system not specified") } file, err := c.Root.Open(filename) if err != nil { - return nil, err + return err } defer file.Close() - bodyBuf := bufPool.Get().(*bytes.Buffer) - bodyBuf.Reset() - defer bufPool.Put(bodyBuf) - _, err = io.Copy(bodyBuf, file) if err != nil { - return nil, err + return err } - return bodyBuf, nil + return nil } // funcHTTPInclude returns the body of a virtual (lightweight) request @@ -185,7 +186,12 @@ func (c TemplateContext) funcHTTPInclude(uri string) (string, error) { // {{ template }} from the standard template library. If the imported file has // no {{ define }} blocks, the name of the import will be the path func (c *TemplateContext) funcImport(filename string) (string, error) { - bodyBuf, err := c.readFileToBuffer(filename) + + bodyBuf := bufPool.Get().(*bytes.Buffer) + bodyBuf.Reset() + defer bufPool.Put(bodyBuf) + + err := c.readFileToBuffer(filename, bodyBuf) if err != nil { return "", err } diff --git a/modules/caddyhttp/templates/tplcontext_test.go b/modules/caddyhttp/templates/tplcontext_test.go index ddc8b99ba..bd0497514 100644 --- a/modules/caddyhttp/templates/tplcontext_test.go +++ b/modules/caddyhttp/templates/tplcontext_test.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "fmt" + "io/ioutil" "net/http" "os" "path/filepath" @@ -187,6 +188,95 @@ func TestImport(t *testing.T) { } } +func TestNestedInclude(t *testing.T) { + for i, test := range []struct { + child string + childFile string + parent string + parentFile string + shouldErr bool + expect string + child2 string + child2File string + }{ + { + // include in parent + child: `{{ include "file1" }}`, + childFile: "file0", + parent: `{{ $content := "file2" }}{{ $p := include $content}}`, + parentFile: "file1", + shouldErr: false, + expect: ``, + child2: `This shouldn't show`, + child2File: "file2", + }, + } { + context := getContextOrFail(t) + var absFilePath string + var absFilePath0 string + var absFilePath1 string + var buf *bytes.Buffer + var err error + + // create files and for test case + if test.parentFile != "" { + absFilePath = filepath.Join(fmt.Sprintf("%s", context.Root), test.parentFile) + if err := ioutil.WriteFile(absFilePath, []byte(test.parent), os.ModePerm); err != nil { + os.Remove(absFilePath) + t.Fatalf("Test %d: Expected no error creating file, got: '%s'", i, err.Error()) + } + } + if test.childFile != "" { + absFilePath0 = filepath.Join(fmt.Sprintf("%s", context.Root), test.childFile) + if err := ioutil.WriteFile(absFilePath0, []byte(test.child), os.ModePerm); err != nil { + os.Remove(absFilePath0) + t.Fatalf("Test %d: Expected no error creating file, got: '%s'", i, err.Error()) + } + } + if test.child2File != "" { + absFilePath1 = filepath.Join(fmt.Sprintf("%s", context.Root), test.child2File) + if err := ioutil.WriteFile(absFilePath1, []byte(test.child2), os.ModePerm); err != nil { + os.Remove(absFilePath0) + t.Fatalf("Test %d: Expected no error creating file, got: '%s'", i, err.Error()) + } + } + + buf = bufPool.Get().(*bytes.Buffer) + buf.Reset() + defer bufPool.Put(buf) + buf.WriteString(test.child) + err = context.executeTemplateInBuffer(test.childFile, buf) + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error, got: '%s'", i, err) + } + } else if test.shouldErr { + t.Errorf("Test %d: Expected error but had none", i) + } else if buf.String() != test.expect { + // + t.Errorf("Test %d: Expected '%s' but got '%s'", i, test.expect, buf.String()) + + } + + if absFilePath != "" { + if err := os.Remove(absFilePath); err != nil && !os.IsNotExist(err) { + t.Fatalf("Test %d: Expected no error removing temporary test file, got: %v", i, err) + } + } + if absFilePath0 != "" { + if err := os.Remove(absFilePath0); err != nil && !os.IsNotExist(err) { + t.Fatalf("Test %d: Expected no error removing temporary test file, got: %v", i, err) + } + } + if absFilePath1 != "" { + if err := os.Remove(absFilePath1); err != nil && !os.IsNotExist(err) { + t.Fatalf("Test %d: Expected no error removing temporary test file, got: %v", i, err) + } + } + } +} + func TestInclude(t *testing.T) { for i, test := range []struct { fileContent string