diff --git a/modules/context/api.go b/modules/context/api.go index e263dcbe8d..092ad73f31 100644 --- a/modules/context/api.go +++ b/modules/context/api.go @@ -13,18 +13,32 @@ import ( "code.gitea.io/gitea/models/auth" repo_model "code.gitea.io/gitea/models/repo" - "code.gitea.io/gitea/modules/cache" + "code.gitea.io/gitea/models/unit" + user_model "code.gitea.io/gitea/models/user" + mc "code.gitea.io/gitea/modules/cache" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/httpcache" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" - "code.gitea.io/gitea/modules/web/middleware" + + "gitea.com/go-chi/cache" ) // APIContext is a specific context for API service type APIContext struct { - *Context - Org *APIOrganization + *Base + + Cache cache.Cache + + Doer *user_model.User // current signed-in user + IsSigned bool + IsBasicAuth bool + + ContextUser *user_model.User // the user which is being visited, in most cases it differs from Doer + + Repo *Repository + Org *APIOrganization + Package *Package } // Currently, we have the following common fields in error response: @@ -128,11 +142,6 @@ type apiContextKeyType struct{} var apiContextKey = apiContextKeyType{} -// WithAPIContext set up api context in request -func WithAPIContext(req *http.Request, ctx *APIContext) *http.Request { - return req.WithContext(context.WithValue(req.Context(), apiContextKey, ctx)) -} - // GetAPIContext returns a context for API routes func GetAPIContext(req *http.Request) *APIContext { return req.Context().Value(apiContextKey).(*APIContext) @@ -195,21 +204,21 @@ func (ctx *APIContext) CheckForOTP() { } otpHeader := ctx.Req.Header.Get("X-Gitea-OTP") - twofa, err := auth.GetTwoFactorByUID(ctx.Context.Doer.ID) + twofa, err := auth.GetTwoFactorByUID(ctx.Doer.ID) if err != nil { if auth.IsErrTwoFactorNotEnrolled(err) { return // No 2FA enrollment for this user } - ctx.Context.Error(http.StatusInternalServerError) + ctx.Error(http.StatusInternalServerError, "GetTwoFactorByUID", err) return } ok, err := twofa.ValidateTOTP(otpHeader) if err != nil { - ctx.Context.Error(http.StatusInternalServerError) + ctx.Error(http.StatusInternalServerError, "ValidateTOTP", err) return } if !ok { - ctx.Context.Error(http.StatusUnauthorized) + ctx.Error(http.StatusUnauthorized, "", nil) return } } @@ -218,23 +227,17 @@ func (ctx *APIContext) CheckForOTP() { func APIContexter() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - locale := middleware.Locale(w, req) - ctx := APIContext{ - Context: &Context{ - Resp: NewResponse(w), - Data: middleware.GetContextData(req.Context()), - Locale: locale, - Cache: cache.GetCache(), - Repo: &Repository{ - PullRequest: &PullRequest{}, - }, - Org: &Organization{}, - }, - Org: &APIOrganization{}, + base, baseCleanUp := NewBaseContext(w, req) + ctx := &APIContext{ + Base: base, + Cache: mc.GetCache(), + Repo: &Repository{PullRequest: &PullRequest{}}, + Org: &APIOrganization{}, } - defer ctx.Close() + defer baseCleanUp() - ctx.Req = WithAPIContext(WithContext(req, ctx.Context), &ctx) + ctx.Base.AppendContextValue(apiContextKey, ctx) + ctx.Base.AppendContextValueFunc(git.RepositoryContextKey, func() any { return ctx.Repo.GitRepo }) // If request sends files, parse them here otherwise the Query() can't be parsed and the CsrfToken will be invalid. if ctx.Req.Method == "POST" && strings.Contains(ctx.Req.Header.Get("Content-Type"), "multipart/form-data") { @@ -247,8 +250,6 @@ func APIContexter() func(http.Handler) http.Handler { httpcache.SetCacheControlInHeader(ctx.Resp.Header(), 0, "no-transform") ctx.Resp.Header().Set(`X-Frame-Options`, setting.CORSConfig.XFrameOptions) - ctx.Data["Context"] = &ctx - next.ServeHTTP(ctx.Resp, ctx.Req) }) } @@ -301,7 +302,7 @@ func ReferencesGitRepo(allowEmpty ...bool) func(ctx *APIContext) (cancel context return func() { // If it's been set to nil then assume someone else has closed it. if ctx.Repo.GitRepo != nil { - ctx.Repo.GitRepo.Close() + _ = ctx.Repo.GitRepo.Close() } } } @@ -337,7 +338,7 @@ func RepoRefForAPI(next http.Handler) http.Handler { } var err error - refName := getRefName(ctx.Context, RepoRefAny) + refName := getRefName(ctx.Base, ctx.Repo, RepoRefAny) if ctx.Repo.GitRepo.IsBranchExist(refName) { ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetBranchCommit(refName) @@ -368,3 +369,53 @@ func RepoRefForAPI(next http.Handler) http.Handler { next.ServeHTTP(w, req) }) } + +// HasAPIError returns true if error occurs in form validation. +func (ctx *APIContext) HasAPIError() bool { + hasErr, ok := ctx.Data["HasError"] + if !ok { + return false + } + return hasErr.(bool) +} + +// GetErrMsg returns error message in form validation. +func (ctx *APIContext) GetErrMsg() string { + msg, _ := ctx.Data["ErrorMsg"].(string) + if msg == "" { + msg = "invalid form data" + } + return msg +} + +// NotFoundOrServerError use error check function to determine if the error +// is about not found. It responds with 404 status code for not found error, +// or error context description for logging purpose of 500 server error. +func (ctx *APIContext) NotFoundOrServerError(logMsg string, errCheck func(error) bool, logErr error) { + if errCheck(logErr) { + ctx.JSON(http.StatusNotFound, nil) + return + } + ctx.Error(http.StatusInternalServerError, "NotFoundOrServerError", logMsg) +} + +// IsUserSiteAdmin returns true if current user is a site admin +func (ctx *APIContext) IsUserSiteAdmin() bool { + return ctx.IsSigned && ctx.Doer.IsAdmin +} + +// IsUserRepoAdmin returns true if current user is admin in current repo +func (ctx *APIContext) IsUserRepoAdmin() bool { + return ctx.Repo.IsAdmin() +} + +// IsUserRepoWriter returns true if current user has write privilege in current repo +func (ctx *APIContext) IsUserRepoWriter(unitTypes []unit.Type) bool { + for _, unitType := range unitTypes { + if ctx.Repo.CanWrite(unitType) { + return true + } + } + + return false +} diff --git a/modules/context/base.go b/modules/context/base.go new file mode 100644 index 0000000000..ac9b52d51c --- /dev/null +++ b/modules/context/base.go @@ -0,0 +1,300 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package context + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "code.gitea.io/gitea/modules/httplib" + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/translation" + "code.gitea.io/gitea/modules/util" + "code.gitea.io/gitea/modules/web/middleware" + + "github.com/go-chi/chi/v5" +) + +type contextValuePair struct { + key any + valueFn func() any +} + +type Base struct { + originCtx context.Context + contextValues []contextValuePair + + Resp ResponseWriter + Req *http.Request + + // Data is prepared by ContextDataStore middleware, this field only refers to the pre-created/prepared ContextData. + // Although it's mainly used for MVC templates, sometimes it's also used to pass data between middlewares/handler + Data middleware.ContextData + + // Locale is mainly for Web context, although the API context also uses it in some cases: message response, form validation + Locale translation.Locale +} + +func (b *Base) Deadline() (deadline time.Time, ok bool) { + return b.originCtx.Deadline() +} + +func (b *Base) Done() <-chan struct{} { + return b.originCtx.Done() +} + +func (b *Base) Err() error { + return b.originCtx.Err() +} + +func (b *Base) Value(key any) any { + for _, pair := range b.contextValues { + if pair.key == key { + return pair.valueFn() + } + } + return b.originCtx.Value(key) +} + +func (b *Base) AppendContextValueFunc(key any, valueFn func() any) any { + b.contextValues = append(b.contextValues, contextValuePair{key, valueFn}) + return b +} + +func (b *Base) AppendContextValue(key, value any) any { + b.contextValues = append(b.contextValues, contextValuePair{key, func() any { return value }}) + return b +} + +func (b *Base) GetData() middleware.ContextData { + return b.Data +} + +// AppendAccessControlExposeHeaders append headers by name to "Access-Control-Expose-Headers" header +func (b *Base) AppendAccessControlExposeHeaders(names ...string) { + val := b.RespHeader().Get("Access-Control-Expose-Headers") + if len(val) != 0 { + b.RespHeader().Set("Access-Control-Expose-Headers", fmt.Sprintf("%s, %s", val, strings.Join(names, ", "))) + } else { + b.RespHeader().Set("Access-Control-Expose-Headers", strings.Join(names, ", ")) + } +} + +// SetTotalCountHeader set "X-Total-Count" header +func (b *Base) SetTotalCountHeader(total int64) { + b.RespHeader().Set("X-Total-Count", fmt.Sprint(total)) + b.AppendAccessControlExposeHeaders("X-Total-Count") +} + +// Written returns true if there are something sent to web browser +func (b *Base) Written() bool { + return b.Resp.Status() > 0 +} + +// Status writes status code +func (b *Base) Status(status int) { + b.Resp.WriteHeader(status) +} + +// Write writes data to web browser +func (b *Base) Write(bs []byte) (int, error) { + return b.Resp.Write(bs) +} + +// RespHeader returns the response header +func (b *Base) RespHeader() http.Header { + return b.Resp.Header() +} + +// Error returned an error to web browser +func (b *Base) Error(status int, contents ...string) { + v := http.StatusText(status) + if len(contents) > 0 { + v = contents[0] + } + http.Error(b.Resp, v, status) +} + +// JSON render content as JSON +func (b *Base) JSON(status int, content interface{}) { + b.Resp.Header().Set("Content-Type", "application/json;charset=utf-8") + b.Resp.WriteHeader(status) + if err := json.NewEncoder(b.Resp).Encode(content); err != nil { + log.Error("Render JSON failed: %v", err) + } +} + +// RemoteAddr returns the client machine ip address +func (b *Base) RemoteAddr() string { + return b.Req.RemoteAddr +} + +// Params returns the param on route +func (b *Base) Params(p string) string { + s, _ := url.PathUnescape(chi.URLParam(b.Req, strings.TrimPrefix(p, ":"))) + return s +} + +// ParamsInt64 returns the param on route as int64 +func (b *Base) ParamsInt64(p string) int64 { + v, _ := strconv.ParseInt(b.Params(p), 10, 64) + return v +} + +// SetParams set params into routes +func (b *Base) SetParams(k, v string) { + chiCtx := chi.RouteContext(b) + chiCtx.URLParams.Add(strings.TrimPrefix(k, ":"), url.PathEscape(v)) +} + +// FormString returns the first value matching the provided key in the form as a string +func (b *Base) FormString(key string) string { + return b.Req.FormValue(key) +} + +// FormStrings returns a string slice for the provided key from the form +func (b *Base) FormStrings(key string) []string { + if b.Req.Form == nil { + if err := b.Req.ParseMultipartForm(32 << 20); err != nil { + return nil + } + } + if v, ok := b.Req.Form[key]; ok { + return v + } + return nil +} + +// FormTrim returns the first value for the provided key in the form as a space trimmed string +func (b *Base) FormTrim(key string) string { + return strings.TrimSpace(b.Req.FormValue(key)) +} + +// FormInt returns the first value for the provided key in the form as an int +func (b *Base) FormInt(key string) int { + v, _ := strconv.Atoi(b.Req.FormValue(key)) + return v +} + +// FormInt64 returns the first value for the provided key in the form as an int64 +func (b *Base) FormInt64(key string) int64 { + v, _ := strconv.ParseInt(b.Req.FormValue(key), 10, 64) + return v +} + +// FormBool returns true if the value for the provided key in the form is "1", "true" or "on" +func (b *Base) FormBool(key string) bool { + s := b.Req.FormValue(key) + v, _ := strconv.ParseBool(s) + v = v || strings.EqualFold(s, "on") + return v +} + +// FormOptionalBool returns an OptionalBoolTrue or OptionalBoolFalse if the value +// for the provided key exists in the form else it returns OptionalBoolNone +func (b *Base) FormOptionalBool(key string) util.OptionalBool { + value := b.Req.FormValue(key) + if len(value) == 0 { + return util.OptionalBoolNone + } + s := b.Req.FormValue(key) + v, _ := strconv.ParseBool(s) + v = v || strings.EqualFold(s, "on") + return util.OptionalBoolOf(v) +} + +func (b *Base) SetFormString(key, value string) { + _ = b.Req.FormValue(key) // force parse form + b.Req.Form.Set(key, value) +} + +// PlainTextBytes renders bytes as plain text +func (b *Base) plainTextInternal(skip, status int, bs []byte) { + statusPrefix := status / 100 + if statusPrefix == 4 || statusPrefix == 5 { + log.Log(skip, log.TRACE, "plainTextInternal (status=%d): %s", status, string(bs)) + } + b.Resp.Header().Set("Content-Type", "text/plain;charset=utf-8") + b.Resp.Header().Set("X-Content-Type-Options", "nosniff") + b.Resp.WriteHeader(status) + if _, err := b.Resp.Write(bs); err != nil { + log.ErrorWithSkip(skip, "plainTextInternal (status=%d): write bytes failed: %v", status, err) + } +} + +// PlainTextBytes renders bytes as plain text +func (b *Base) PlainTextBytes(status int, bs []byte) { + b.plainTextInternal(2, status, bs) +} + +// PlainText renders content as plain text +func (b *Base) PlainText(status int, text string) { + b.plainTextInternal(2, status, []byte(text)) +} + +// Redirect redirects the request +func (b *Base) Redirect(location string, status ...int) { + code := http.StatusSeeOther + if len(status) == 1 { + code = status[0] + } + + if strings.Contains(location, "://") || strings.HasPrefix(location, "//") { + // Some browsers (Safari) have buggy behavior for Cookie + Cache + External Redirection, eg: /my-path => https://other/path + // 1. the first request to "/my-path" contains cookie + // 2. some time later, the request to "/my-path" doesn't contain cookie (caused by Prevent web tracking) + // 3. Gitea's Sessioner doesn't see the session cookie, so it generates a new session id, and returns it to browser + // 4. then the browser accepts the empty session, then the user is logged out + // So in this case, we should remove the session cookie from the response header + removeSessionCookieHeader(b.Resp) + } + http.Redirect(b.Resp, b.Req, location, code) +} + +type ServeHeaderOptions httplib.ServeHeaderOptions + +func (b *Base) SetServeHeaders(opt *ServeHeaderOptions) { + httplib.ServeSetHeaders(b.Resp, (*httplib.ServeHeaderOptions)(opt)) +} + +// ServeContent serves content to http request +func (b *Base) ServeContent(r io.ReadSeeker, opts *ServeHeaderOptions) { + httplib.ServeSetHeaders(b.Resp, (*httplib.ServeHeaderOptions)(opts)) + http.ServeContent(b.Resp, b.Req, opts.Filename, opts.LastModified, r) +} + +// Close frees all resources hold by Context +func (b *Base) cleanUp() { + if b.Req != nil && b.Req.MultipartForm != nil { + _ = b.Req.MultipartForm.RemoveAll() // remove the temp files buffered to tmp directory + } +} + +func (b *Base) Tr(msg string, args ...any) string { + return b.Locale.Tr(msg, args...) +} + +func (b *Base) TrN(cnt any, key1, keyN string, args ...any) string { + return b.Locale.TrN(cnt, key1, keyN, args...) +} + +func NewBaseContext(resp http.ResponseWriter, req *http.Request) (b *Base, closeFunc func()) { + b = &Base{ + originCtx: req.Context(), + Req: req, + Resp: WrapResponseWriter(resp), + Locale: middleware.Locale(resp, req), + Data: middleware.GetContextData(req.Context()), + } + b.AppendContextValue(translation.ContextKey, b.Locale) + b.Req = b.Req.WithContext(b) + return b, b.cleanUp +} diff --git a/modules/context/context.go b/modules/context/context.go index 9ba1985f36..1e15081479 100644 --- a/modules/context/context.go +++ b/modules/context/context.go @@ -5,7 +5,6 @@ package context import ( - "context" "html" "html/template" "io" @@ -36,38 +35,27 @@ type Render interface { // Context represents context of a request. type Context struct { - Resp ResponseWriter - Req *http.Request - Render Render + *Base - Data middleware.ContextData // data used by MVC templates - PageData map[string]any // data used by JavaScript modules in one page, it's `window.config.pageData` + Render Render + PageData map[string]any // data used by JavaScript modules in one page, it's `window.config.pageData` - Locale translation.Locale Cache cache.Cache Csrf CSRFProtector Flash *middleware.Flash Session session.Store - Link string // current request URL (without query string) - Doer *user_model.User + Link string // current request URL (without query string) + + Doer *user_model.User // current signed-in user IsSigned bool IsBasicAuth bool - ContextUser *user_model.User - Repo *Repository - Org *Organization - Package *Package -} + ContextUser *user_model.User // the user which is being visited, in most cases it differs from Doer -// Close frees all resources hold by Context -func (ctx *Context) Close() error { - var err error - if ctx.Req != nil && ctx.Req.MultipartForm != nil { - err = ctx.Req.MultipartForm.RemoveAll() // remove the temp files buffered to tmp directory - } - // TODO: close opened repo, and more - return err + Repo *Repository + Org *Organization + Package *Package } // TrHTMLEscapeArgs runs ".Locale.Tr()" but pre-escapes all arguments with html.EscapeString. @@ -80,55 +68,30 @@ func (ctx *Context) TrHTMLEscapeArgs(msg string, args ...string) string { return ctx.Locale.Tr(msg, trArgs...) } -func (ctx *Context) Tr(msg string, args ...any) string { - return ctx.Locale.Tr(msg, args...) -} - -func (ctx *Context) TrN(cnt any, key1, keyN string, args ...any) string { - return ctx.Locale.TrN(cnt, key1, keyN, args...) -} - -// Deadline is part of the interface for context.Context and we pass this to the request context -func (ctx *Context) Deadline() (deadline time.Time, ok bool) { - return ctx.Req.Context().Deadline() -} - -// Done is part of the interface for context.Context and we pass this to the request context -func (ctx *Context) Done() <-chan struct{} { - return ctx.Req.Context().Done() -} - -// Err is part of the interface for context.Context and we pass this to the request context -func (ctx *Context) Err() error { - return ctx.Req.Context().Err() -} - -// Value is part of the interface for context.Context and we pass this to the request context -func (ctx *Context) Value(key interface{}) interface{} { - if key == git.RepositoryContextKey && ctx.Repo != nil { - return ctx.Repo.GitRepo - } - if key == translation.ContextKey && ctx.Locale != nil { - return ctx.Locale - } - return ctx.Req.Context().Value(key) -} - type contextKeyType struct{} var contextKey interface{} = contextKeyType{} -// WithContext set up install context in request -func WithContext(req *http.Request, ctx *Context) *http.Request { - return req.WithContext(context.WithValue(req.Context(), contextKey, ctx)) +func GetContext(req *http.Request) *Context { + ctx, _ := req.Context().Value(contextKey).(*Context) + return ctx } -// GetContext retrieves install context from request -func GetContext(req *http.Request) *Context { - if ctx, ok := req.Context().Value(contextKey).(*Context); ok { - return ctx +// ValidateContext is a special context for form validation middleware. It may be different from other contexts. +type ValidateContext struct { + *Base +} + +// GetValidateContext gets a context for middleware form validation +func GetValidateContext(req *http.Request) (ctx *ValidateContext) { + if ctxAPI, ok := req.Context().Value(apiContextKey).(*APIContext); ok { + ctx = &ValidateContext{Base: ctxAPI.Base} + } else if ctxWeb, ok := req.Context().Value(contextKey).(*Context); ok { + ctx = &ValidateContext{Base: ctxWeb.Base} + } else { + panic("invalid context, expect either APIContext or Context") } - return nil + return ctx } // Contexter initializes a classic context for a request. @@ -150,20 +113,17 @@ func Contexter() func(next http.Handler) http.Handler { } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - ctx := Context{ - Resp: NewResponse(resp), + base, baseCleanUp := NewBaseContext(resp, req) + ctx := &Context{ + Base: base, Cache: mc.GetCache(), - Locale: middleware.Locale(resp, req), Link: setting.AppSubURL + strings.TrimSuffix(req.URL.EscapedPath(), "/"), Render: rnd, Session: session.GetSession(req), - Repo: &Repository{ - PullRequest: &PullRequest{}, - }, - Org: &Organization{}, - Data: middleware.GetContextData(req.Context()), + Repo: &Repository{PullRequest: &PullRequest{}}, + Org: &Organization{}, } - defer ctx.Close() + defer baseCleanUp() ctx.Data.MergeFrom(middleware.CommonTemplateContextData()) ctx.Data["Context"] = &ctx @@ -175,15 +135,17 @@ func Contexter() func(next http.Handler) http.Handler { ctx.PageData = map[string]any{} ctx.Data["PageData"] = ctx.PageData - ctx.Req = WithContext(req, &ctx) - ctx.Csrf = PrepareCSRFProtector(csrfOpts, &ctx) + ctx.Base.AppendContextValue(contextKey, ctx) + ctx.Base.AppendContextValueFunc(git.RepositoryContextKey, func() any { return ctx.Repo.GitRepo }) + + ctx.Csrf = PrepareCSRFProtector(csrfOpts, ctx) // Get the last flash message from cookie lastFlashCookie := middleware.GetSiteCookie(ctx.Req, CookieNameFlash) if vals, _ := url.ParseQuery(lastFlashCookie); len(vals) > 0 { // store last Flash message into the template data, to render it ctx.Data["Flash"] = &middleware.Flash{ - DataStore: &ctx, + DataStore: ctx, Values: vals, ErrorMsg: vals.Get("error"), SuccessMsg: vals.Get("success"), @@ -193,7 +155,7 @@ func Contexter() func(next http.Handler) http.Handler { } // prepare an empty Flash message for current request - ctx.Flash = &middleware.Flash{DataStore: &ctx, Values: url.Values{}} + ctx.Flash = &middleware.Flash{DataStore: ctx, Values: url.Values{}} ctx.Resp.Before(func(resp ResponseWriter) { if val := ctx.Flash.Encode(); val != "" { middleware.SetSiteCookie(ctx.Resp, CookieNameFlash, val, 0) @@ -235,3 +197,24 @@ func Contexter() func(next http.Handler) http.Handler { }) } } + +// HasError returns true if error occurs in form validation. +// Attention: this function changes ctx.Data and ctx.Flash +func (ctx *Context) HasError() bool { + hasErr, ok := ctx.Data["HasError"] + if !ok { + return false + } + ctx.Flash.ErrorMsg = ctx.GetErrMsg() + ctx.Data["Flash"] = ctx.Flash + return hasErr.(bool) +} + +// GetErrMsg returns error message in form validation. +func (ctx *Context) GetErrMsg() string { + msg, _ := ctx.Data["ErrorMsg"].(string) + if msg == "" { + msg = "invalid form data" + } + return msg +} diff --git a/modules/context/context_data.go b/modules/context/context_data.go deleted file mode 100644 index cdf4ff9afe..0000000000 --- a/modules/context/context_data.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2023 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package context - -import "code.gitea.io/gitea/modules/web/middleware" - -// GetData returns the data -func (ctx *Context) GetData() middleware.ContextData { - return ctx.Data -} - -// HasAPIError returns true if error occurs in form validation. -func (ctx *Context) HasAPIError() bool { - hasErr, ok := ctx.Data["HasError"] - if !ok { - return false - } - return hasErr.(bool) -} - -// GetErrMsg returns error message -func (ctx *Context) GetErrMsg() string { - return ctx.Data["ErrorMsg"].(string) -} - -// HasError returns true if error occurs in form validation. -// Attention: this function changes ctx.Data and ctx.Flash -func (ctx *Context) HasError() bool { - hasErr, ok := ctx.Data["HasError"] - if !ok { - return false - } - ctx.Flash.ErrorMsg = ctx.Data["ErrorMsg"].(string) - ctx.Data["Flash"] = ctx.Flash - return hasErr.(bool) -} - -// HasValue returns true if value of given name exists. -func (ctx *Context) HasValue(name string) bool { - _, ok := ctx.Data[name] - return ok -} diff --git a/modules/context/context_form.go b/modules/context/context_form.go deleted file mode 100644 index 5c02152582..0000000000 --- a/modules/context/context_form.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2021 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package context - -import ( - "strconv" - "strings" - - "code.gitea.io/gitea/modules/util" -) - -// FormString returns the first value matching the provided key in the form as a string -func (ctx *Context) FormString(key string) string { - return ctx.Req.FormValue(key) -} - -// FormStrings returns a string slice for the provided key from the form -func (ctx *Context) FormStrings(key string) []string { - if ctx.Req.Form == nil { - if err := ctx.Req.ParseMultipartForm(32 << 20); err != nil { - return nil - } - } - if v, ok := ctx.Req.Form[key]; ok { - return v - } - return nil -} - -// FormTrim returns the first value for the provided key in the form as a space trimmed string -func (ctx *Context) FormTrim(key string) string { - return strings.TrimSpace(ctx.Req.FormValue(key)) -} - -// FormInt returns the first value for the provided key in the form as an int -func (ctx *Context) FormInt(key string) int { - v, _ := strconv.Atoi(ctx.Req.FormValue(key)) - return v -} - -// FormInt64 returns the first value for the provided key in the form as an int64 -func (ctx *Context) FormInt64(key string) int64 { - v, _ := strconv.ParseInt(ctx.Req.FormValue(key), 10, 64) - return v -} - -// FormBool returns true if the value for the provided key in the form is "1", "true" or "on" -func (ctx *Context) FormBool(key string) bool { - s := ctx.Req.FormValue(key) - v, _ := strconv.ParseBool(s) - v = v || strings.EqualFold(s, "on") - return v -} - -// FormOptionalBool returns an OptionalBoolTrue or OptionalBoolFalse if the value -// for the provided key exists in the form else it returns OptionalBoolNone -func (ctx *Context) FormOptionalBool(key string) util.OptionalBool { - value := ctx.Req.FormValue(key) - if len(value) == 0 { - return util.OptionalBoolNone - } - s := ctx.Req.FormValue(key) - v, _ := strconv.ParseBool(s) - v = v || strings.EqualFold(s, "on") - return util.OptionalBoolOf(v) -} - -func (ctx *Context) SetFormString(key, value string) { - _ = ctx.Req.FormValue(key) // force parse form - ctx.Req.Form.Set(key, value) -} diff --git a/modules/context/context_request.go b/modules/context/context_request.go index 0b87552c08..984b9ac793 100644 --- a/modules/context/context_request.go +++ b/modules/context/context_request.go @@ -6,36 +6,9 @@ package context import ( "io" "net/http" - "net/url" - "strconv" "strings" - - "github.com/go-chi/chi/v5" ) -// RemoteAddr returns the client machine ip address -func (ctx *Context) RemoteAddr() string { - return ctx.Req.RemoteAddr -} - -// Params returns the param on route -func (ctx *Context) Params(p string) string { - s, _ := url.PathUnescape(chi.URLParam(ctx.Req, strings.TrimPrefix(p, ":"))) - return s -} - -// ParamsInt64 returns the param on route as int64 -func (ctx *Context) ParamsInt64(p string) int64 { - v, _ := strconv.ParseInt(ctx.Params(p), 10, 64) - return v -} - -// SetParams set params into routes -func (ctx *Context) SetParams(k, v string) { - chiCtx := chi.RouteContext(ctx) - chiCtx.URLParams.Add(strings.TrimPrefix(k, ":"), url.PathEscape(v)) -} - // UploadStream returns the request body or the first form file // Only form files need to get closed. func (ctx *Context) UploadStream() (rd io.ReadCloser, needToClose bool, err error) { diff --git a/modules/context/context_response.go b/modules/context/context_response.go index 8adff96994..aeeb51ba37 100644 --- a/modules/context/context_response.go +++ b/modules/context/context_response.go @@ -16,49 +16,17 @@ import ( user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/base" - "code.gitea.io/gitea/modules/json" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/templates" "code.gitea.io/gitea/modules/web/middleware" ) -// SetTotalCountHeader set "X-Total-Count" header -func (ctx *Context) SetTotalCountHeader(total int64) { - ctx.RespHeader().Set("X-Total-Count", fmt.Sprint(total)) - ctx.AppendAccessControlExposeHeaders("X-Total-Count") -} - -// AppendAccessControlExposeHeaders append headers by name to "Access-Control-Expose-Headers" header -func (ctx *Context) AppendAccessControlExposeHeaders(names ...string) { - val := ctx.RespHeader().Get("Access-Control-Expose-Headers") - if len(val) != 0 { - ctx.RespHeader().Set("Access-Control-Expose-Headers", fmt.Sprintf("%s, %s", val, strings.Join(names, ", "))) - } else { - ctx.RespHeader().Set("Access-Control-Expose-Headers", strings.Join(names, ", ")) - } -} - -// Written returns true if there are something sent to web browser -func (ctx *Context) Written() bool { - return ctx.Resp.Status() > 0 -} - -// Status writes status code -func (ctx *Context) Status(status int) { - ctx.Resp.WriteHeader(status) -} - -// Write writes data to web browser -func (ctx *Context) Write(bs []byte) (int, error) { - return ctx.Resp.Write(bs) -} - // RedirectToUser redirect to a differently-named user -func RedirectToUser(ctx *Context, userName string, redirectUserID int64) { +func RedirectToUser(ctx *Base, userName string, redirectUserID int64) { user, err := user_model.GetUserByID(ctx, redirectUserID) if err != nil { - ctx.ServerError("GetUserByID", err) + ctx.Error(http.StatusInternalServerError, "unable to get user") return } @@ -211,69 +179,3 @@ func (ctx *Context) NotFoundOrServerError(logMsg string, errCheck func(error) bo } ctx.serverErrorInternal(logMsg, logErr) } - -// PlainTextBytes renders bytes as plain text -func (ctx *Context) plainTextInternal(skip, status int, bs []byte) { - statusPrefix := status / 100 - if statusPrefix == 4 || statusPrefix == 5 { - log.Log(skip, log.TRACE, "plainTextInternal (status=%d): %s", status, string(bs)) - } - ctx.Resp.Header().Set("Content-Type", "text/plain;charset=utf-8") - ctx.Resp.Header().Set("X-Content-Type-Options", "nosniff") - ctx.Resp.WriteHeader(status) - if _, err := ctx.Resp.Write(bs); err != nil { - log.ErrorWithSkip(skip, "plainTextInternal (status=%d): write bytes failed: %v", status, err) - } -} - -// PlainTextBytes renders bytes as plain text -func (ctx *Context) PlainTextBytes(status int, bs []byte) { - ctx.plainTextInternal(2, status, bs) -} - -// PlainText renders content as plain text -func (ctx *Context) PlainText(status int, text string) { - ctx.plainTextInternal(2, status, []byte(text)) -} - -// RespHeader returns the response header -func (ctx *Context) RespHeader() http.Header { - return ctx.Resp.Header() -} - -// Error returned an error to web browser -func (ctx *Context) Error(status int, contents ...string) { - v := http.StatusText(status) - if len(contents) > 0 { - v = contents[0] - } - http.Error(ctx.Resp, v, status) -} - -// JSON render content as JSON -func (ctx *Context) JSON(status int, content interface{}) { - ctx.Resp.Header().Set("Content-Type", "application/json;charset=utf-8") - ctx.Resp.WriteHeader(status) - if err := json.NewEncoder(ctx.Resp).Encode(content); err != nil { - ctx.ServerError("Render JSON failed", err) - } -} - -// Redirect redirects the request -func (ctx *Context) Redirect(location string, status ...int) { - code := http.StatusSeeOther - if len(status) == 1 { - code = status[0] - } - - if strings.Contains(location, "://") || strings.HasPrefix(location, "//") { - // Some browsers (Safari) have buggy behavior for Cookie + Cache + External Redirection, eg: /my-path => https://other/path - // 1. the first request to "/my-path" contains cookie - // 2. some time later, the request to "/my-path" doesn't contain cookie (caused by Prevent web tracking) - // 3. Gitea's Sessioner doesn't see the session cookie, so it generates a new session id, and returns it to browser - // 4. then the browser accepts the empty session, then the user is logged out - // So in this case, we should remove the session cookie from the response header - removeSessionCookieHeader(ctx.Resp) - } - http.Redirect(ctx.Resp, ctx.Req, location, code) -} diff --git a/modules/context/context_serve.go b/modules/context/context_serve.go deleted file mode 100644 index 5569efbc7e..0000000000 --- a/modules/context/context_serve.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2023 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package context - -import ( - "io" - "net/http" - - "code.gitea.io/gitea/modules/httplib" -) - -type ServeHeaderOptions httplib.ServeHeaderOptions - -func (ctx *Context) SetServeHeaders(opt *ServeHeaderOptions) { - httplib.ServeSetHeaders(ctx.Resp, (*httplib.ServeHeaderOptions)(opt)) -} - -// ServeContent serves content to http request -func (ctx *Context) ServeContent(r io.ReadSeeker, opts *ServeHeaderOptions) { - httplib.ServeSetHeaders(ctx.Resp, (*httplib.ServeHeaderOptions)(opts)) - http.ServeContent(ctx.Resp, ctx.Req, opts.Filename, opts.LastModified, r) -} diff --git a/modules/context/org.go b/modules/context/org.go index 39a3038f91..355ba0ebd0 100644 --- a/modules/context/org.go +++ b/modules/context/org.go @@ -47,7 +47,7 @@ func GetOrganizationByParams(ctx *Context) { if organization.IsErrOrgNotExist(err) { redirectUserID, err := user_model.LookupUserRedirect(orgName) if err == nil { - RedirectToUser(ctx, orgName, redirectUserID) + RedirectToUser(ctx.Base, orgName, redirectUserID) } else if user_model.IsErrUserRedirectNotExist(err) { ctx.NotFound("GetUserByName", err) } else { diff --git a/modules/context/package.go b/modules/context/package.go index fe5bdac19d..b1fd7088dd 100644 --- a/modules/context/package.go +++ b/modules/context/package.go @@ -4,7 +4,6 @@ package context import ( - gocontext "context" "fmt" "net/http" @@ -16,7 +15,6 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/templates" - "code.gitea.io/gitea/modules/web/middleware" ) // Package contains owner, access mode and optional the package descriptor @@ -26,10 +24,16 @@ type Package struct { Descriptor *packages_model.PackageDescriptor } +type packageAssignmentCtx struct { + *Base + Doer *user_model.User + ContextUser *user_model.User +} + // PackageAssignment returns a middleware to handle Context.Package assignment func PackageAssignment() func(ctx *Context) { return func(ctx *Context) { - packageAssignment(ctx, func(status int, title string, obj interface{}) { + errorFn := func(status int, title string, obj interface{}) { err, ok := obj.(error) if !ok { err = fmt.Errorf("%s", obj) @@ -39,68 +43,72 @@ func PackageAssignment() func(ctx *Context) { } else { ctx.ServerError(title, err) } - }) + } + paCtx := &packageAssignmentCtx{Base: ctx.Base, Doer: ctx.Doer, ContextUser: ctx.ContextUser} + ctx.Package = packageAssignment(paCtx, errorFn) } } // PackageAssignmentAPI returns a middleware to handle Context.Package assignment func PackageAssignmentAPI() func(ctx *APIContext) { return func(ctx *APIContext) { - packageAssignment(ctx.Context, ctx.Error) + paCtx := &packageAssignmentCtx{Base: ctx.Base, Doer: ctx.Doer, ContextUser: ctx.ContextUser} + ctx.Package = packageAssignment(paCtx, ctx.Error) } } -func packageAssignment(ctx *Context, errCb func(int, string, interface{})) { - ctx.Package = &Package{ +func packageAssignment(ctx *packageAssignmentCtx, errCb func(int, string, interface{})) *Package { + pkg := &Package{ Owner: ctx.ContextUser, } - var err error - ctx.Package.AccessMode, err = determineAccessMode(ctx) + pkg.AccessMode, err = determineAccessMode(ctx.Base, pkg, ctx.Doer) if err != nil { errCb(http.StatusInternalServerError, "determineAccessMode", err) - return + return pkg } packageType := ctx.Params("type") name := ctx.Params("name") version := ctx.Params("version") if packageType != "" && name != "" && version != "" { - pv, err := packages_model.GetVersionByNameAndVersion(ctx, ctx.Package.Owner.ID, packages_model.Type(packageType), name, version) + pv, err := packages_model.GetVersionByNameAndVersion(ctx, pkg.Owner.ID, packages_model.Type(packageType), name, version) if err != nil { if err == packages_model.ErrPackageNotExist { errCb(http.StatusNotFound, "GetVersionByNameAndVersion", err) } else { errCb(http.StatusInternalServerError, "GetVersionByNameAndVersion", err) } - return + return pkg } - ctx.Package.Descriptor, err = packages_model.GetPackageDescriptor(ctx, pv) + pkg.Descriptor, err = packages_model.GetPackageDescriptor(ctx, pv) if err != nil { errCb(http.StatusInternalServerError, "GetPackageDescriptor", err) - return + return pkg } } + + return pkg } -func determineAccessMode(ctx *Context) (perm.AccessMode, error) { - if setting.Service.RequireSignInView && ctx.Doer == nil { +func determineAccessMode(ctx *Base, pkg *Package, doer *user_model.User) (perm.AccessMode, error) { + if setting.Service.RequireSignInView && doer == nil { return perm.AccessModeNone, nil } - if ctx.Doer != nil && !ctx.Doer.IsGhost() && (!ctx.Doer.IsActive || ctx.Doer.ProhibitLogin) { + if doer != nil && !doer.IsGhost() && (!doer.IsActive || doer.ProhibitLogin) { return perm.AccessModeNone, nil } // TODO: ActionUser permission check accessMode := perm.AccessModeNone - if ctx.Package.Owner.IsOrganization() { - org := organization.OrgFromUser(ctx.Package.Owner) + if pkg.Owner.IsOrganization() { + org := organization.OrgFromUser(pkg.Owner) - if ctx.Doer != nil && !ctx.Doer.IsGhost() { + if doer != nil && !doer.IsGhost() { // 1. If user is logged in, check all team packages permissions - teams, err := organization.GetUserOrgTeams(ctx, org.ID, ctx.Doer.ID) + teams, err := organization.GetUserOrgTeams(ctx, org.ID, doer.ID) if err != nil { return accessMode, err } @@ -110,19 +118,19 @@ func determineAccessMode(ctx *Context) (perm.AccessMode, error) { accessMode = perm } } - } else if organization.HasOrgOrUserVisible(ctx, ctx.Package.Owner, ctx.Doer) { + } else if organization.HasOrgOrUserVisible(ctx, pkg.Owner, doer) { // 2. If user is non-login, check if org is visible to non-login user accessMode = perm.AccessModeRead } } else { - if ctx.Doer != nil && !ctx.Doer.IsGhost() { + if doer != nil && !doer.IsGhost() { // 1. Check if user is package owner - if ctx.Doer.ID == ctx.Package.Owner.ID { + if doer.ID == pkg.Owner.ID { accessMode = perm.AccessModeOwner - } else if ctx.Package.Owner.Visibility == structs.VisibleTypePublic || ctx.Package.Owner.Visibility == structs.VisibleTypeLimited { // 2. Check if package owner is public or limited + } else if pkg.Owner.Visibility == structs.VisibleTypePublic || pkg.Owner.Visibility == structs.VisibleTypeLimited { // 2. Check if package owner is public or limited accessMode = perm.AccessModeRead } - } else if ctx.Package.Owner.Visibility == structs.VisibleTypePublic { // 3. Check if package owner is public + } else if pkg.Owner.Visibility == structs.VisibleTypePublic { // 3. Check if package owner is public accessMode = perm.AccessModeRead } } @@ -131,19 +139,18 @@ func determineAccessMode(ctx *Context) (perm.AccessMode, error) { } // PackageContexter initializes a package context for a request. -func PackageContexter(ctx gocontext.Context) func(next http.Handler) http.Handler { - rnd := templates.HTMLRenderer() +func PackageContexter() func(next http.Handler) http.Handler { + renderer := templates.HTMLRenderer() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - ctx := Context{ - Resp: NewResponse(resp), - Data: middleware.GetContextData(req.Context()), - Render: rnd, + base, baseCleanUp := NewBaseContext(resp, req) + ctx := &Context{ + Base: base, + Render: renderer, // it is still needed when rendering 500 page in a package handler } - defer ctx.Close() - - ctx.Req = WithContext(req, &ctx) + defer baseCleanUp() + ctx.Base.AppendContextValue(contextKey, ctx) next.ServeHTTP(ctx.Resp, ctx.Req) }) } diff --git a/modules/context/private.go b/modules/context/private.go index f621dd6839..41ca8a4709 100644 --- a/modules/context/private.go +++ b/modules/context/private.go @@ -11,13 +11,14 @@ import ( "code.gitea.io/gitea/modules/graceful" "code.gitea.io/gitea/modules/process" - "code.gitea.io/gitea/modules/web/middleware" ) // PrivateContext represents a context for private routes type PrivateContext struct { - *Context + *Base Override context.Context + + Repo *Repository } // Deadline is part of the interface for context.Context and we pass this to the request context @@ -25,7 +26,7 @@ func (ctx *PrivateContext) Deadline() (deadline time.Time, ok bool) { if ctx.Override != nil { return ctx.Override.Deadline() } - return ctx.Req.Context().Deadline() + return ctx.Base.Deadline() } // Done is part of the interface for context.Context and we pass this to the request context @@ -33,7 +34,7 @@ func (ctx *PrivateContext) Done() <-chan struct{} { if ctx.Override != nil { return ctx.Override.Done() } - return ctx.Req.Context().Done() + return ctx.Base.Done() } // Err is part of the interface for context.Context and we pass this to the request context @@ -41,16 +42,11 @@ func (ctx *PrivateContext) Err() error { if ctx.Override != nil { return ctx.Override.Err() } - return ctx.Req.Context().Err() + return ctx.Base.Err() } var privateContextKey interface{} = "default_private_context" -// WithPrivateContext set up private context in request -func WithPrivateContext(req *http.Request, ctx *PrivateContext) *http.Request { - return req.WithContext(context.WithValue(req.Context(), privateContextKey, ctx)) -} - // GetPrivateContext returns a context for Private routes func GetPrivateContext(req *http.Request) *PrivateContext { return req.Context().Value(privateContextKey).(*PrivateContext) @@ -60,16 +56,11 @@ func GetPrivateContext(req *http.Request) *PrivateContext { func PrivateContexter() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ctx := &PrivateContext{ - Context: &Context{ - Resp: NewResponse(w), - Data: middleware.GetContextData(req.Context()), - }, - } - defer ctx.Close() + base, baseCleanUp := NewBaseContext(w, req) + ctx := &PrivateContext{Base: base} + defer baseCleanUp() + ctx.Base.AppendContextValue(privateContextKey, ctx) - ctx.Req = WithPrivateContext(req, ctx) - ctx.Data["Context"] = ctx next.ServeHTTP(ctx.Resp, ctx.Req) }) } diff --git a/modules/context/repo.go b/modules/context/repo.go index 5e90e8aec0..fd5f208576 100644 --- a/modules/context/repo.go +++ b/modules/context/repo.go @@ -331,13 +331,14 @@ func EarlyResponseForGoGetMeta(ctx *Context) { } // RedirectToRepo redirect to a differently-named repository -func RedirectToRepo(ctx *Context, redirectRepoID int64) { +func RedirectToRepo(ctx *Base, redirectRepoID int64) { ownerName := ctx.Params(":username") previousRepoName := ctx.Params(":reponame") repo, err := repo_model.GetRepositoryByID(ctx, redirectRepoID) if err != nil { - ctx.ServerError("GetRepositoryByID", err) + log.Error("GetRepositoryByID: %v", err) + ctx.Error(http.StatusInternalServerError, "GetRepositoryByID") return } @@ -456,7 +457,7 @@ func RepoAssignment(ctx *Context) (cancel context.CancelFunc) { } if redirectUserID, err := user_model.LookupUserRedirect(userName); err == nil { - RedirectToUser(ctx, userName, redirectUserID) + RedirectToUser(ctx.Base, userName, redirectUserID) } else if user_model.IsErrUserRedirectNotExist(err) { ctx.NotFound("GetUserByName", nil) } else { @@ -498,7 +499,7 @@ func RepoAssignment(ctx *Context) (cancel context.CancelFunc) { if repo_model.IsErrRepoNotExist(err) { redirectRepoID, err := repo_model.LookupRedirect(owner.ID, repoName) if err == nil { - RedirectToRepo(ctx, redirectRepoID) + RedirectToRepo(ctx.Base, redirectRepoID) } else if repo_model.IsErrRedirectNotExist(err) { if ctx.FormString("go-get") == "1" { EarlyResponseForGoGetMeta(ctx) @@ -781,46 +782,46 @@ func (rt RepoRefType) RefTypeIncludesTags() bool { return false } -func getRefNameFromPath(ctx *Context, path string, isExist func(string) bool) string { +func getRefNameFromPath(ctx *Base, repo *Repository, path string, isExist func(string) bool) string { refName := "" parts := strings.Split(path, "/") for i, part := range parts { refName = strings.TrimPrefix(refName+"/"+part, "/") if isExist(refName) { - ctx.Repo.TreePath = strings.Join(parts[i+1:], "/") + repo.TreePath = strings.Join(parts[i+1:], "/") return refName } } return "" } -func getRefName(ctx *Context, pathType RepoRefType) string { +func getRefName(ctx *Base, repo *Repository, pathType RepoRefType) string { path := ctx.Params("*") switch pathType { case RepoRefLegacy, RepoRefAny: - if refName := getRefName(ctx, RepoRefBranch); len(refName) > 0 { + if refName := getRefName(ctx, repo, RepoRefBranch); len(refName) > 0 { return refName } - if refName := getRefName(ctx, RepoRefTag); len(refName) > 0 { + if refName := getRefName(ctx, repo, RepoRefTag); len(refName) > 0 { return refName } // For legacy and API support only full commit sha parts := strings.Split(path, "/") if len(parts) > 0 && len(parts[0]) == git.SHAFullLength { - ctx.Repo.TreePath = strings.Join(parts[1:], "/") + repo.TreePath = strings.Join(parts[1:], "/") return parts[0] } - if refName := getRefName(ctx, RepoRefBlob); len(refName) > 0 { + if refName := getRefName(ctx, repo, RepoRefBlob); len(refName) > 0 { return refName } - ctx.Repo.TreePath = path - return ctx.Repo.Repository.DefaultBranch + repo.TreePath = path + return repo.Repository.DefaultBranch case RepoRefBranch: - ref := getRefNameFromPath(ctx, path, ctx.Repo.GitRepo.IsBranchExist) + ref := getRefNameFromPath(ctx, repo, path, repo.GitRepo.IsBranchExist) if len(ref) == 0 { // maybe it's a renamed branch - return getRefNameFromPath(ctx, path, func(s string) bool { - b, exist, err := git_model.FindRenamedBranch(ctx, ctx.Repo.Repository.ID, s) + return getRefNameFromPath(ctx, repo, path, func(s string) bool { + b, exist, err := git_model.FindRenamedBranch(ctx, repo.Repository.ID, s) if err != nil { log.Error("FindRenamedBranch", err) return false @@ -839,15 +840,15 @@ func getRefName(ctx *Context, pathType RepoRefType) string { return ref case RepoRefTag: - return getRefNameFromPath(ctx, path, ctx.Repo.GitRepo.IsTagExist) + return getRefNameFromPath(ctx, repo, path, repo.GitRepo.IsTagExist) case RepoRefCommit: parts := strings.Split(path, "/") if len(parts) > 0 && len(parts[0]) >= 7 && len(parts[0]) <= git.SHAFullLength { - ctx.Repo.TreePath = strings.Join(parts[1:], "/") + repo.TreePath = strings.Join(parts[1:], "/") return parts[0] } case RepoRefBlob: - _, err := ctx.Repo.GitRepo.GetBlob(path) + _, err := repo.GitRepo.GetBlob(path) if err != nil { return "" } @@ -922,7 +923,7 @@ func RepoRefByType(refType RepoRefType, ignoreNotExistErr ...bool) func(*Context } ctx.Repo.IsViewBranch = true } else { - refName = getRefName(ctx, refType) + refName = getRefName(ctx.Base, ctx.Repo, refType) ctx.Repo.RefName = refName isRenamedBranch, has := ctx.Data["IsRenamedBranch"].(bool) if isRenamedBranch && has { diff --git a/modules/context/response.go b/modules/context/response.go index 40eb5c0d35..ca52ea137d 100644 --- a/modules/context/response.go +++ b/modules/context/response.go @@ -10,10 +10,9 @@ import ( // ResponseWriter represents a response writer for HTTP type ResponseWriter interface { http.ResponseWriter - Flush() + http.Flusher Status() int Before(func(ResponseWriter)) - Size() int } var _ ResponseWriter = &Response{} @@ -27,11 +26,6 @@ type Response struct { beforeExecuted bool } -// Size return written size -func (r *Response) Size() int { - return r.written -} - // Write writes bytes to HTTP endpoint func (r *Response) Write(bs []byte) (int, error) { if !r.beforeExecuted { @@ -65,7 +59,7 @@ func (r *Response) WriteHeader(statusCode int) { } } -// Flush flush cached data +// Flush flushes cached data func (r *Response) Flush() { if f, ok := r.ResponseWriter.(http.Flusher); ok { f.Flush() @@ -83,8 +77,7 @@ func (r *Response) Before(f func(ResponseWriter)) { r.befores = append(r.befores, f) } -// NewResponse creates a response -func NewResponse(resp http.ResponseWriter) *Response { +func WrapResponseWriter(resp http.ResponseWriter) *Response { if v, ok := resp.(*Response); ok { return v } diff --git a/modules/context/utils.go b/modules/context/utils.go index 1fa99953a2..c0f619aa23 100644 --- a/modules/context/utils.go +++ b/modules/context/utils.go @@ -10,7 +10,7 @@ import ( ) // GetQueryBeforeSince return parsed time (unix format) from URL query's before and since -func GetQueryBeforeSince(ctx *Context) (before, since int64, err error) { +func GetQueryBeforeSince(ctx *Base) (before, since int64, err error) { qCreatedBefore, err := prepareQueryArg(ctx, "before") if err != nil { return 0, 0, err @@ -48,7 +48,7 @@ func parseTime(value string) (int64, error) { } // prepareQueryArg unescape and trim a query arg -func prepareQueryArg(ctx *Context, name string) (value string, err error) { +func prepareQueryArg(ctx *Base, name string) (value string, err error) { value, err = url.PathUnescape(ctx.FormString(name)) value = strings.TrimSpace(value) return value, err diff --git a/modules/test/context_tests.go b/modules/test/context_tests.go index 5ba2126126..349c7e3e80 100644 --- a/modules/test/context_tests.go +++ b/modules/test/context_tests.go @@ -4,7 +4,7 @@ package test import ( - scontext "context" + gocontext "context" "io" "net/http" "net/http/httptest" @@ -28,18 +28,7 @@ import ( // MockContext mock context for unit tests // TODO: move this function to other packages, because it depends on "models" package func MockContext(t *testing.T, path string) *context.Context { - resp := &mockResponseWriter{} - ctx := context.Context{ - Render: &mockRender{}, - Data: make(middleware.ContextData), - Flash: &middleware.Flash{ - Values: make(url.Values), - }, - Resp: context.NewResponse(resp), - Locale: &translation.MockLocale{}, - } - defer ctx.Close() - + resp := httptest.NewRecorder() requestURL, err := url.Parse(path) assert.NoError(t, err) req := &http.Request{ @@ -47,41 +36,105 @@ func MockContext(t *testing.T, path string) *context.Context { Form: url.Values{}, } + base, baseCleanUp := context.NewBaseContext(resp, req) + base.Data = middleware.ContextData{} + base.Locale = &translation.MockLocale{} + ctx := &context.Context{ + Base: base, + Render: &mockRender{}, + Flash: &middleware.Flash{Values: url.Values{}}, + } + _ = baseCleanUp // during test, it doesn't need to do clean up. TODO: this can be improved later + chiCtx := chi.NewRouteContext() - req = req.WithContext(scontext.WithValue(req.Context(), chi.RouteCtxKey, chiCtx)) - ctx.Req = context.WithContext(req, &ctx) - return &ctx + ctx.Base.AppendContextValue(chi.RouteCtxKey, chiCtx) + return ctx +} + +// MockAPIContext mock context for unit tests +// TODO: move this function to other packages, because it depends on "models" package +func MockAPIContext(t *testing.T, path string) *context.APIContext { + resp := httptest.NewRecorder() + requestURL, err := url.Parse(path) + assert.NoError(t, err) + req := &http.Request{ + URL: requestURL, + Form: url.Values{}, + } + + base, baseCleanUp := context.NewBaseContext(resp, req) + base.Data = middleware.ContextData{} + base.Locale = &translation.MockLocale{} + ctx := &context.APIContext{Base: base} + _ = baseCleanUp // during test, it doesn't need to do clean up. TODO: this can be improved later + + chiCtx := chi.NewRouteContext() + ctx.Base.AppendContextValue(chi.RouteCtxKey, chiCtx) + return ctx } // LoadRepo load a repo into a test context. -func LoadRepo(t *testing.T, ctx *context.Context, repoID int64) { - ctx.Repo = &context.Repository{} - ctx.Repo.Repository = unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID}) +func LoadRepo(t *testing.T, ctx gocontext.Context, repoID int64) { + var doer *user_model.User + repo := &context.Repository{} + switch ctx := ctx.(type) { + case *context.Context: + ctx.Repo = repo + doer = ctx.Doer + case *context.APIContext: + ctx.Repo = repo + doer = ctx.Doer + default: + assert.Fail(t, "context is not *context.Context or *context.APIContext") + return + } + + repo.Repository = unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID}) var err error - ctx.Repo.Owner, err = user_model.GetUserByID(ctx, ctx.Repo.Repository.OwnerID) + repo.Owner, err = user_model.GetUserByID(ctx, repo.Repository.OwnerID) assert.NoError(t, err) - ctx.Repo.RepoLink = ctx.Repo.Repository.Link() - ctx.Repo.Permission, err = access_model.GetUserRepoPermission(ctx, ctx.Repo.Repository, ctx.Doer) + repo.RepoLink = repo.Repository.Link() + repo.Permission, err = access_model.GetUserRepoPermission(ctx, repo.Repository, doer) assert.NoError(t, err) } // LoadRepoCommit loads a repo's commit into a test context. -func LoadRepoCommit(t *testing.T, ctx *context.Context) { - gitRepo, err := git.OpenRepository(ctx, ctx.Repo.Repository.RepoPath()) +func LoadRepoCommit(t *testing.T, ctx gocontext.Context) { + var repo *context.Repository + switch ctx := ctx.(type) { + case *context.Context: + repo = ctx.Repo + case *context.APIContext: + repo = ctx.Repo + default: + assert.Fail(t, "context is not *context.Context or *context.APIContext") + return + } + + gitRepo, err := git.OpenRepository(ctx, repo.Repository.RepoPath()) assert.NoError(t, err) defer gitRepo.Close() branch, err := gitRepo.GetHEADBranch() assert.NoError(t, err) assert.NotNil(t, branch) if branch != nil { - ctx.Repo.Commit, err = gitRepo.GetBranchCommit(branch.Name) + repo.Commit, err = gitRepo.GetBranchCommit(branch.Name) assert.NoError(t, err) } } // LoadUser load a user into a test context. -func LoadUser(t *testing.T, ctx *context.Context, userID int64) { - ctx.Doer = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: userID}) +func LoadUser(t *testing.T, ctx gocontext.Context, userID int64) { + doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: userID}) + switch ctx := ctx.(type) { + case *context.Context: + ctx.Doer = doer + case *context.APIContext: + ctx.Doer = doer + default: + assert.Fail(t, "context is not *context.Context or *context.APIContext") + return + } } // LoadGitRepo load a git repo into a test context. Requires that ctx.Repo has @@ -93,32 +146,6 @@ func LoadGitRepo(t *testing.T, ctx *context.Context) { assert.NoError(t, err) } -type mockResponseWriter struct { - httptest.ResponseRecorder - size int -} - -func (rw *mockResponseWriter) Write(b []byte) (int, error) { - rw.size += len(b) - return rw.ResponseRecorder.Write(b) -} - -func (rw *mockResponseWriter) Status() int { - return rw.ResponseRecorder.Code -} - -func (rw *mockResponseWriter) Written() bool { - return rw.ResponseRecorder.Code > 0 -} - -func (rw *mockResponseWriter) Size() int { - return rw.size -} - -func (rw *mockResponseWriter) Push(target string, opts *http.PushOptions) error { - return nil -} - type mockRender struct{} func (tr *mockRender) TemplateLookup(tmpl string) (templates.TemplateExecutor, error) { diff --git a/modules/translation/translation.go b/modules/translation/translation.go index 49dfa84d1b..dba4de6607 100644 --- a/modules/translation/translation.go +++ b/modules/translation/translation.go @@ -38,10 +38,12 @@ type LangType struct { } var ( - lock *sync.RWMutex + lock *sync.RWMutex + + allLangs []*LangType + allLangMap map[string]*LangType + matcher language.Matcher - allLangs []*LangType - allLangMap map[string]*LangType supportedTags []language.Tag ) @@ -251,3 +253,9 @@ func (l *locale) PrettyNumber(v any) string { } return l.msgPrinter.Sprintf("%v", number.Decimal(v)) } + +func init() { + // prepare a default matcher, especially for tests + supportedTags = []language.Tag{language.English} + matcher = language.NewMatcher(supportedTags) +} diff --git a/modules/web/handler.go b/modules/web/handler.go index bfb83820c8..5013bac93f 100644 --- a/modules/web/handler.go +++ b/modules/web/handler.go @@ -10,6 +10,7 @@ import ( "reflect" "code.gitea.io/gitea/modules/context" + "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/web/routing" ) @@ -25,6 +26,10 @@ var argTypeProvider = map[reflect.Type]func(req *http.Request) ResponseStatusPro reflect.TypeOf(&context.PrivateContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetPrivateContext(req) }, } +func RegisterHandleTypeProvider[T any](fn func(req *http.Request) ResponseStatusProvider) { + argTypeProvider[reflect.TypeOf((*T)(nil)).Elem()] = fn +} + // responseWriter is a wrapper of http.ResponseWriter, to check whether the response has been written type responseWriter struct { respWriter http.ResponseWriter @@ -78,7 +83,13 @@ func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) { } } -func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value) []reflect.Value { +func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value { + defer func() { + if err := recover(); err != nil { + log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err) + panic(err) + } + }() isPreCheck := req == nil argsIn := make([]reflect.Value, fn.Type().NumIn()) @@ -155,7 +166,7 @@ func toHandlerProvider(handler any) func(next http.Handler) http.Handler { } // prepare the arguments for the handler and do pre-check - argsIn := prepareHandleArgsIn(resp, req, fn) + argsIn := prepareHandleArgsIn(resp, req, fn, funcInfo) if req == nil { preCheckHandler(fn, argsIn) return // it's doing pre-check, just return diff --git a/routers/api/actions/artifacts.go b/routers/api/actions/artifacts.go index 61d432c862..4b10cd7ad1 100644 --- a/routers/api/actions/artifacts.go +++ b/routers/api/actions/artifacts.go @@ -3,7 +3,7 @@ package actions -// Github Actions Artifacts API Simple Description +// GitHub Actions Artifacts API Simple Description // // 1. Upload artifact // 1.1. Post upload url @@ -63,7 +63,6 @@ package actions import ( "compress/gzip" - gocontext "context" "crypto/md5" "encoding/base64" "errors" @@ -92,9 +91,25 @@ const ( const artifactRouteBase = "/_apis/pipelines/workflows/{run_id}/artifacts" -func ArtifactsRoutes(goctx gocontext.Context, prefix string) *web.Route { +type artifactContextKeyType struct{} + +var artifactContextKey = artifactContextKeyType{} + +type ArtifactContext struct { + *context.Base + + ActionTask *actions.ActionTask +} + +func init() { + web.RegisterHandleTypeProvider[*ArtifactContext](func(req *http.Request) web.ResponseStatusProvider { + return req.Context().Value(artifactContextKey).(*ArtifactContext) + }) +} + +func ArtifactsRoutes(prefix string) *web.Route { m := web.NewRoute() - m.Use(withContexter(goctx)) + m.Use(ArtifactContexter()) r := artifactRoutes{ prefix: prefix, @@ -115,15 +130,14 @@ func ArtifactsRoutes(goctx gocontext.Context, prefix string) *web.Route { return m } -// withContexter initializes a package context for a request. -func withContexter(goctx gocontext.Context) func(next http.Handler) http.Handler { +func ArtifactContexter() func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - ctx := context.Context{ - Resp: context.NewResponse(resp), - Data: map[string]interface{}{}, - } - defer ctx.Close() + base, baseCleanUp := context.NewBaseContext(resp, req) + defer baseCleanUp() + + ctx := &ArtifactContext{Base: base} + ctx.AppendContextValue(artifactContextKey, ctx) // action task call server api with Bearer ACTIONS_RUNTIME_TOKEN // we should verify the ACTIONS_RUNTIME_TOKEN @@ -132,6 +146,7 @@ func withContexter(goctx gocontext.Context) func(next http.Handler) http.Handler ctx.Error(http.StatusUnauthorized, "Bad authorization header") return } + authToken := strings.TrimPrefix(authHeader, "Bearer ") task, err := actions.GetRunningTaskByToken(req.Context(), authToken) if err != nil { @@ -139,16 +154,14 @@ func withContexter(goctx gocontext.Context) func(next http.Handler) http.Handler ctx.Error(http.StatusInternalServerError, "Error runner api getting task") return } - ctx.Data["task"] = task - if err := task.LoadJob(goctx); err != nil { + if err := task.LoadJob(req.Context()); err != nil { log.Error("Error runner api getting job: %v", err) ctx.Error(http.StatusInternalServerError, "Error runner api getting job") return } - ctx.Req = context.WithContext(req, &ctx) - + ctx.ActionTask = task next.ServeHTTP(ctx.Resp, ctx.Req) }) } @@ -175,13 +188,8 @@ type getUploadArtifactResponse struct { FileContainerResourceURL string `json:"fileContainerResourceUrl"` } -func (ar artifactRoutes) validateRunID(ctx *context.Context) (*actions.ActionTask, int64, bool) { - task, ok := ctx.Data["task"].(*actions.ActionTask) - if !ok { - log.Error("Error getting task in context") - ctx.Error(http.StatusInternalServerError, "Error getting task in context") - return nil, 0, false - } +func (ar artifactRoutes) validateRunID(ctx *ArtifactContext) (*actions.ActionTask, int64, bool) { + task := ctx.ActionTask runID := ctx.ParamsInt64("run_id") if task.Job.RunID != runID { log.Error("Error runID not match") @@ -192,7 +200,7 @@ func (ar artifactRoutes) validateRunID(ctx *context.Context) (*actions.ActionTas } // getUploadArtifactURL generates a URL for uploading an artifact -func (ar artifactRoutes) getUploadArtifactURL(ctx *context.Context) { +func (ar artifactRoutes) getUploadArtifactURL(ctx *ArtifactContext) { task, runID, ok := ar.validateRunID(ctx) if !ok { return @@ -220,7 +228,7 @@ func (ar artifactRoutes) getUploadArtifactURL(ctx *context.Context) { // getUploadFileSize returns the size of the file to be uploaded. // The raw size is the size of the file as reported by the header X-TFS-FileLength. -func (ar artifactRoutes) getUploadFileSize(ctx *context.Context) (int64, int64, error) { +func (ar artifactRoutes) getUploadFileSize(ctx *ArtifactContext) (int64, int64, error) { contentLength := ctx.Req.ContentLength xTfsLength, _ := strconv.ParseInt(ctx.Req.Header.Get(artifactXTfsFileLengthHeader), 10, 64) if xTfsLength > 0 { @@ -229,7 +237,7 @@ func (ar artifactRoutes) getUploadFileSize(ctx *context.Context) (int64, int64, return contentLength, contentLength, nil } -func (ar artifactRoutes) saveUploadChunk(ctx *context.Context, +func (ar artifactRoutes) saveUploadChunk(ctx *ArtifactContext, artifact *actions.ActionArtifact, contentSize, runID int64, ) (int64, error) { @@ -273,7 +281,7 @@ func (ar artifactRoutes) saveUploadChunk(ctx *context.Context, // The rules are from https://github.com/actions/toolkit/blob/main/packages/artifact/src/internal/path-and-artifact-name-validation.ts#L32 var invalidArtifactNameChars = strings.Join([]string{"\\", "/", "\"", ":", "<", ">", "|", "*", "?", "\r", "\n"}, "") -func (ar artifactRoutes) uploadArtifact(ctx *context.Context) { +func (ar artifactRoutes) uploadArtifact(ctx *ArtifactContext) { _, runID, ok := ar.validateRunID(ctx) if !ok { return @@ -341,7 +349,7 @@ func (ar artifactRoutes) uploadArtifact(ctx *context.Context) { // comfirmUploadArtifact comfirm upload artifact. // if all chunks are uploaded, merge them to one file. -func (ar artifactRoutes) comfirmUploadArtifact(ctx *context.Context) { +func (ar artifactRoutes) comfirmUploadArtifact(ctx *ArtifactContext) { _, runID, ok := ar.validateRunID(ctx) if !ok { return @@ -364,7 +372,7 @@ type chunkItem struct { Path string } -func (ar artifactRoutes) mergeArtifactChunks(ctx *context.Context, runID int64) error { +func (ar artifactRoutes) mergeArtifactChunks(ctx *ArtifactContext, runID int64) error { storageDir := fmt.Sprintf("tmp%d", runID) var chunks []*chunkItem if err := ar.fs.IterateObjects(storageDir, func(path string, obj storage.Object) error { @@ -415,14 +423,20 @@ func (ar artifactRoutes) mergeArtifactChunks(ctx *context.Context, runID int64) // use multiReader readers := make([]io.Reader, 0, len(allChunks)) - readerClosers := make([]io.Closer, 0, len(allChunks)) + closeReaders := func() { + for _, r := range readers { + _ = r.(io.Closer).Close() // it guarantees to be io.Closer by the following loop's Open function + } + readers = nil + } + defer closeReaders() + for _, c := range allChunks { - reader, err := ar.fs.Open(c.Path) - if err != nil { + var readCloser io.ReadCloser + if readCloser, err = ar.fs.Open(c.Path); err != nil { return fmt.Errorf("open chunk error: %v, %s", err, c.Path) } - readers = append(readers, reader) - readerClosers = append(readerClosers, reader) + readers = append(readers, readCloser) } mergedReader := io.MultiReader(readers...) @@ -445,11 +459,6 @@ func (ar artifactRoutes) mergeArtifactChunks(ctx *context.Context, runID int64) return fmt.Errorf("merged file size is not equal to chunk length") } - // close readers - for _, r := range readerClosers { - r.Close() - } - // save storage path to artifact log.Debug("[artifact] merge chunks to artifact: %d, %s", artifact.ID, storagePath) artifact.StoragePath = storagePath @@ -458,6 +467,8 @@ func (ar artifactRoutes) mergeArtifactChunks(ctx *context.Context, runID int64) return fmt.Errorf("update artifact error: %v", err) } + closeReaders() // close before delete + // drop chunks for _, c := range cs { if err := ar.fs.Delete(c.Path); err != nil { @@ -479,21 +490,21 @@ type ( } ) -func (ar artifactRoutes) listArtifacts(ctx *context.Context) { +func (ar artifactRoutes) listArtifacts(ctx *ArtifactContext) { _, runID, ok := ar.validateRunID(ctx) if !ok { return } - artficats, err := actions.ListArtifactsByRunID(ctx, runID) + artifacts, err := actions.ListArtifactsByRunID(ctx, runID) if err != nil { log.Error("Error getting artifacts: %v", err) ctx.Error(http.StatusInternalServerError, err.Error()) return } - artficatsData := make([]listArtifactsResponseItem, 0, len(artficats)) - for _, a := range artficats { + artficatsData := make([]listArtifactsResponseItem, 0, len(artifacts)) + for _, a := range artifacts { artficatsData = append(artficatsData, listArtifactsResponseItem{ Name: a.ArtifactName, FileContainerResourceURL: ar.buildArtifactURL(runID, a.ID, "path"), @@ -517,7 +528,7 @@ type ( } ) -func (ar artifactRoutes) getDownloadArtifactURL(ctx *context.Context) { +func (ar artifactRoutes) getDownloadArtifactURL(ctx *ArtifactContext) { _, runID, ok := ar.validateRunID(ctx) if !ok { return @@ -546,7 +557,7 @@ func (ar artifactRoutes) getDownloadArtifactURL(ctx *context.Context) { ctx.JSON(http.StatusOK, respData) } -func (ar artifactRoutes) downloadArtifact(ctx *context.Context) { +func (ar artifactRoutes) downloadArtifact(ctx *ArtifactContext) { _, runID, ok := ar.validateRunID(ctx) if !ok { return diff --git a/routers/api/packages/api.go b/routers/api/packages/api.go index aaceb8a92b..e715997e82 100644 --- a/routers/api/packages/api.go +++ b/routers/api/packages/api.go @@ -98,7 +98,7 @@ func verifyAuth(r *web.Route, authMethods []auth.Method) { func CommonRoutes(ctx gocontext.Context) *web.Route { r := web.NewRoute() - r.Use(context.PackageContexter(ctx)) + r.Use(context.PackageContexter()) verifyAuth(r, []auth.Method{ &auth.OAuth2{}, @@ -574,7 +574,7 @@ func CommonRoutes(ctx gocontext.Context) *web.Route { func ContainerRoutes(ctx gocontext.Context) *web.Route { r := web.NewRoute() - r.Use(context.PackageContexter(ctx)) + r.Use(context.PackageContexter()) verifyAuth(r, []auth.Method{ &auth.Basic{}, diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index a67a5420ac..f1e1cf946a 100644 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -149,7 +149,7 @@ func repoAssignment() func(ctx *context.APIContext) { if err != nil { if user_model.IsErrUserNotExist(err) { if redirectUserID, err := user_model.LookupUserRedirect(userName); err == nil { - context.RedirectToUser(ctx.Context, userName, redirectUserID) + context.RedirectToUser(ctx.Base, userName, redirectUserID) } else if user_model.IsErrUserRedirectNotExist(err) { ctx.NotFound("GetUserByName", err) } else { @@ -170,7 +170,7 @@ func repoAssignment() func(ctx *context.APIContext) { if repo_model.IsErrRepoNotExist(err) { redirectRepoID, err := repo_model.LookupRedirect(owner.ID, repoName) if err == nil { - context.RedirectToRepo(ctx.Context, redirectRepoID) + context.RedirectToRepo(ctx.Base, redirectRepoID) } else if repo_model.IsErrRedirectNotExist(err) { ctx.NotFound() } else { @@ -274,7 +274,7 @@ func reqToken(requiredScope auth_model.AccessTokenScope) func(ctx *context.APICo ctx.Error(http.StatusForbidden, "reqToken", "token does not have required scope: "+requiredScope) return } - if ctx.Context.IsBasicAuth { + if ctx.IsBasicAuth { ctx.CheckForOTP() return } @@ -295,7 +295,7 @@ func reqExploreSignIn() func(ctx *context.APIContext) { func reqBasicAuth() func(ctx *context.APIContext) { return func(ctx *context.APIContext) { - if !ctx.Context.IsBasicAuth { + if !ctx.IsBasicAuth { ctx.Error(http.StatusUnauthorized, "reqBasicAuth", "auth required") return } @@ -375,7 +375,7 @@ func reqAnyRepoReader() func(ctx *context.APIContext) { // reqOrgOwnership user should be an organization owner, or a site admin func reqOrgOwnership() func(ctx *context.APIContext) { return func(ctx *context.APIContext) { - if ctx.Context.IsUserSiteAdmin() { + if ctx.IsUserSiteAdmin() { return } @@ -407,7 +407,7 @@ func reqOrgOwnership() func(ctx *context.APIContext) { // reqTeamMembership user should be an team member, or a site admin func reqTeamMembership() func(ctx *context.APIContext) { return func(ctx *context.APIContext) { - if ctx.Context.IsUserSiteAdmin() { + if ctx.IsUserSiteAdmin() { return } if ctx.Org.Team == nil { @@ -444,7 +444,7 @@ func reqTeamMembership() func(ctx *context.APIContext) { // reqOrgMembership user should be an organization member, or a site admin func reqOrgMembership() func(ctx *context.APIContext) { return func(ctx *context.APIContext) { - if ctx.Context.IsUserSiteAdmin() { + if ctx.IsUserSiteAdmin() { return } @@ -512,7 +512,7 @@ func orgAssignment(args ...bool) func(ctx *context.APIContext) { if organization.IsErrOrgNotExist(err) { redirectUserID, err := user_model.LookupUserRedirect(ctx.Params(":org")) if err == nil { - context.RedirectToUser(ctx.Context, ctx.Params(":org"), redirectUserID) + context.RedirectToUser(ctx.Base, ctx.Params(":org"), redirectUserID) } else if user_model.IsErrUserRedirectNotExist(err) { ctx.NotFound("GetOrgByName", err) } else { diff --git a/routers/api/v1/misc/markup.go b/routers/api/v1/misc/markup.go index 93d5754444..7b24b353b6 100644 --- a/routers/api/v1/misc/markup.go +++ b/routers/api/v1/misc/markup.go @@ -41,7 +41,7 @@ func Markup(ctx *context.APIContext) { return } - common.RenderMarkup(ctx.Context, form.Mode, form.Text, form.Context, form.FilePath, form.Wiki) + common.RenderMarkup(ctx.Base, ctx.Repo, form.Mode, form.Text, form.Context, form.FilePath, form.Wiki) } // Markdown render markdown document to HTML @@ -76,7 +76,7 @@ func Markdown(ctx *context.APIContext) { mode = form.Mode } - common.RenderMarkup(ctx.Context, mode, form.Text, form.Context, "", form.Wiki) + common.RenderMarkup(ctx.Base, ctx.Repo, mode, form.Text, form.Context, "", form.Wiki) } // MarkdownRaw render raw markdown HTML diff --git a/routers/api/v1/misc/markup_test.go b/routers/api/v1/misc/markup_test.go index 68776613b2..fdf540fd65 100644 --- a/routers/api/v1/misc/markup_test.go +++ b/routers/api/v1/misc/markup_test.go @@ -16,7 +16,6 @@ import ( "code.gitea.io/gitea/modules/markup" "code.gitea.io/gitea/modules/setting" api "code.gitea.io/gitea/modules/structs" - "code.gitea.io/gitea/modules/templates" "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/modules/web" "code.gitea.io/gitea/modules/web/middleware" @@ -30,26 +29,16 @@ const ( AppSubURL = AppURL + Repo + "/" ) -func createContext(req *http.Request) (*context.Context, *httptest.ResponseRecorder) { - rnd := templates.HTMLRenderer() +func createAPIContext(req *http.Request) (*context.APIContext, *httptest.ResponseRecorder) { resp := httptest.NewRecorder() - c := &context.Context{ - Req: req, - Resp: context.NewResponse(resp), - Render: rnd, - Data: make(middleware.ContextData), - } - defer c.Close() + base, baseCleanUp := context.NewBaseContext(resp, req) + base.Data = middleware.ContextData{} + c := &context.APIContext{Base: base} + _ = baseCleanUp // during test, it doesn't need to do clean up. TODO: this can be improved later return c, resp } -func wrap(ctx *context.Context) *context.APIContext { - return &context.APIContext{ - Context: ctx, - } -} - func testRenderMarkup(t *testing.T, mode, filePath, text, responseBody string, responseCode int) { setting.AppURL = AppURL @@ -65,8 +54,7 @@ func testRenderMarkup(t *testing.T, mode, filePath, text, responseBody string, r Method: "POST", URL: requrl, } - m, resp := createContext(req) - ctx := wrap(m) + ctx, resp := createAPIContext(req) options.Text = text web.SetForm(ctx, &options) @@ -90,8 +78,7 @@ func testRenderMarkdown(t *testing.T, mode, text, responseBody string, responseC Method: "POST", URL: requrl, } - m, resp := createContext(req) - ctx := wrap(m) + ctx, resp := createAPIContext(req) options.Text = text web.SetForm(ctx, &options) @@ -211,8 +198,7 @@ func TestAPI_RenderSimple(t *testing.T) { Method: "POST", URL: requrl, } - m, resp := createContext(req) - ctx := wrap(m) + ctx, resp := createAPIContext(req) for i := 0; i < len(simpleCases); i += 2 { options.Text = simpleCases[i] @@ -231,8 +217,7 @@ func TestAPI_RenderRaw(t *testing.T) { Method: "POST", URL: requrl, } - m, resp := createContext(req) - ctx := wrap(m) + ctx, resp := createAPIContext(req) for i := 0; i < len(simpleCases); i += 2 { ctx.Req.Body = io.NopCloser(strings.NewReader(simpleCases[i])) diff --git a/routers/api/v1/notify/notifications.go b/routers/api/v1/notify/notifications.go index 3b6a9bfdc2..b22ea8a771 100644 --- a/routers/api/v1/notify/notifications.go +++ b/routers/api/v1/notify/notifications.go @@ -25,7 +25,7 @@ func NewAvailable(ctx *context.APIContext) { } func getFindNotificationOptions(ctx *context.APIContext) *activities_model.FindNotificationOptions { - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return nil diff --git a/routers/api/v1/repo/file.go b/routers/api/v1/repo/file.go index eb63dda590..786407827c 100644 --- a/routers/api/v1/repo/file.go +++ b/routers/api/v1/repo/file.go @@ -80,7 +80,7 @@ func GetRawFile(ctx *context.APIContext) { ctx.RespHeader().Set(giteaObjectTypeHeader, string(files_service.GetObjectTypeFromTreeEntry(entry))) - if err := common.ServeBlob(ctx.Context, blob, lastModified); err != nil { + if err := common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, lastModified); err != nil { ctx.Error(http.StatusInternalServerError, "ServeBlob", err) } } @@ -137,7 +137,7 @@ func GetRawFileOrLFS(ctx *context.APIContext) { } // OK not cached - serve! - if err := common.ServeBlob(ctx.Context, blob, lastModified); err != nil { + if err := common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, lastModified); err != nil { ctx.ServerError("ServeBlob", err) } return @@ -159,7 +159,7 @@ func GetRawFileOrLFS(ctx *context.APIContext) { } if err := dataRc.Close(); err != nil { - log.Error("Error whilst closing blob %s reader in %-v. Error: %v", blob.ID, ctx.Context.Repo.Repository, err) + log.Error("Error whilst closing blob %s reader in %-v. Error: %v", blob.ID, ctx.Repo.Repository, err) } // Check if the blob represents a pointer @@ -173,7 +173,7 @@ func GetRawFileOrLFS(ctx *context.APIContext) { } // OK not cached - serve! - common.ServeContentByReader(ctx.Context, ctx.Repo.TreePath, blob.Size(), bytes.NewReader(buf)) + common.ServeContentByReader(ctx.Base, ctx.Repo.TreePath, blob.Size(), bytes.NewReader(buf)) return } @@ -187,7 +187,7 @@ func GetRawFileOrLFS(ctx *context.APIContext) { return } - common.ServeContentByReader(ctx.Context, ctx.Repo.TreePath, blob.Size(), bytes.NewReader(buf)) + common.ServeContentByReader(ctx.Base, ctx.Repo.TreePath, blob.Size(), bytes.NewReader(buf)) return } else if err != nil { ctx.ServerError("GetLFSMetaObjectByOid", err) @@ -215,7 +215,7 @@ func GetRawFileOrLFS(ctx *context.APIContext) { } defer lfsDataRc.Close() - common.ServeContentByReadSeeker(ctx.Context, ctx.Repo.TreePath, lastModified, lfsDataRc) + common.ServeContentByReadSeeker(ctx.Base, ctx.Repo.TreePath, lastModified, lfsDataRc) } func getBlobForEntry(ctx *context.APIContext) (blob *git.Blob, entry *git.TreeEntry, lastModified time.Time) { diff --git a/routers/api/v1/repo/hook_test.go b/routers/api/v1/repo/hook_test.go index 34dc990c3d..56658b45d5 100644 --- a/routers/api/v1/repo/hook_test.go +++ b/routers/api/v1/repo/hook_test.go @@ -9,7 +9,6 @@ import ( "code.gitea.io/gitea/models/unittest" "code.gitea.io/gitea/models/webhook" - "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/test" "github.com/stretchr/testify/assert" @@ -18,12 +17,12 @@ import ( func TestTestHook(t *testing.T) { unittest.PrepareTestEnv(t) - ctx := test.MockContext(t, "user2/repo1/wiki/_pages") + ctx := test.MockAPIContext(t, "user2/repo1/wiki/_pages") ctx.SetParams(":id", "1") test.LoadRepo(t, ctx, 1) test.LoadRepoCommit(t, ctx) test.LoadUser(t, ctx, 2) - TestHook(&context.APIContext{Context: ctx, Org: nil}) + TestHook(ctx) assert.EqualValues(t, http.StatusNoContent, ctx.Resp.Status()) unittest.AssertExistsAndLoadBean(t, &webhook.HookTask{ diff --git a/routers/api/v1/repo/issue.go b/routers/api/v1/repo/issue.go index 95528d664d..49252f7a4b 100644 --- a/routers/api/v1/repo/issue.go +++ b/routers/api/v1/repo/issue.go @@ -116,7 +116,7 @@ func SearchIssues(ctx *context.APIContext) { // "200": // "$ref": "#/responses/IssueList" - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return @@ -368,7 +368,7 @@ func ListIssues(ctx *context.APIContext) { // responses: // "200": // "$ref": "#/responses/IssueList" - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return diff --git a/routers/api/v1/repo/issue_comment.go b/routers/api/v1/repo/issue_comment.go index 6ae6063303..7c8f30f116 100644 --- a/routers/api/v1/repo/issue_comment.go +++ b/routers/api/v1/repo/issue_comment.go @@ -59,7 +59,7 @@ func ListIssueComments(ctx *context.APIContext) { // "200": // "$ref": "#/responses/CommentList" - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return @@ -156,7 +156,7 @@ func ListIssueCommentsAndTimeline(ctx *context.APIContext) { // "200": // "$ref": "#/responses/TimelineList" - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return @@ -259,7 +259,7 @@ func ListRepoIssueComments(ctx *context.APIContext) { // "200": // "$ref": "#/responses/CommentList" - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return diff --git a/routers/api/v1/repo/issue_tracked_time.go b/routers/api/v1/repo/issue_tracked_time.go index 16bb8cb73d..1ff934950c 100644 --- a/routers/api/v1/repo/issue_tracked_time.go +++ b/routers/api/v1/repo/issue_tracked_time.go @@ -103,7 +103,7 @@ func ListTrackedTimes(ctx *context.APIContext) { opts.UserID = user.ID } - if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Context); err != nil { + if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Base); err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return } @@ -522,7 +522,7 @@ func ListTrackedTimesByRepository(ctx *context.APIContext) { } var err error - if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Context); err != nil { + if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Base); err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return } @@ -596,7 +596,7 @@ func ListMyTrackedTimes(ctx *context.APIContext) { } var err error - if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Context); err != nil { + if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Base); err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return } diff --git a/routers/api/v1/repo/migrate.go b/routers/api/v1/repo/migrate.go index efce39e520..b458cd122b 100644 --- a/routers/api/v1/repo/migrate.go +++ b/routers/api/v1/repo/migrate.go @@ -79,7 +79,7 @@ func Migrate(ctx *context.APIContext) { return } - if ctx.HasError() { + if ctx.HasAPIError() { ctx.Error(http.StatusUnprocessableEntity, "", ctx.GetErrMsg()) return } diff --git a/routers/api/v1/repo/repo_test.go b/routers/api/v1/repo/repo_test.go index 59c3bde819..e1bdea5c82 100644 --- a/routers/api/v1/repo/repo_test.go +++ b/routers/api/v1/repo/repo_test.go @@ -9,7 +9,6 @@ import ( repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/models/unittest" - "code.gitea.io/gitea/modules/context" api "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/test" "code.gitea.io/gitea/modules/web" @@ -20,7 +19,7 @@ import ( func TestRepoEdit(t *testing.T) { unittest.PrepareTestEnv(t) - ctx := test.MockContext(t, "user2/repo1") + ctx := test.MockAPIContext(t, "user2/repo1") test.LoadRepo(t, ctx, 1) test.LoadUser(t, ctx, 2) ctx.Repo.Owner = ctx.Doer @@ -54,9 +53,8 @@ func TestRepoEdit(t *testing.T) { Archived: &archived, } - apiCtx := &context.APIContext{Context: ctx, Org: nil} - web.SetForm(apiCtx, &opts) - Edit(apiCtx) + web.SetForm(ctx, &opts) + Edit(ctx) assert.EqualValues(t, http.StatusOK, ctx.Resp.Status()) unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ @@ -67,7 +65,7 @@ func TestRepoEdit(t *testing.T) { func TestRepoEditNameChange(t *testing.T) { unittest.PrepareTestEnv(t) - ctx := test.MockContext(t, "user2/repo1") + ctx := test.MockAPIContext(t, "user2/repo1") test.LoadRepo(t, ctx, 1) test.LoadUser(t, ctx, 2) ctx.Repo.Owner = ctx.Doer @@ -76,9 +74,8 @@ func TestRepoEditNameChange(t *testing.T) { Name: &name, } - apiCtx := &context.APIContext{Context: ctx, Org: nil} - web.SetForm(apiCtx, &opts) - Edit(apiCtx) + web.SetForm(ctx, &opts) + Edit(ctx) assert.EqualValues(t, http.StatusOK, ctx.Resp.Status()) unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ diff --git a/routers/api/v1/repo/status.go b/routers/api/v1/repo/status.go index 5158f38e14..c1110ebce5 100644 --- a/routers/api/v1/repo/status.go +++ b/routers/api/v1/repo/status.go @@ -183,7 +183,7 @@ func getCommitStatuses(ctx *context.APIContext, sha string) { ctx.Error(http.StatusBadRequest, "ref/sha not given", nil) return } - sha = utils.MustConvertToSHA1(ctx.Context, sha) + sha = utils.MustConvertToSHA1(ctx.Base, ctx.Repo, sha) repo := ctx.Repo.Repository listOptions := utils.GetListOptions(ctx) diff --git a/routers/api/v1/user/helper.go b/routers/api/v1/user/helper.go index 28f600ad92..4b642910b1 100644 --- a/routers/api/v1/user/helper.go +++ b/routers/api/v1/user/helper.go @@ -17,7 +17,7 @@ func GetUserByParamsName(ctx *context.APIContext, name string) *user_model.User if err != nil { if user_model.IsErrUserNotExist(err) { if redirectUserID, err2 := user_model.LookupUserRedirect(username); err2 == nil { - context.RedirectToUser(ctx.Context, username, redirectUserID) + context.RedirectToUser(ctx.Base, username, redirectUserID) } else { ctx.NotFound("GetUserByName", err) } diff --git a/routers/api/v1/utils/git.go b/routers/api/v1/utils/git.go index eaf0f5fd37..32f5c85319 100644 --- a/routers/api/v1/utils/git.go +++ b/routers/api/v1/utils/git.go @@ -4,6 +4,7 @@ package utils import ( + gocontext "context" "fmt" "net/http" @@ -33,7 +34,7 @@ func ResolveRefOrSha(ctx *context.APIContext, ref string) string { } } - sha = MustConvertToSHA1(ctx.Context, sha) + sha = MustConvertToSHA1(ctx, ctx.Repo, sha) if ctx.Repo.GitRepo != nil { err := ctx.Repo.GitRepo.AddLastCommitCache(ctx.Repo.Repository.GetCommitsCountCacheKey(ref, ref != sha), ctx.Repo.Repository.FullName(), sha) @@ -69,7 +70,7 @@ func searchRefCommitByType(ctx *context.APIContext, refType, filter string) (str } // ConvertToSHA1 returns a full-length SHA1 from a potential ID string -func ConvertToSHA1(ctx *context.Context, commitID string) (git.SHA1, error) { +func ConvertToSHA1(ctx gocontext.Context, repo *context.Repository, commitID string) (git.SHA1, error) { if len(commitID) == git.SHAFullLength && git.IsValidSHAPattern(commitID) { sha1, err := git.NewIDFromString(commitID) if err == nil { @@ -77,7 +78,7 @@ func ConvertToSHA1(ctx *context.Context, commitID string) (git.SHA1, error) { } } - gitRepo, closer, err := git.RepositoryFromContextOrOpen(ctx, ctx.Repo.Repository.RepoPath()) + gitRepo, closer, err := git.RepositoryFromContextOrOpen(ctx, repo.Repository.RepoPath()) if err != nil { return git.SHA1{}, fmt.Errorf("RepositoryFromContextOrOpen: %w", err) } @@ -87,8 +88,8 @@ func ConvertToSHA1(ctx *context.Context, commitID string) (git.SHA1, error) { } // MustConvertToSHA1 returns a full-length SHA1 string from a potential ID string, or returns origin input if it can't convert to SHA1 -func MustConvertToSHA1(ctx *context.Context, commitID string) string { - sha, err := ConvertToSHA1(ctx, commitID) +func MustConvertToSHA1(ctx gocontext.Context, repo *context.Repository, commitID string) string { + sha, err := ConvertToSHA1(ctx, repo, commitID) if err != nil { return commitID } diff --git a/routers/common/markup.go b/routers/common/markup.go index 3acd12721e..5f412014d7 100644 --- a/routers/common/markup.go +++ b/routers/common/markup.go @@ -19,7 +19,7 @@ import ( ) // RenderMarkup renders markup text for the /markup and /markdown endpoints -func RenderMarkup(ctx *context.Context, mode, text, urlPrefix, filePath string, wiki bool) { +func RenderMarkup(ctx *context.Base, repo *context.Repository, mode, text, urlPrefix, filePath string, wiki bool) { var markupType string relativePath := "" @@ -63,11 +63,11 @@ func RenderMarkup(ctx *context.Context, mode, text, urlPrefix, filePath string, } meta := map[string]string{} - if ctx.Repo != nil && ctx.Repo.Repository != nil { + if repo != nil && repo.Repository != nil { if mode == "comment" { - meta = ctx.Repo.Repository.ComposeMetas() + meta = repo.Repository.ComposeMetas() } else { - meta = ctx.Repo.Repository.ComposeDocumentMetas() + meta = repo.Repository.ComposeDocumentMetas() } } if mode != "comment" { diff --git a/routers/common/middleware.go b/routers/common/middleware.go index c1ee9dd765..a25ff1ee00 100644 --- a/routers/common/middleware.go +++ b/routers/common/middleware.go @@ -42,7 +42,7 @@ func ProtocolMiddlewares() (handlers []any) { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { ctx, _, finished := process.GetManager().AddTypedContext(req.Context(), fmt.Sprintf("%s: %s", req.Method, req.RequestURI), process.RequestProcessType, true) defer finished() - next.ServeHTTP(context.NewResponse(resp), req.WithContext(cache.WithCacheContext(ctx))) + next.ServeHTTP(context.WrapResponseWriter(resp), req.WithContext(cache.WithCacheContext(ctx))) }) }) diff --git a/routers/common/serve.go b/routers/common/serve.go index 59b993328e..3094ee6a6e 100644 --- a/routers/common/serve.go +++ b/routers/common/serve.go @@ -15,7 +15,7 @@ import ( ) // ServeBlob download a git.Blob -func ServeBlob(ctx *context.Context, blob *git.Blob, lastModified time.Time) error { +func ServeBlob(ctx *context.Base, filePath string, blob *git.Blob, lastModified time.Time) error { if httpcache.HandleGenericETagTimeCache(ctx.Req, ctx.Resp, `"`+blob.ID.String()+`"`, lastModified) { return nil } @@ -30,14 +30,14 @@ func ServeBlob(ctx *context.Context, blob *git.Blob, lastModified time.Time) err } }() - httplib.ServeContentByReader(ctx.Req, ctx.Resp, ctx.Repo.TreePath, blob.Size(), dataRc) + httplib.ServeContentByReader(ctx.Req, ctx.Resp, filePath, blob.Size(), dataRc) return nil } -func ServeContentByReader(ctx *context.Context, filePath string, size int64, reader io.Reader) { +func ServeContentByReader(ctx *context.Base, filePath string, size int64, reader io.Reader) { httplib.ServeContentByReader(ctx.Req, ctx.Resp, filePath, size, reader) } -func ServeContentByReadSeeker(ctx *context.Context, filePath string, modTime time.Time, reader io.ReadSeeker) { +func ServeContentByReadSeeker(ctx *context.Base, filePath string, modTime time.Time, reader io.ReadSeeker) { httplib.ServeContentByReadSeeker(ctx.Req, ctx.Resp, filePath, modTime, reader) } diff --git a/routers/init.go b/routers/init.go index 087d8c2915..5737ef3dc0 100644 --- a/routers/init.go +++ b/routers/init.go @@ -198,7 +198,7 @@ func NormalRoutes(ctx context.Context) *web.Route { // In Github, it uses ACTIONS_RUNTIME_URL=https://pipelines.actions.githubusercontent.com/fLgcSHkPGySXeIFrg8W8OBSfeg3b5Fls1A1CwX566g8PayEGlg/ // TODO: this prefix should be generated with a token string with runner ? prefix = "/api/actions_pipeline" - r.Mount(prefix, actions_router.ArtifactsRoutes(ctx, prefix)) + r.Mount(prefix, actions_router.ArtifactsRoutes(prefix)) } return r diff --git a/routers/install/install.go b/routers/install/install.go index 714ddd5548..89b91a5a48 100644 --- a/routers/install/install.go +++ b/routers/install/install.go @@ -58,15 +58,14 @@ func Contexter() func(next http.Handler) http.Handler { dbTypeNames := getSupportedDbTypeNames() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + base, baseCleanUp := context.NewBaseContext(resp, req) ctx := context.Context{ - Resp: context.NewResponse(resp), + Base: base, Flash: &middleware.Flash{}, - Locale: middleware.Locale(resp, req), Render: rnd, - Data: middleware.GetContextData(req.Context()), Session: session.GetSession(req), } - defer ctx.Close() + defer baseCleanUp() ctx.Data.MergeFrom(middleware.CommonTemplateContextData()) ctx.Data.MergeFrom(middleware.ContextData{ @@ -78,7 +77,6 @@ func Contexter() func(next http.Handler) http.Handler { "PasswordHashAlgorithms": hash.RecommendedHashAlgorithms, }) - ctx.Req = context.WithContext(req, &ctx) next.ServeHTTP(resp, ctx.Req) }) } @@ -249,15 +247,8 @@ func SubmitInstall(ctx *context.Context) { ctx.Data["CurDbType"] = form.DbType if ctx.HasError() { - if ctx.HasValue("Err_SMTPUser") { - ctx.Data["Err_SMTP"] = true - } - if ctx.HasValue("Err_AdminName") || - ctx.HasValue("Err_AdminPasswd") || - ctx.HasValue("Err_AdminEmail") { - ctx.Data["Err_Admin"] = true - } - + ctx.Data["Err_SMTP"] = ctx.Data["Err_SMTPUser"] != nil + ctx.Data["Err_Admin"] = ctx.Data["Err_AdminName"] != nil || ctx.Data["Err_AdminPasswd"] != nil || ctx.Data["Err_AdminEmail"] != nil ctx.HTML(http.StatusOK, tplInstall) return } diff --git a/routers/web/misc/markup.go b/routers/web/misc/markup.go index 1690378945..c91da9a7f1 100644 --- a/routers/web/misc/markup.go +++ b/routers/web/misc/markup.go @@ -5,8 +5,6 @@ package misc import ( - "net/http" - "code.gitea.io/gitea/modules/context" api "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/web" @@ -16,11 +14,5 @@ import ( // Markup render markup document to HTML func Markup(ctx *context.Context) { form := web.GetForm(ctx).(*api.MarkupOption) - - if ctx.HasAPIError() { - ctx.Error(http.StatusUnprocessableEntity, "", ctx.GetErrMsg()) - return - } - - common.RenderMarkup(ctx, form.Mode, form.Text, form.Context, form.FilePath, form.Wiki) + common.RenderMarkup(ctx.Base, ctx.Repo, form.Mode, form.Text, form.Context, form.FilePath, form.Wiki) } diff --git a/routers/web/repo/attachment.go b/routers/web/repo/attachment.go index c46ec29841..fb95e63ecf 100644 --- a/routers/web/repo/attachment.go +++ b/routers/web/repo/attachment.go @@ -153,7 +153,7 @@ func ServeAttachment(ctx *context.Context, uuid string) { } defer fr.Close() - common.ServeContentByReadSeeker(ctx, attach.Name, attach.CreatedUnix.AsTime(), fr) + common.ServeContentByReadSeeker(ctx.Base, attach.Name, attach.CreatedUnix.AsTime(), fr) } // GetAttachment serve attachments diff --git a/routers/web/repo/download.go b/routers/web/repo/download.go index 1c87f9bed7..a498180f35 100644 --- a/routers/web/repo/download.go +++ b/routers/web/repo/download.go @@ -47,7 +47,7 @@ func ServeBlobOrLFS(ctx *context.Context, blob *git.Blob, lastModified time.Time log.Error("ServeBlobOrLFS: Close: %v", err) } closed = true - return common.ServeBlob(ctx, blob, lastModified) + return common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, lastModified) } if httpcache.HandleGenericETagCache(ctx.Req, ctx.Resp, `"`+pointer.Oid+`"`) { return nil @@ -71,7 +71,7 @@ func ServeBlobOrLFS(ctx *context.Context, blob *git.Blob, lastModified time.Time log.Error("ServeBlobOrLFS: Close: %v", err) } }() - common.ServeContentByReadSeeker(ctx, ctx.Repo.TreePath, lastModified, lfsDataRc) + common.ServeContentByReadSeeker(ctx.Base, ctx.Repo.TreePath, lastModified, lfsDataRc) return nil } if err = dataRc.Close(); err != nil { @@ -79,7 +79,7 @@ func ServeBlobOrLFS(ctx *context.Context, blob *git.Blob, lastModified time.Time } closed = true - return common.ServeBlob(ctx, blob, lastModified) + return common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, lastModified) } func getBlobForEntry(ctx *context.Context) (blob *git.Blob, lastModified time.Time) { @@ -120,7 +120,7 @@ func SingleDownload(ctx *context.Context) { return } - if err := common.ServeBlob(ctx, blob, lastModified); err != nil { + if err := common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, lastModified); err != nil { ctx.ServerError("ServeBlob", err) } } @@ -148,7 +148,7 @@ func DownloadByID(ctx *context.Context) { } return } - if err = common.ServeBlob(ctx, blob, time.Time{}); err != nil { + if err = common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, time.Time{}); err != nil { ctx.ServerError("ServeBlob", err) } } diff --git a/routers/web/repo/http.go b/routers/web/repo/http.go index 4e45a9b6e2..b6ebd25915 100644 --- a/routers/web/repo/http.go +++ b/routers/web/repo/http.go @@ -109,7 +109,7 @@ func httpBase(ctx *context.Context) (h *serviceHandler) { if err != nil { if repo_model.IsErrRepoNotExist(err) { if redirectRepoID, err := repo_model.LookupRedirect(owner.ID, reponame); err == nil { - context.RedirectToRepo(ctx, redirectRepoID) + context.RedirectToRepo(ctx.Base, redirectRepoID) return } repoExist = false diff --git a/routers/web/repo/issue.go b/routers/web/repo/issue.go index 88d2a97a7a..7a0dc9940b 100644 --- a/routers/web/repo/issue.go +++ b/routers/web/repo/issue.go @@ -2344,7 +2344,7 @@ func UpdatePullReviewRequest(ctx *context.Context) { // SearchIssues searches for issues across the repositories that the user has access to func SearchIssues(ctx *context.Context) { - before, since, err := context.GetQueryBeforeSince(ctx) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, err.Error()) return @@ -2545,7 +2545,7 @@ func getUserIDForFilter(ctx *context.Context, queryName string) int64 { // ListIssues list the issues of a repository func ListIssues(ctx *context.Context) { - before, since, err := context.GetQueryBeforeSince(ctx) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, err.Error()) return diff --git a/routers/web/repo/wiki.go b/routers/web/repo/wiki.go index a335c114be..115418887d 100644 --- a/routers/web/repo/wiki.go +++ b/routers/web/repo/wiki.go @@ -671,7 +671,7 @@ func WikiRaw(ctx *context.Context) { } if entry != nil { - if err = common.ServeBlob(ctx, entry.Blob(), time.Time{}); err != nil { + if err = common.ServeBlob(ctx.Base, ctx.Repo.TreePath, entry.Blob(), time.Time{}); err != nil { ctx.ServerError("ServeBlob", err) } return diff --git a/services/auth/middleware.go b/services/auth/middleware.go index 3b2f883d00..d1955a4c90 100644 --- a/services/auth/middleware.go +++ b/services/auth/middleware.go @@ -8,6 +8,7 @@ import ( "strings" "code.gitea.io/gitea/models/auth" + user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -17,11 +18,15 @@ import ( // Auth is a middleware to authenticate a web user func Auth(authMethod Method) func(*context.Context) { return func(ctx *context.Context) { - if err := authShared(ctx, authMethod); err != nil { + ar, err := authShared(ctx.Base, ctx.Session, authMethod) + if err != nil { log.Error("Failed to verify user: %v", err) ctx.Error(http.StatusUnauthorized, "Verify") return } + ctx.Doer = ar.Doer + ctx.IsSigned = ar.Doer != nil + ctx.IsBasicAuth = ar.IsBasicAuth if ctx.Doer == nil { // ensure the session uid is deleted _ = ctx.Session.Delete("uid") @@ -32,32 +37,41 @@ func Auth(authMethod Method) func(*context.Context) { // APIAuth is a middleware to authenticate an api user func APIAuth(authMethod Method) func(*context.APIContext) { return func(ctx *context.APIContext) { - if err := authShared(ctx.Context, authMethod); err != nil { + ar, err := authShared(ctx.Base, nil, authMethod) + if err != nil { ctx.Error(http.StatusUnauthorized, "APIAuth", err) + return } + ctx.Doer = ar.Doer + ctx.IsSigned = ar.Doer != nil + ctx.IsBasicAuth = ar.IsBasicAuth } } -func authShared(ctx *context.Context, authMethod Method) error { - var err error - ctx.Doer, err = authMethod.Verify(ctx.Req, ctx.Resp, ctx, ctx.Session) +type authResult struct { + Doer *user_model.User + IsBasicAuth bool +} + +func authShared(ctx *context.Base, sessionStore SessionStore, authMethod Method) (ar authResult, err error) { + ar.Doer, err = authMethod.Verify(ctx.Req, ctx.Resp, ctx, sessionStore) if err != nil { - return err + return ar, err } - if ctx.Doer != nil { - if ctx.Locale.Language() != ctx.Doer.Language { + if ar.Doer != nil { + if ctx.Locale.Language() != ar.Doer.Language { ctx.Locale = middleware.Locale(ctx.Resp, ctx.Req) } - ctx.IsBasicAuth = ctx.Data["AuthedMethod"].(string) == BasicMethodName - ctx.IsSigned = true - ctx.Data["IsSigned"] = ctx.IsSigned - ctx.Data[middleware.ContextDataKeySignedUser] = ctx.Doer - ctx.Data["SignedUserID"] = ctx.Doer.ID - ctx.Data["IsAdmin"] = ctx.Doer.IsAdmin + ar.IsBasicAuth = ctx.Data["AuthedMethod"].(string) == BasicMethodName + + ctx.Data["IsSigned"] = true + ctx.Data[middleware.ContextDataKeySignedUser] = ar.Doer + ctx.Data["SignedUserID"] = ar.Doer.ID + ctx.Data["IsAdmin"] = ar.Doer.IsAdmin } else { ctx.Data["SignedUserID"] = int64(0) } - return nil + return ar, nil } // VerifyOptions contains required or check options @@ -68,7 +82,7 @@ type VerifyOptions struct { DisableCSRF bool } -// Checks authentication according to options +// VerifyAuthWithOptions checks authentication according to options func VerifyAuthWithOptions(options *VerifyOptions) func(ctx *context.Context) { return func(ctx *context.Context) { // Check prohibit login users. @@ -153,7 +167,7 @@ func VerifyAuthWithOptions(options *VerifyOptions) func(ctx *context.Context) { } } -// Checks authentication according to options +// VerifyAuthWithOptionsAPI checks authentication according to options func VerifyAuthWithOptionsAPI(options *VerifyOptions) func(ctx *context.APIContext) { return func(ctx *context.APIContext) { // Check prohibit login users. @@ -197,7 +211,9 @@ func VerifyAuthWithOptionsAPI(options *VerifyOptions) func(ctx *context.APIConte return } else if !ctx.Doer.IsActive && setting.Service.RegisterEmailConfirm { ctx.Data["Title"] = ctx.Tr("auth.active_your_account") - ctx.HTML(http.StatusOK, "user/auth/activate") + ctx.JSON(http.StatusForbidden, map[string]string{ + "message": "This account is not activated.", + }) return } if ctx.IsSigned && ctx.IsBasicAuth { diff --git a/services/context/user.go b/services/context/user.go index c713667bca..4e74aa50bd 100644 --- a/services/context/user.go +++ b/services/context/user.go @@ -15,7 +15,7 @@ import ( // UserAssignmentWeb returns a middleware to handle context-user assignment for web routes func UserAssignmentWeb() func(ctx *context.Context) { return func(ctx *context.Context) { - userAssignment(ctx, func(status int, title string, obj interface{}) { + errorFn := func(status int, title string, obj interface{}) { err, ok := obj.(error) if !ok { err = fmt.Errorf("%s", obj) @@ -25,7 +25,8 @@ func UserAssignmentWeb() func(ctx *context.Context) { } else { ctx.ServerError(title, err) } - }) + } + ctx.ContextUser = userAssignment(ctx.Base, ctx.Doer, errorFn) } } @@ -53,18 +54,18 @@ func UserIDAssignmentAPI() func(ctx *context.APIContext) { // UserAssignmentAPI returns a middleware to handle context-user assignment for api routes func UserAssignmentAPI() func(ctx *context.APIContext) { return func(ctx *context.APIContext) { - userAssignment(ctx.Context, ctx.Error) + ctx.ContextUser = userAssignment(ctx.Base, ctx.Doer, ctx.Error) } } -func userAssignment(ctx *context.Context, errCb func(int, string, interface{})) { +func userAssignment(ctx *context.Base, doer *user_model.User, errCb func(int, string, interface{})) (contextUser *user_model.User) { username := ctx.Params(":username") - if ctx.IsSigned && ctx.Doer.LowerName == strings.ToLower(username) { - ctx.ContextUser = ctx.Doer + if doer != nil && doer.LowerName == strings.ToLower(username) { + contextUser = doer } else { var err error - ctx.ContextUser, err = user_model.GetUserByName(ctx, username) + contextUser, err = user_model.GetUserByName(ctx, username) if err != nil { if user_model.IsErrUserNotExist(err) { if redirectUserID, err := user_model.LookupUserRedirect(username); err == nil { @@ -79,4 +80,5 @@ func userAssignment(ctx *context.Context, errCb func(int, string, interface{})) } } } + return contextUser } diff --git a/services/forms/admin.go b/services/forms/admin.go index a749f863f3..4b3cacc606 100644 --- a/services/forms/admin.go +++ b/services/forms/admin.go @@ -27,7 +27,7 @@ type AdminCreateUserForm struct { // Validate validates form fields func (f *AdminCreateUserForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -55,7 +55,7 @@ type AdminEditUserForm struct { // Validate validates form fields func (f *AdminEditUserForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -67,6 +67,6 @@ type AdminDashboardForm struct { // Validate validates form fields func (f *AdminDashboardForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/auth_form.go b/services/forms/auth_form.go index 5625aa1e2e..25acbbb99e 100644 --- a/services/forms/auth_form.go +++ b/services/forms/auth_form.go @@ -86,6 +86,6 @@ type AuthenticationForm struct { // Validate validates fields func (f *AuthenticationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/org.go b/services/forms/org.go index d753531371..c333bead31 100644 --- a/services/forms/org.go +++ b/services/forms/org.go @@ -30,7 +30,7 @@ type CreateOrgForm struct { // Validate validates the fields func (f *CreateOrgForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -48,7 +48,7 @@ type UpdateOrgSettingForm struct { // Validate validates the fields func (f *UpdateOrgSettingForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -70,6 +70,6 @@ type CreateTeamForm struct { // Validate validates the fields func (f *CreateTeamForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/package_form.go b/services/forms/package_form.go index dfec98fff4..cf8abfb8fb 100644 --- a/services/forms/package_form.go +++ b/services/forms/package_form.go @@ -25,6 +25,6 @@ type PackageCleanupRuleForm struct { } func (f *PackageCleanupRuleForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/repo_branch_form.go b/services/forms/repo_branch_form.go index bf1183fc43..5deb0ae463 100644 --- a/services/forms/repo_branch_form.go +++ b/services/forms/repo_branch_form.go @@ -21,7 +21,7 @@ type NewBranchForm struct { // Validate validates the fields func (f *NewBranchForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -33,6 +33,6 @@ type RenameBranchForm struct { // Validate validates the fields func (f *RenameBranchForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/repo_form.go b/services/forms/repo_form.go index d705ecad3f..cacfb64b17 100644 --- a/services/forms/repo_form.go +++ b/services/forms/repo_form.go @@ -54,7 +54,7 @@ type CreateRepoForm struct { // Validate validates the fields func (f *CreateRepoForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -87,7 +87,7 @@ type MigrateRepoForm struct { // Validate validates the fields func (f *MigrateRepoForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -176,7 +176,7 @@ type RepoSettingForm struct { // Validate validates the fields func (f *RepoSettingForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -215,7 +215,7 @@ type ProtectBranchForm struct { // Validate validates the fields func (f *ProtectBranchForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -280,7 +280,7 @@ type NewWebhookForm struct { // Validate validates the fields func (f *NewWebhookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -294,7 +294,7 @@ type NewGogshookForm struct { // Validate validates the fields func (f *NewGogshookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -310,7 +310,7 @@ type NewSlackHookForm struct { // Validate validates the fields func (f *NewSlackHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) if !webhook.IsValidSlackChannel(strings.TrimSpace(f.Channel)) { errs = append(errs, binding.Error{ FieldNames: []string{"Channel"}, @@ -331,7 +331,7 @@ type NewDiscordHookForm struct { // Validate validates the fields func (f *NewDiscordHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -343,7 +343,7 @@ type NewDingtalkHookForm struct { // Validate validates the fields func (f *NewDingtalkHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -356,7 +356,7 @@ type NewTelegramHookForm struct { // Validate validates the fields func (f *NewTelegramHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -370,7 +370,7 @@ type NewMatrixHookForm struct { // Validate validates the fields func (f *NewMatrixHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -382,7 +382,7 @@ type NewMSTeamsHookForm struct { // Validate validates the fields func (f *NewMSTeamsHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -394,7 +394,7 @@ type NewFeishuHookForm struct { // Validate validates the fields func (f *NewFeishuHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -406,7 +406,7 @@ type NewWechatWorkHookForm struct { // Validate validates the fields func (f *NewWechatWorkHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -420,7 +420,7 @@ type NewPackagistHookForm struct { // Validate validates the fields func (f *NewPackagistHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -447,7 +447,7 @@ type CreateIssueForm struct { // Validate validates the fields func (f *CreateIssueForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -460,7 +460,7 @@ type CreateCommentForm struct { // Validate validates the fields func (f *CreateCommentForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -471,7 +471,7 @@ type ReactionForm struct { // Validate validates the fields func (f *ReactionForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -482,7 +482,7 @@ type IssueLockForm struct { // Validate validates the fields func (i *IssueLockForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, i, ctx.Locale) } @@ -550,7 +550,7 @@ type CreateMilestoneForm struct { // Validate validates the fields func (f *CreateMilestoneForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -572,7 +572,7 @@ type CreateLabelForm struct { // Validate validates the fields func (f *CreateLabelForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -583,7 +583,7 @@ type InitializeLabelsForm struct { // Validate validates the fields func (f *InitializeLabelsForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -611,7 +611,7 @@ type MergePullRequestForm struct { // Validate validates the fields func (f *MergePullRequestForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -629,7 +629,7 @@ type CodeCommentForm struct { // Validate validates the fields func (f *CodeCommentForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -643,7 +643,7 @@ type SubmitReviewForm struct { // Validate validates the fields func (f *SubmitReviewForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -704,7 +704,7 @@ type NewReleaseForm struct { // Validate validates the fields func (f *NewReleaseForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -719,7 +719,7 @@ type EditReleaseForm struct { // Validate validates the fields func (f *EditReleaseForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -740,7 +740,7 @@ type NewWikiForm struct { // Validate validates the fields // FIXME: use code generation to generate this method. func (f *NewWikiForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -765,7 +765,7 @@ type EditRepoFileForm struct { // Validate validates the fields func (f *EditRepoFileForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -776,7 +776,7 @@ type EditPreviewDiffForm struct { // Validate validates the fields func (f *EditPreviewDiffForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -800,7 +800,7 @@ type CherryPickForm struct { // Validate validates the fields func (f *CherryPickForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -825,7 +825,7 @@ type UploadRepoFileForm struct { // Validate validates the fields func (f *UploadRepoFileForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -836,7 +836,7 @@ type RemoveUploadFileForm struct { // Validate validates the fields func (f *RemoveUploadFileForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -859,7 +859,7 @@ type DeleteRepoFileForm struct { // Validate validates the fields func (f *DeleteRepoFileForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -878,7 +878,7 @@ type AddTimeManuallyForm struct { // Validate validates the fields func (f *AddTimeManuallyForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -894,6 +894,6 @@ type DeadlineForm struct { // Validate validates the fields func (f *DeadlineForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/repo_tag_form.go b/services/forms/repo_tag_form.go index 1209d2346f..4dd99f9e32 100644 --- a/services/forms/repo_tag_form.go +++ b/services/forms/repo_tag_form.go @@ -21,6 +21,6 @@ type ProtectTagForm struct { // Validate validates the fields func (f *ProtectTagForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/runner.go b/services/forms/runner.go index 9063060346..22dea49e31 100644 --- a/services/forms/runner.go +++ b/services/forms/runner.go @@ -20,6 +20,6 @@ type EditRunnerForm struct { // Validate validates form fields func (f *EditRunnerForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/user_form.go b/services/forms/user_form.go index 285bc398b2..fa8129bf85 100644 --- a/services/forms/user_form.go +++ b/services/forms/user_form.go @@ -78,7 +78,7 @@ type InstallForm struct { // Validate validates the fields func (f *InstallForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -99,7 +99,7 @@ type RegisterForm struct { // Validate validates the fields func (f *RegisterForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -148,7 +148,7 @@ type MustChangePasswordForm struct { // Validate validates the fields func (f *MustChangePasswordForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -162,7 +162,7 @@ type SignInForm struct { // Validate validates the fields func (f *SignInForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -182,7 +182,7 @@ type AuthorizationForm struct { // Validate validates the fields func (f *AuthorizationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -197,7 +197,7 @@ type GrantApplicationForm struct { // Validate validates the fields func (f *GrantApplicationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -216,7 +216,7 @@ type AccessTokenForm struct { // Validate validates the fields func (f *AccessTokenForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -227,7 +227,7 @@ type IntrospectTokenForm struct { // Validate validates the fields func (f *IntrospectTokenForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -252,7 +252,7 @@ type UpdateProfileForm struct { // Validate validates the fields func (f *UpdateProfileForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -263,7 +263,7 @@ type UpdateLanguageForm struct { // Validate validates the fields func (f *UpdateLanguageForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -283,7 +283,7 @@ type AvatarForm struct { // Validate validates the fields func (f *AvatarForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -294,7 +294,7 @@ type AddEmailForm struct { // Validate validates the fields func (f *AddEmailForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -305,7 +305,7 @@ type UpdateThemeForm struct { // Validate validates the field func (f *UpdateThemeForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -332,7 +332,7 @@ type ChangePasswordForm struct { // Validate validates the fields func (f *ChangePasswordForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -343,7 +343,7 @@ type AddOpenIDForm struct { // Validate validates the fields func (f *AddOpenIDForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -360,7 +360,7 @@ type AddKeyForm struct { // Validate validates the fields func (f *AddKeyForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -372,7 +372,7 @@ type AddSecretForm struct { // Validate validates the fields func (f *AddSecretForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -384,7 +384,7 @@ type NewAccessTokenForm struct { // Validate validates the fields func (f *NewAccessTokenForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -403,7 +403,7 @@ type EditOAuth2ApplicationForm struct { // Validate validates the fields func (f *EditOAuth2ApplicationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -414,7 +414,7 @@ type TwoFactorAuthForm struct { // Validate validates the fields func (f *TwoFactorAuthForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -425,7 +425,7 @@ type TwoFactorScratchAuthForm struct { // Validate validates the fields func (f *TwoFactorScratchAuthForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -436,7 +436,7 @@ type WebauthnRegistrationForm struct { // Validate validates the fields func (f *WebauthnRegistrationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -447,7 +447,7 @@ type WebauthnDeleteForm struct { // Validate validates the fields func (f *WebauthnDeleteForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -459,6 +459,6 @@ type PackageSettingForm struct { // Validate validates the fields func (f *PackageSettingForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/user_form_auth_openid.go b/services/forms/user_form_auth_openid.go index f95eb98405..d8137a8d13 100644 --- a/services/forms/user_form_auth_openid.go +++ b/services/forms/user_form_auth_openid.go @@ -20,7 +20,7 @@ type SignInOpenIDForm struct { // Validate validates the fields func (f *SignInOpenIDForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -32,7 +32,7 @@ type SignUpOpenIDForm struct { // Validate validates the fields func (f *SignUpOpenIDForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -44,6 +44,6 @@ type ConnectOpenIDForm struct { // Validate validates the fields func (f *ConnectOpenIDForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/markup/processorhelper_test.go b/services/markup/processorhelper_test.go index 6c9c1c27e7..2f48e03b22 100644 --- a/services/markup/processorhelper_test.go +++ b/services/markup/processorhelper_test.go @@ -6,6 +6,7 @@ package markup import ( "context" "net/http" + "net/http/httptest" "testing" "code.gitea.io/gitea/models/db" @@ -36,12 +37,12 @@ func TestProcessorHelper(t *testing.T) { assert.False(t, ProcessorHelper().IsUsernameMentionable(context.Background(), userNoSuch)) // when using web context, use user.IsUserVisibleToViewer to check - var err error - giteaCtx := &gitea_context.Context{} - giteaCtx.Req, err = http.NewRequest("GET", "/", nil) + req, err := http.NewRequest("GET", "/", nil) assert.NoError(t, err) + base, baseCleanUp := gitea_context.NewBaseContext(httptest.NewRecorder(), req) + defer baseCleanUp() + giteaCtx := &gitea_context.Context{Base: base} - giteaCtx.Doer = nil assert.True(t, ProcessorHelper().IsUsernameMentionable(giteaCtx, userPublic)) assert.False(t, ProcessorHelper().IsUsernameMentionable(giteaCtx, userPrivate))