From 245ac321c3169babe70a700108ea91fe5149c326 Mon Sep 17 00:00:00 2001 From: wxiaoguang Date: Tue, 11 Feb 2025 16:46:03 +0800 Subject: [PATCH] Fix context usage (#33554) Some old code use direct type-casting to get context, it causes problems. This PR fixes all legacy problems and use correct `ctx.Value` to get low-level contexts. Fix #33518 --- routers/install/install.go | 2 +- routers/private/internal.go | 2 +- routers/web/web.go | 2 +- services/auth/auth.go | 2 +- services/auth/sspi.go | 2 +- services/context/context.go | 9 +++++---- services/context/package.go | 2 +- services/contexttest/context_tests.go | 2 +- services/markup/renderhelper.go | 4 ++-- services/markup/renderhelper_codepreview.go | 4 ++-- services/markup/renderhelper_issueicontitle.go | 4 ++-- 11 files changed, 18 insertions(+), 17 deletions(-) diff --git a/routers/install/install.go b/routers/install/install.go index 8a1d57aa0b3..2cede3685d1 100644 --- a/routers/install/install.go +++ b/routers/install/install.go @@ -64,7 +64,7 @@ func Contexter() func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { base := context.NewBaseContext(resp, req) ctx := context.NewWebContext(base, rnd, session.GetSession(req)) - ctx.SetContextValue(context.WebContextKey, ctx) + ctx.SetContextValue(context.WebContextKey, ctx) // FIXME: this should be removed because NewWebContext should already set it ctx.Data.MergeFrom(middleware.CommonTemplateContextData()) ctx.Data.MergeFrom(reqctx.ContextData{ "Title": ctx.Locale.Tr("install.install"), diff --git a/routers/private/internal.go b/routers/private/internal.go index a78c76f8970..2232c1b78c8 100644 --- a/routers/private/internal.go +++ b/routers/private/internal.go @@ -88,7 +88,7 @@ func Routes() *web.Router { // Fortunately, the LFS handlers are able to handle requests without a complete web context common.AddOwnerRepoGitLFSRoutes(r, func(ctx *context.PrivateContext) { webContext := &context.Context{Base: ctx.Base} - ctx.SetContextValue(context.WebContextKey, webContext) + ctx.SetContextValue(context.WebContextKey, webContext) // FIXME: this is not ideal but no other way at the moment }) }) diff --git a/routers/web/web.go b/routers/web/web.go index 65548073d28..f5bd6a92979 100644 --- a/routers/web/web.go +++ b/routers/web/web.go @@ -1637,7 +1637,7 @@ func registerRoutes(m *web.Router) { } m.NotFound(func(w http.ResponseWriter, req *http.Request) { - ctx := context.GetWebContext(req) + ctx := context.GetWebContext(req.Context()) defer routing.RecordFuncInfo(ctx, routing.GetFuncInfo(ctx.NotFound, "WebNotFound"))() ctx.NotFound("", nil) }) diff --git a/services/auth/auth.go b/services/auth/auth.go index 7deca9bc3d7..f7deeb4c502 100644 --- a/services/auth/auth.go +++ b/services/auth/auth.go @@ -149,7 +149,7 @@ func handleSignIn(resp http.ResponseWriter, req *http.Request, sess SessionStore middleware.SetLocaleCookie(resp, user.Language, 0) // force to generate a new CSRF token - if ctx := gitea_context.GetWebContext(req); ctx != nil { + if ctx := gitea_context.GetWebContext(req.Context()); ctx != nil { ctx.Csrf.PrepareForSessionUser(ctx) } } diff --git a/services/auth/sspi.go b/services/auth/sspi.go index 3882740ae37..8cb39886c4f 100644 --- a/services/auth/sspi.go +++ b/services/auth/sspi.go @@ -88,7 +88,7 @@ func (s *SSPI) Verify(req *http.Request, w http.ResponseWriter, store DataStore, store.GetData()["EnableSSPI"] = true // in this case, the Verify function is called in Gitea's web context // FIXME: it doesn't look good to render the page here, why not redirect? - gitea_context.GetWebContext(req).HTML(http.StatusUnauthorized, tplSignIn) + gitea_context.GetWebContext(req.Context()).HTML(http.StatusUnauthorized, tplSignIn) return nil, err } if outToken != "" { diff --git a/services/context/context.go b/services/context/context.go index 7aeb0de7ab4..ffce1d617a5 100644 --- a/services/context/context.go +++ b/services/context/context.go @@ -79,9 +79,9 @@ type webContextKeyType struct{} var WebContextKey = webContextKeyType{} -func GetWebContext(req *http.Request) *Context { - ctx, _ := req.Context().Value(WebContextKey).(*Context) - return ctx +func GetWebContext(ctx context.Context) *Context { + webCtx, _ := ctx.Value(WebContextKey).(*Context) + return webCtx } // ValidateContext is a special context for form validation middleware. It may be different from other contexts. @@ -135,6 +135,7 @@ func NewWebContext(base *Base, render Render, session session.Store) *Context { } ctx.TemplateContext = NewTemplateContextForWeb(ctx) ctx.Flash = &middleware.Flash{DataStore: ctx, Values: url.Values{}} + ctx.SetContextValue(WebContextKey, ctx) return ctx } @@ -165,7 +166,7 @@ func Contexter() func(next http.Handler) http.Handler { ctx.PageData = map[string]any{} ctx.Data["PageData"] = ctx.PageData - ctx.Base.SetContextValue(WebContextKey, ctx) + ctx.Base.SetContextValue(WebContextKey, ctx) // FIXME: this should be removed because NewWebContext should already set it ctx.Csrf = NewCSRFProtector(csrfOpts) // get the last flash message from cookie diff --git a/services/context/package.go b/services/context/package.go index e98e01acbb0..e566b7e5322 100644 --- a/services/context/package.go +++ b/services/context/package.go @@ -156,7 +156,7 @@ func PackageContexter() func(next http.Handler) http.Handler { base := NewBaseContext(resp, req) // it is still needed when rendering 500 page in a package handler ctx := NewWebContext(base, renderer, nil) - ctx.SetContextValue(WebContextKey, ctx) + ctx.SetContextValue(WebContextKey, ctx) // FIXME: this should be removed because NewWebContext should already set it next.ServeHTTP(ctx.Resp, ctx.Req) }) } diff --git a/services/contexttest/context_tests.go b/services/contexttest/context_tests.go index b0f71cad20e..4615c8404bb 100644 --- a/services/contexttest/context_tests.go +++ b/services/contexttest/context_tests.go @@ -67,7 +67,7 @@ func MockContext(t *testing.T, reqPath string, opts ...MockContextOption) (*cont chiCtx := chi.NewRouteContext() ctx := context.NewWebContext(base, opt.Render, nil) - ctx.SetContextValue(context.WebContextKey, ctx) + ctx.SetContextValue(context.WebContextKey, ctx) // FIXME: this should be removed because NewWebContext should already set it ctx.SetContextValue(chi.RouteCtxKey, chiCtx) if opt.SessionStore != nil { ctx.SetContextValue(session.MockStoreContextKey, opt.SessionStore) diff --git a/services/markup/renderhelper.go b/services/markup/renderhelper.go index 4b9852b48bf..ea494146a7b 100644 --- a/services/markup/renderhelper.go +++ b/services/markup/renderhelper.go @@ -21,8 +21,8 @@ func FormalRenderHelperFuncs() *markup.RenderHelperFuncs { return false } - giteaCtx, ok := ctx.(*gitea_context.Context) - if !ok { + giteaCtx := gitea_context.GetWebContext(ctx) + if giteaCtx == nil { // when using general context, use user's visibility to check return mentionedUser.Visibility.IsPublic() } diff --git a/services/markup/renderhelper_codepreview.go b/services/markup/renderhelper_codepreview.go index 170c70c4098..d638af7ff06 100644 --- a/services/markup/renderhelper_codepreview.go +++ b/services/markup/renderhelper_codepreview.go @@ -36,8 +36,8 @@ func renderRepoFileCodePreview(ctx context.Context, opts markup.RenderCodePrevie return "", err } - webCtx, ok := ctx.Value(gitea_context.WebContextKey).(*gitea_context.Context) - if !ok { + webCtx := gitea_context.GetWebContext(ctx) + if webCtx == nil { return "", fmt.Errorf("context is not a web context") } doer := webCtx.Doer diff --git a/services/markup/renderhelper_issueicontitle.go b/services/markup/renderhelper_issueicontitle.go index 53a508e9081..fd8f9d43fa1 100644 --- a/services/markup/renderhelper_issueicontitle.go +++ b/services/markup/renderhelper_issueicontitle.go @@ -18,8 +18,8 @@ import ( ) func renderRepoIssueIconTitle(ctx context.Context, opts markup.RenderIssueIconTitleOptions) (_ template.HTML, err error) { - webCtx, ok := ctx.Value(gitea_context.WebContextKey).(*gitea_context.Context) - if !ok { + webCtx := gitea_context.GetWebContext(ctx) + if webCtx == nil { return "", fmt.Errorf("context is not a web context") }