diff --git a/fs/rc/internal.go b/fs/rc/internal.go index 3f2b02644..a3018bc4a 100644 --- a/fs/rc/internal.go +++ b/fs/rc/internal.go @@ -353,17 +353,22 @@ func init() { - command - a string with the command name - arg - a list of arguments for the backend command - opt - a map of string to string of options +- returnType - one of ("COMBINED_OUTPUT", "STREAM", "STREAM_ONLY_STDOUT", "STREAM_ONLY_STDERR") + - defaults to "COMBINED_OUTPUT" if not set + - the STREAM returnTypes will write the output to the body of the HTTP message + - the COMBINED_OUTPUT will write the output to the "result" parameter Returns - result - result from the backend command + - only set when using returnType "COMBINED_OUTPUT" - error - set if rclone exits with an error code -- returnType - one of ("COMBINED_OUTPUT", "STREAM", "STREAM_ONLY_STDOUT". "STREAM_ONLY_STDERR") +- returnType - one of ("COMBINED_OUTPUT", "STREAM", "STREAM_ONLY_STDOUT", "STREAM_ONLY_STDERR") For example rclone rc core/command command=ls -a mydrive:/ -o max-depth=1 - rclone rc core/command -a ls -a mydrive:/ -o max-depth=1 + rclone rc core/command -a ls -a mydrive:/ -o max-depth=1 Returns @@ -386,7 +391,6 @@ OR // rcRunCommand runs an rclone command with the given args and flags func rcRunCommand(ctx context.Context, in Params) (out Params, err error) { - command, err := in.GetString("command") if err != nil { command = "" @@ -409,7 +413,7 @@ func rcRunCommand(ctx context.Context, in Params) (out Params, err error) { returnType = "COMBINED_OUTPUT" } - var httpResponse *http.ResponseWriter + var httpResponse http.ResponseWriter httpResponse, err = in.GetHTTPResponseWriter() if err != nil { return nil, errors.Errorf("response object is required\n" + err.Error()) @@ -460,12 +464,14 @@ func rcRunCommand(ctx context.Context, in Params) (out Params, err error) { "error": false, }, nil } else if returnType == "STREAM_ONLY_STDOUT" { - cmd.Stdout = *httpResponse + cmd.Stdout = httpResponse } else if returnType == "STREAM_ONLY_STDERR" { - cmd.Stderr = *httpResponse + cmd.Stderr = httpResponse } else if returnType == "STREAM" { - cmd.Stdout = *httpResponse - cmd.Stderr = *httpResponse + cmd.Stdout = httpResponse + cmd.Stderr = httpResponse + } else { + return nil, errors.Errorf("Unknown returnType %q", returnType) } err = cmd.Run() diff --git a/fs/rc/internal_test.go b/fs/rc/internal_test.go index 3eac2be57..1b9b13f57 100644 --- a/fs/rc/internal_test.go +++ b/fs/rc/internal_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "os" "runtime" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -22,6 +23,12 @@ func TestMain(m *testing.M) { fmt.Printf("rclone %s\n", fs.Version) os.Exit(0) } + // Pretend to error if we have an unknown command + if os.Args[len(os.Args)-1] == "unknown_command" { + fmt.Printf("rclone %s\n", fs.Version) + fmt.Fprintf(os.Stderr, "Unknown command\n") + os.Exit(1) + } os.Exit(m.Run()) } @@ -136,17 +143,56 @@ func TestCoreQuit(t *testing.T) { func TestCoreCommand(t *testing.T) { call := Calls.Get("core/command") - var httpResponse http.ResponseWriter = httptest.NewRecorder() + test := func(command string, returnType string, wantOutput string, fail bool) { + var rec = httptest.NewRecorder() + var w http.ResponseWriter = rec - in := Params{ - "command": "version", - "opt": map[string]string{}, - "arg": []string{}, - "_response": &httpResponse, + in := Params{ + "command": command, + "opt": map[string]string{}, + "arg": []string{}, + "_response": w, + } + if returnType != "" { + in["returnType"] = returnType + } else { + returnType = "COMBINED_OUTPUT" + } + stream := strings.HasPrefix(returnType, "STREAM") + got, err := call.Fn(context.Background(), in) + if stream && fail { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + if !stream { + assert.Equal(t, wantOutput, got["result"]) + assert.Equal(t, fail, got["error"]) + } else { + assert.Equal(t, wantOutput, rec.Body.String()) + } + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) } - got, err := call.Fn(context.Background(), in) - require.NoError(t, err) - assert.Equal(t, fmt.Sprintf("rclone %s\n", fs.Version), got["result"]) - assert.Equal(t, false, got["error"]) + version := fmt.Sprintf("rclone %s\n", fs.Version) + errorString := "Unknown command\n" + t.Run("OK", func(t *testing.T) { + test("version", "", version, false) + }) + t.Run("Fail", func(t *testing.T) { + test("unknown_command", "", version+errorString, true) + }) + t.Run("Combined", func(t *testing.T) { + test("unknown_command", "COMBINED_OUTPUT", version+errorString, true) + }) + t.Run("Stderr", func(t *testing.T) { + test("unknown_command", "STREAM_ONLY_STDERR", errorString, true) + }) + t.Run("Stdout", func(t *testing.T) { + test("unknown_command", "STREAM_ONLY_STDOUT", version, true) + }) + t.Run("Stream", func(t *testing.T) { + test("unknown_command", "STREAM", version+errorString, true) + }) } diff --git a/fs/rc/params.go b/fs/rc/params.go index bdd020b1e..2684b5120 100644 --- a/fs/rc/params.go +++ b/fs/rc/params.go @@ -112,15 +112,15 @@ func (p Params) GetHTTPRequest() (*http.Request, error) { // // If the parameter isn't found then error will be of type // ErrParamNotFound and the returned value will be nil. -func (p Params) GetHTTPResponseWriter() (*http.ResponseWriter, error) { +func (p Params) GetHTTPResponseWriter() (http.ResponseWriter, error) { key := "_response" value, err := p.Get(key) if err != nil { return nil, err } - request, ok := value.(*http.ResponseWriter) + request, ok := value.(http.ResponseWriter) if !ok { - return nil, ErrParamInvalid{errors.Errorf("expecting *http.ResponseWriter value for key %q (was %T)", key, value)} + return nil, ErrParamInvalid{errors.Errorf("expecting http.ResponseWriter value for key %q (was %T)", key, value)} } return request, nil } diff --git a/fs/rc/params_test.go b/fs/rc/params_test.go index 1ceeee430..e9b770c41 100644 --- a/fs/rc/params_test.go +++ b/fs/rc/params_test.go @@ -2,6 +2,8 @@ package rc import ( "fmt" + "net/http" + "net/http/httptest" "testing" "time" @@ -346,3 +348,53 @@ func TestParamsGetStructMissingOK(t *testing.T) { assert.Equal(t, 4.2, out.Float) assert.Equal(t, true, IsErrParamInvalid(e3), e3.Error()) } + +func TestParamsGetHTTPRequest(t *testing.T) { + in := Params{} + req, err := in.GetHTTPRequest() + assert.Nil(t, req) + assert.Error(t, err) + assert.Equal(t, true, IsErrParamNotFound(err), err.Error()) + + in = Params{ + "_request": 42, + } + req, err = in.GetHTTPRequest() + assert.Nil(t, req) + assert.Error(t, err) + assert.Equal(t, true, IsErrParamInvalid(err), err.Error()) + + r := new(http.Request) + in = Params{ + "_request": r, + } + req, err = in.GetHTTPRequest() + assert.NotNil(t, req) + assert.NoError(t, err) + assert.Equal(t, r, req) +} + +func TestParamsGetHTTPResponseWriter(t *testing.T) { + in := Params{} + wr, err := in.GetHTTPResponseWriter() + assert.Nil(t, wr) + assert.Error(t, err) + assert.Equal(t, true, IsErrParamNotFound(err), err.Error()) + + in = Params{ + "_response": 42, + } + wr, err = in.GetHTTPResponseWriter() + assert.Nil(t, wr) + assert.Error(t, err) + assert.Equal(t, true, IsErrParamInvalid(err), err.Error()) + + var w http.ResponseWriter = httptest.NewRecorder() + in = Params{ + "_response": w, + } + wr, err = in.GetHTTPResponseWriter() + assert.NotNil(t, wr) + assert.NoError(t, err) + assert.Equal(t, w, wr) +}