diff --git a/cmd/cmd.go b/cmd/cmd.go index bb768cc15..8d9d1ee07 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -7,9 +7,13 @@ package cmd import ( + "context" "errors" "fmt" + "os" + "os/signal" "strings" + "syscall" "code.gitea.io/gitea/models" "code.gitea.io/gitea/modules/setting" @@ -66,3 +70,25 @@ func initDBDisableConsole(disableConsole bool) error { } return nil } + +func installSignals() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + // install notify + signalChannel := make(chan os.Signal, 1) + + signal.Notify( + signalChannel, + syscall.SIGINT, + syscall.SIGTERM, + ) + select { + case <-signalChannel: + case <-ctx.Done(): + } + cancel() + signal.Reset() + }() + + return ctx, cancel +} diff --git a/cmd/hook.go b/cmd/hook.go index 067a0bfb8..87f1f3756 100644 --- a/cmd/hook.go +++ b/cmd/hook.go @@ -152,17 +152,18 @@ func runHookPreReceive(c *cli.Context) error { if os.Getenv(models.EnvIsInternal) == "true" { return nil } + ctx, cancel := installSignals() + defer cancel() setup("hooks/pre-receive.log", c.Bool("debug")) if len(os.Getenv("SSH_ORIGINAL_COMMAND")) == 0 { if setting.OnlyAllowPushIfGiteaEnvironmentSet { - fail(`Rejecting changes as Gitea environment not set. + return fail(`Rejecting changes as Gitea environment not set. If you are pushing over SSH you must push with a key managed by Gitea or set your environment appropriately.`, "") - } else { - return nil } + return nil } // the environment is set by serv command @@ -235,14 +236,14 @@ Gitea or set your environment appropriately.`, "") hookOptions.OldCommitIDs = oldCommitIDs hookOptions.NewCommitIDs = newCommitIDs hookOptions.RefFullNames = refFullNames - statusCode, msg := private.HookPreReceive(username, reponame, hookOptions) + statusCode, msg := private.HookPreReceive(ctx, username, reponame, hookOptions) switch statusCode { case http.StatusOK: // no-op case http.StatusInternalServerError: - fail("Internal Server Error", msg) + return fail("Internal Server Error", msg) default: - fail(msg, "") + return fail(msg, "") } count = 0 lastline = 0 @@ -263,12 +264,12 @@ Gitea or set your environment appropriately.`, "") fmt.Fprintf(out, " Checking %d references\n", count) - statusCode, msg := private.HookPreReceive(username, reponame, hookOptions) + statusCode, msg := private.HookPreReceive(ctx, username, reponame, hookOptions) switch statusCode { case http.StatusInternalServerError: - fail("Internal Server Error", msg) + return fail("Internal Server Error", msg) case http.StatusForbidden: - fail(msg, "") + return fail(msg, "") } } else if lastline > 0 { fmt.Fprintf(out, "\n") @@ -285,8 +286,11 @@ func runHookUpdate(c *cli.Context) error { } func runHookPostReceive(c *cli.Context) error { + ctx, cancel := installSignals() + defer cancel() + // First of all run update-server-info no matter what - if _, err := git.NewCommand("update-server-info").Run(); err != nil { + if _, err := git.NewCommand("update-server-info").SetParentContext(ctx).Run(); err != nil { return fmt.Errorf("Failed to call 'git update-server-info': %v", err) } @@ -299,12 +303,11 @@ func runHookPostReceive(c *cli.Context) error { if len(os.Getenv("SSH_ORIGINAL_COMMAND")) == 0 { if setting.OnlyAllowPushIfGiteaEnvironmentSet { - fail(`Rejecting changes as Gitea environment not set. + return fail(`Rejecting changes as Gitea environment not set. If you are pushing over SSH you must push with a key managed by Gitea or set your environment appropriately.`, "") - } else { - return nil } + return nil } var out io.Writer @@ -371,11 +374,11 @@ Gitea or set your environment appropriately.`, "") hookOptions.OldCommitIDs = oldCommitIDs hookOptions.NewCommitIDs = newCommitIDs hookOptions.RefFullNames = refFullNames - resp, err := private.HookPostReceive(repoUser, repoName, hookOptions) + resp, err := private.HookPostReceive(ctx, repoUser, repoName, hookOptions) if resp == nil { _ = dWriter.Close() hookPrintResults(results) - fail("Internal Server Error", err) + return fail("Internal Server Error", err) } wasEmpty = wasEmpty || resp.RepoWasEmpty results = append(results, resp.Results...) @@ -386,9 +389,9 @@ Gitea or set your environment appropriately.`, "") if count == 0 { if wasEmpty && masterPushed { // We need to tell the repo to reset the default branch to master - err := private.SetDefaultBranch(repoUser, repoName, "master") + err := private.SetDefaultBranch(ctx, repoUser, repoName, "master") if err != nil { - fail("Internal Server Error", "SetDefaultBranch failed with Error: %v", err) + return fail("Internal Server Error", "SetDefaultBranch failed with Error: %v", err) } } fmt.Fprintf(out, "Processed %d references in total\n", total) @@ -404,11 +407,11 @@ Gitea or set your environment appropriately.`, "") fmt.Fprintf(out, " Processing %d references\n", count) - resp, err := private.HookPostReceive(repoUser, repoName, hookOptions) + resp, err := private.HookPostReceive(ctx, repoUser, repoName, hookOptions) if resp == nil { _ = dWriter.Close() hookPrintResults(results) - fail("Internal Server Error", err) + return fail("Internal Server Error", err) } wasEmpty = wasEmpty || resp.RepoWasEmpty results = append(results, resp.Results...) @@ -417,9 +420,9 @@ Gitea or set your environment appropriately.`, "") if wasEmpty && masterPushed { // We need to tell the repo to reset the default branch to master - err := private.SetDefaultBranch(repoUser, repoName, "master") + err := private.SetDefaultBranch(ctx, repoUser, repoName, "master") if err != nil { - fail("Internal Server Error", "SetDefaultBranch failed with Error: %v", err) + return fail("Internal Server Error", "SetDefaultBranch failed with Error: %v", err) } } _ = dWriter.Close() diff --git a/cmd/keys.go b/cmd/keys.go index 7456815cd..684aca64e 100644 --- a/cmd/keys.go +++ b/cmd/keys.go @@ -62,9 +62,12 @@ func runKeys(c *cli.Context) error { return errors.New("No key type and content provided") } + ctx, cancel := installSignals() + defer cancel() + setup("keys.log", false) - authorizedString, err := private.AuthorizedPublicKeyByContent(content) + authorizedString, err := private.AuthorizedPublicKeyByContent(ctx, content) if err != nil { return err } diff --git a/cmd/mailer.go b/cmd/mailer.go index ee11b56cc..1a4b0902e 100644 --- a/cmd/mailer.go +++ b/cmd/mailer.go @@ -14,6 +14,9 @@ import ( ) func runSendMail(c *cli.Context) error { + ctx, cancel := installSignals() + defer cancel() + setting.NewContext() if err := argsSet(c, "title"); err != nil { @@ -39,7 +42,7 @@ func runSendMail(c *cli.Context) error { } } - status, message := private.SendEmail(subject, body, nil) + status, message := private.SendEmail(ctx, subject, body, nil) if status != http.StatusOK { fmt.Printf("error: %s\n", message) return nil diff --git a/cmd/manager.go b/cmd/manager.go index 20c785868..99d283b44 100644 --- a/cmd/manager.go +++ b/cmd/manager.go @@ -236,10 +236,13 @@ func runRemoveLogger(c *cli.Context) error { group = log.DEFAULT } name := c.Args().First() - statusCode, msg := private.RemoveLogger(group, name) + ctx, cancel := installSignals() + defer cancel() + + statusCode, msg := private.RemoveLogger(ctx, group, name) switch statusCode { case http.StatusInternalServerError: - fail("InternalServerError", msg) + return fail("InternalServerError", msg) } fmt.Fprintln(os.Stdout, msg) @@ -371,10 +374,13 @@ func commonAddLogger(c *cli.Context, mode string, vals map[string]interface{}) e if c.IsSet("name") { name = c.String("name") } - statusCode, msg := private.AddLogger(group, name, mode, vals) + ctx, cancel := installSignals() + defer cancel() + + statusCode, msg := private.AddLogger(ctx, group, name, mode, vals) switch statusCode { case http.StatusInternalServerError: - fail("InternalServerError", msg) + return fail("InternalServerError", msg) } fmt.Fprintln(os.Stdout, msg) @@ -382,11 +388,14 @@ func commonAddLogger(c *cli.Context, mode string, vals map[string]interface{}) e } func runShutdown(c *cli.Context) error { + ctx, cancel := installSignals() + defer cancel() + setup("manager", c.Bool("debug")) - statusCode, msg := private.Shutdown() + statusCode, msg := private.Shutdown(ctx) switch statusCode { case http.StatusInternalServerError: - fail("InternalServerError", msg) + return fail("InternalServerError", msg) } fmt.Fprintln(os.Stdout, msg) @@ -394,11 +403,14 @@ func runShutdown(c *cli.Context) error { } func runRestart(c *cli.Context) error { + ctx, cancel := installSignals() + defer cancel() + setup("manager", c.Bool("debug")) - statusCode, msg := private.Restart() + statusCode, msg := private.Restart(ctx) switch statusCode { case http.StatusInternalServerError: - fail("InternalServerError", msg) + return fail("InternalServerError", msg) } fmt.Fprintln(os.Stdout, msg) @@ -406,11 +418,14 @@ func runRestart(c *cli.Context) error { } func runFlushQueues(c *cli.Context) error { + ctx, cancel := installSignals() + defer cancel() + setup("manager", c.Bool("debug")) - statusCode, msg := private.FlushQueues(c.Duration("timeout"), c.Bool("non-blocking")) + statusCode, msg := private.FlushQueues(ctx, c.Duration("timeout"), c.Bool("non-blocking")) switch statusCode { case http.StatusInternalServerError: - fail("InternalServerError", msg) + return fail("InternalServerError", msg) } fmt.Fprintln(os.Stdout, msg) @@ -418,11 +433,14 @@ func runFlushQueues(c *cli.Context) error { } func runPauseLogging(c *cli.Context) error { + ctx, cancel := installSignals() + defer cancel() + setup("manager", c.Bool("debug")) - statusCode, msg := private.PauseLogging() + statusCode, msg := private.PauseLogging(ctx) switch statusCode { case http.StatusInternalServerError: - fail("InternalServerError", msg) + return fail("InternalServerError", msg) } fmt.Fprintln(os.Stdout, msg) @@ -430,11 +448,14 @@ func runPauseLogging(c *cli.Context) error { } func runResumeLogging(c *cli.Context) error { + ctx, cancel := installSignals() + defer cancel() + setup("manager", c.Bool("debug")) - statusCode, msg := private.ResumeLogging() + statusCode, msg := private.ResumeLogging(ctx) switch statusCode { case http.StatusInternalServerError: - fail("InternalServerError", msg) + return fail("InternalServerError", msg) } fmt.Fprintln(os.Stdout, msg) @@ -442,11 +463,14 @@ func runResumeLogging(c *cli.Context) error { } func runReleaseReopenLogging(c *cli.Context) error { + ctx, cancel := installSignals() + defer cancel() + setup("manager", c.Bool("debug")) - statusCode, msg := private.ReleaseReopenLogging() + statusCode, msg := private.ReleaseReopenLogging(ctx) switch statusCode { case http.StatusInternalServerError: - fail("InternalServerError", msg) + return fail("InternalServerError", msg) } fmt.Fprintln(os.Stdout, msg) diff --git a/cmd/restore_repo.go b/cmd/restore_repo.go index b83247192..1208796c9 100644 --- a/cmd/restore_repo.go +++ b/cmd/restore_repo.go @@ -40,20 +40,24 @@ var CmdRestoreRepository = cli.Command{ cli.StringFlag{ Name: "units", Value: "", - Usage: `Which items will be restored, one or more units should be separated as comma. + Usage: `Which items will be restored, one or more units should be separated as comma. wiki, issues, labels, releases, release_assets, milestones, pull_requests, comments are allowed. Empty means all units.`, }, }, } -func runRestoreRepository(ctx *cli.Context) error { +func runRestoreRepository(c *cli.Context) error { + ctx, cancel := installSignals() + defer cancel() + setting.NewContext() statusCode, errStr := private.RestoreRepo( - ctx.String("repo_dir"), - ctx.String("owner_name"), - ctx.String("repo_name"), - ctx.StringSlice("units"), + ctx, + c.String("repo_dir"), + c.String("owner_name"), + c.String("repo_name"), + c.StringSlice("units"), ) if statusCode == http.StatusOK { return nil diff --git a/cmd/serv.go b/cmd/serv.go index 40f8b89c9..97ae901d2 100644 --- a/cmd/serv.go +++ b/cmd/serv.go @@ -6,17 +6,14 @@ package cmd import ( - "context" "fmt" "net/http" "net/url" "os" "os/exec" - "os/signal" "regexp" "strconv" "strings" - "syscall" "time" "code.gitea.io/gitea/models" @@ -75,7 +72,10 @@ var ( alphaDashDotPattern = regexp.MustCompile(`[^\w-\.]`) ) -func fail(userMessage, logMessage string, args ...interface{}) { +func fail(userMessage, logMessage string, args ...interface{}) error { + // There appears to be a chance to cause a zombie process and failure to read the Exit status + // if nothing is outputted on stdout. + fmt.Fprintln(os.Stdout, "") fmt.Fprintln(os.Stderr, "Gitea:", userMessage) if len(logMessage) > 0 { @@ -83,15 +83,19 @@ func fail(userMessage, logMessage string, args ...interface{}) { fmt.Fprintf(os.Stderr, logMessage+"\n", args...) } } + ctx, cancel := installSignals() + defer cancel() if len(logMessage) > 0 { - _ = private.SSHLog(true, fmt.Sprintf(logMessage+": ", args...)) + _ = private.SSHLog(ctx, true, fmt.Sprintf(logMessage+": ", args...)) } - - os.Exit(1) + return cli.NewExitError(fmt.Sprintf("Gitea: %s", userMessage), 1) } func runServ(c *cli.Context) error { + ctx, cancel := installSignals() + defer cancel() + // FIXME: This needs to internationalised setup("serv.log", c.Bool("debug")) @@ -109,18 +113,18 @@ func runServ(c *cli.Context) error { keys := strings.Split(c.Args()[0], "-") if len(keys) != 2 || keys[0] != "key" { - fail("Key ID format error", "Invalid key argument: %s", c.Args()[0]) + return fail("Key ID format error", "Invalid key argument: %s", c.Args()[0]) } keyID, err := strconv.ParseInt(keys[1], 10, 64) if err != nil { - fail("Key ID format error", "Invalid key argument: %s", c.Args()[1]) + return fail("Key ID format error", "Invalid key argument: %s", c.Args()[1]) } cmd := os.Getenv("SSH_ORIGINAL_COMMAND") if len(cmd) == 0 { - key, user, err := private.ServNoCommand(keyID) + key, user, err := private.ServNoCommand(ctx, keyID) if err != nil { - fail("Internal error", "Failed to check provided key: %v", err) + return fail("Internal error", "Failed to check provided key: %v", err) } switch key.Type { case models.KeyTypeDeploy: @@ -138,11 +142,11 @@ func runServ(c *cli.Context) error { words, err := shellquote.Split(cmd) if err != nil { - fail("Error parsing arguments", "Failed to parse arguments: %v", err) + return fail("Error parsing arguments", "Failed to parse arguments: %v", err) } if len(words) < 2 { - fail("Too few arguments", "Too few arguments in cmd: %s", cmd) + return fail("Too few arguments", "Too few arguments in cmd: %s", cmd) } verb := words[0] @@ -154,7 +158,7 @@ func runServ(c *cli.Context) error { var lfsVerb string if verb == lfsAuthenticateVerb { if !setting.LFS.StartServer { - fail("Unknown git command", "LFS authentication request over SSH denied, LFS support is disabled") + return fail("Unknown git command", "LFS authentication request over SSH denied, LFS support is disabled") } if len(words) > 2 { @@ -167,37 +171,37 @@ func runServ(c *cli.Context) error { rr := strings.SplitN(repoPath, "/", 2) if len(rr) != 2 { - fail("Invalid repository path", "Invalid repository path: %v", repoPath) + return fail("Invalid repository path", "Invalid repository path: %v", repoPath) } username := strings.ToLower(rr[0]) reponame := strings.ToLower(strings.TrimSuffix(rr[1], ".git")) if alphaDashDotPattern.MatchString(reponame) { - fail("Invalid repo name", "Invalid repo name: %s", reponame) + return fail("Invalid repo name", "Invalid repo name: %s", reponame) } if setting.EnablePprof || c.Bool("enable-pprof") { if err := os.MkdirAll(setting.PprofDataPath, os.ModePerm); err != nil { - fail("Error while trying to create PPROF_DATA_PATH", "Error while trying to create PPROF_DATA_PATH: %v", err) + return fail("Error while trying to create PPROF_DATA_PATH", "Error while trying to create PPROF_DATA_PATH: %v", err) } stopCPUProfiler, err := pprof.DumpCPUProfileForUsername(setting.PprofDataPath, username) if err != nil { - fail("Internal Server Error", "Unable to start CPU profile: %v", err) + return fail("Internal Server Error", "Unable to start CPU profile: %v", err) } defer func() { stopCPUProfiler() err := pprof.DumpMemProfileForUsername(setting.PprofDataPath, username) if err != nil { - fail("Internal Server Error", "Unable to dump Mem Profile: %v", err) + _ = fail("Internal Server Error", "Unable to dump Mem Profile: %v", err) } }() } requestedMode, has := allowedCommands[verb] if !has { - fail("Unknown git command", "Unknown git command %s", verb) + return fail("Unknown git command", "Unknown git command %s", verb) } if verb == lfsAuthenticateVerb { @@ -206,21 +210,20 @@ func runServ(c *cli.Context) error { } else if lfsVerb == "download" { requestedMode = models.AccessModeRead } else { - fail("Unknown LFS verb", "Unknown lfs verb %s", lfsVerb) + return fail("Unknown LFS verb", "Unknown lfs verb %s", lfsVerb) } } - results, err := private.ServCommand(keyID, username, reponame, requestedMode, verb, lfsVerb) + results, err := private.ServCommand(ctx, keyID, username, reponame, requestedMode, verb, lfsVerb) if err != nil { if private.IsErrServCommand(err) { errServCommand := err.(private.ErrServCommand) if errServCommand.StatusCode != http.StatusInternalServerError { - fail("Unauthorized", "%s", errServCommand.Error()) - } else { - fail("Internal Server Error", "%s", errServCommand.Error()) + return fail("Unauthorized", "%s", errServCommand.Error()) } + return fail("Internal Server Error", "%s", errServCommand.Error()) } - fail("Internal Server Error", "%s", err.Error()) + return fail("Internal Server Error", "%s", err.Error()) } os.Setenv(models.EnvRepoIsWiki, strconv.FormatBool(results.IsWiki)) os.Setenv(models.EnvRepoName, results.RepoName) @@ -253,7 +256,7 @@ func runServ(c *cli.Context) error { // Sign and get the complete encoded token as a string using the secret tokenString, err := token.SignedString(setting.LFS.JWTSecretBytes) if err != nil { - fail("Internal error", "Failed to sign JWT token: %v", err) + return fail("Internal error", "Failed to sign JWT token: %v", err) } tokenAuthentication := &models.LFSTokenResponse{ @@ -266,7 +269,7 @@ func runServ(c *cli.Context) error { enc := json.NewEncoder(os.Stdout) err = enc.Encode(tokenAuthentication) if err != nil { - fail("Internal error", "Failed to encode LFS json response: %v", err) + return fail("Internal error", "Failed to encode LFS json response: %v", err) } return nil } @@ -276,25 +279,6 @@ func runServ(c *cli.Context) error { verb = strings.Replace(verb, "-", " ", 1) } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go func() { - // install notify - signalChannel := make(chan os.Signal, 1) - - signal.Notify( - signalChannel, - syscall.SIGINT, - syscall.SIGTERM, - ) - select { - case <-signalChannel: - case <-ctx.Done(): - } - cancel() - signal.Reset() - }() - var gitcmd *exec.Cmd verbs := strings.Split(verb, " ") if len(verbs) == 2 { @@ -308,13 +292,13 @@ func runServ(c *cli.Context) error { gitcmd.Stdin = os.Stdin gitcmd.Stderr = os.Stderr if err = gitcmd.Run(); err != nil { - fail("Internal error", "Failed to execute git command: %v", err) + return fail("Internal error", "Failed to execute git command: %v", err) } // Update user key activity. if results.KeyID > 0 { - if err = private.UpdatePublicKeyInRepo(results.KeyID, results.RepoID); err != nil { - fail("Internal error", "UpdatePublicKeyInRepo: %v", err) + if err = private.UpdatePublicKeyInRepo(ctx, results.KeyID, results.RepoID); err != nil { + return fail("Internal error", "UpdatePublicKeyInRepo: %v", err) } } diff --git a/integrations/mssql.ini.tmpl b/integrations/mssql.ini.tmpl index 1867070ff..0a7710fc5 100644 --- a/integrations/mssql.ini.tmpl +++ b/integrations/mssql.ini.tmpl @@ -83,6 +83,7 @@ MODE = test,file ROOT_PATH = mssql-log ROUTER = , XORM = file +ENABLE_SSH_LOG = true [log.test] LEVEL = Info diff --git a/integrations/mysql.ini.tmpl b/integrations/mysql.ini.tmpl index 176992cb2..a78b0425a 100644 --- a/integrations/mysql.ini.tmpl +++ b/integrations/mysql.ini.tmpl @@ -101,6 +101,7 @@ MODE = test,file ROOT_PATH = mysql-log ROUTER = , XORM = file +ENABLE_SSH_LOG = true [log.test] LEVEL = Info diff --git a/integrations/mysql8.ini.tmpl b/integrations/mysql8.ini.tmpl index 7c5bcb58d..1151b6abc 100644 --- a/integrations/mysql8.ini.tmpl +++ b/integrations/mysql8.ini.tmpl @@ -80,6 +80,7 @@ MODE = test,file ROOT_PATH = mysql8-log ROUTER = , XORM = file +ENABLE_SSH_LOG = true [log.test] LEVEL = Info diff --git a/integrations/pgsql.ini.tmpl b/integrations/pgsql.ini.tmpl index 3a4a5e6c4..f11d4faba 100644 --- a/integrations/pgsql.ini.tmpl +++ b/integrations/pgsql.ini.tmpl @@ -84,6 +84,7 @@ MODE = test,file ROOT_PATH = pgsql-log ROUTER = , XORM = file +ENABLE_SSH_LOG = true [log.test] LEVEL = Info diff --git a/integrations/sqlite.ini.tmpl b/integrations/sqlite.ini.tmpl index 4a796e931..71ac39a44 100644 --- a/integrations/sqlite.ini.tmpl +++ b/integrations/sqlite.ini.tmpl @@ -79,6 +79,7 @@ MODE = test,file ROOT_PATH = sqlite-log ROUTER = , XORM = file +ENABLE_SSH_LOG = true [log.test] LEVEL = Info diff --git a/modules/httplib/httplib.go b/modules/httplib/httplib.go index 294ad0b70..5c8eac8b4 100644 --- a/modules/httplib/httplib.go +++ b/modules/httplib/httplib.go @@ -7,6 +7,7 @@ package httplib import ( "bytes" + "context" "crypto/tls" "encoding/xml" "io" @@ -122,6 +123,12 @@ func (r *Request) Setting(setting Settings) *Request { return r } +// SetContext sets the request's Context +func (r *Request) SetContext(ctx context.Context) *Request { + r.req = r.req.WithContext(ctx) + return r +} + // SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. func (r *Request) SetBasicAuth(username, password string) *Request { r.req.SetBasicAuth(username, password) @@ -325,7 +332,7 @@ func (r *Request) getResponse() (*http.Response, error) { trans = &http.Transport{ TLSClientConfig: r.setting.TLSClientConfig, Proxy: proxy, - Dial: TimeoutDialer(r.setting.ConnectTimeout), + DialContext: TimeoutDialer(r.setting.ConnectTimeout), } } else if t, ok := trans.(*http.Transport); ok { if t.TLSClientConfig == nil { @@ -334,8 +341,8 @@ func (r *Request) getResponse() (*http.Response, error) { if t.Proxy == nil { t.Proxy = r.setting.Proxy } - if t.Dial == nil { - t.Dial = TimeoutDialer(r.setting.ConnectTimeout) + if t.DialContext == nil { + t.DialContext = TimeoutDialer(r.setting.ConnectTimeout) } } @@ -458,9 +465,10 @@ func (r *Request) Response() (*http.Response, error) { } // TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. -func TimeoutDialer(cTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { - return func(netw, addr string) (net.Conn, error) { - conn, err := net.DialTimeout(netw, addr, cTimeout) +func TimeoutDialer(cTimeout time.Duration) func(ctx context.Context, net, addr string) (c net.Conn, err error) { + return func(ctx context.Context, netw, addr string) (net.Conn, error) { + d := net.Dialer{Timeout: cTimeout} + conn, err := d.DialContext(ctx, netw, addr) if err != nil { return nil, err } diff --git a/modules/private/hook.go b/modules/private/hook.go index 82dcaf3fc..79fae052d 100644 --- a/modules/private/hook.go +++ b/modules/private/hook.go @@ -5,6 +5,7 @@ package private import ( + "context" "encoding/json" "fmt" "net/http" @@ -80,12 +81,12 @@ type HookPostReceiveBranchResult struct { } // HookPreReceive check whether the provided commits are allowed -func HookPreReceive(ownerName, repoName string, opts HookOptions) (int, string) { +func HookPreReceive(ctx context.Context, ownerName, repoName string, opts HookOptions) (int, string) { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/pre-receive/%s/%s", url.PathEscape(ownerName), url.PathEscape(repoName), ) - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") json := jsoniter.ConfigCompatibleWithStandardLibrary jsonBytes, _ := json.Marshal(opts) @@ -105,13 +106,13 @@ func HookPreReceive(ownerName, repoName string, opts HookOptions) (int, string) } // HookPostReceive updates services and users -func HookPostReceive(ownerName, repoName string, opts HookOptions) (*HookPostReceiveResult, string) { +func HookPostReceive(ctx context.Context, ownerName, repoName string, opts HookOptions) (*HookPostReceiveResult, string) { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/post-receive/%s/%s", url.PathEscape(ownerName), url.PathEscape(repoName), ) - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") req.SetTimeout(60*time.Second, time.Duration(60+len(opts.OldCommitIDs))*time.Second) json := jsoniter.ConfigCompatibleWithStandardLibrary @@ -133,13 +134,13 @@ func HookPostReceive(ownerName, repoName string, opts HookOptions) (*HookPostRec } // SetDefaultBranch will set the default branch to the provided branch for the provided repository -func SetDefaultBranch(ownerName, repoName, branch string) error { +func SetDefaultBranch(ctx context.Context, ownerName, repoName, branch string) error { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/set-default-branch/%s/%s/%s", url.PathEscape(ownerName), url.PathEscape(repoName), url.PathEscape(branch), ) - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") req.SetTimeout(60*time.Second, 60*time.Second) @@ -155,9 +156,9 @@ func SetDefaultBranch(ownerName, repoName, branch string) error { } // SSHLog sends ssh error log response -func SSHLog(isErr bool, msg string) error { +func SSHLog(ctx context.Context, isErr bool, msg string) error { reqURL := setting.LocalURL + "api/internal/ssh/log" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") jsonBytes, _ := json.Marshal(&SSHLogOption{ @@ -171,6 +172,7 @@ func SSHLog(isErr bool, msg string) error { if err != nil { return fmt.Errorf("unable to contact gitea: %v", err) } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("Error returned from gitea: %v", decodeJSONError(resp).Err) diff --git a/modules/private/internal.go b/modules/private/internal.go index 360fae47b..672ac7497 100644 --- a/modules/private/internal.go +++ b/modules/private/internal.go @@ -5,6 +5,7 @@ package private import ( + "context" "crypto/tls" "fmt" "net" @@ -15,9 +16,11 @@ import ( jsoniter "github.com/json-iterator/go" ) -func newRequest(url, method string) *httplib.Request { - return httplib.NewRequest(url, method).Header("Authorization", - fmt.Sprintf("Bearer %s", setting.InternalToken)) +func newRequest(ctx context.Context, url, method string) *httplib.Request { + return httplib.NewRequest(url, method). + SetContext(ctx). + Header("Authorization", + fmt.Sprintf("Bearer %s", setting.InternalToken)) } // Response internal request response @@ -35,8 +38,8 @@ func decodeJSONError(resp *http.Response) *Response { return &res } -func newInternalRequest(url, method string) *httplib.Request { - req := newRequest(url, method).SetTLSClientConfig(&tls.Config{ +func newInternalRequest(ctx context.Context, url, method string) *httplib.Request { + req := newRequest(ctx, url, method).SetTLSClientConfig(&tls.Config{ InsecureSkipVerify: true, ServerName: setting.Domain, }) @@ -45,6 +48,10 @@ func newInternalRequest(url, method string) *httplib.Request { Dial: func(_, _ string) (net.Conn, error) { return net.Dial("unix", setting.HTTPAddr) }, + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", setting.HTTPAddr) + }, }) } return req diff --git a/modules/private/key.go b/modules/private/key.go index bea783790..d0b11a96e 100644 --- a/modules/private/key.go +++ b/modules/private/key.go @@ -5,6 +5,7 @@ package private import ( + "context" "fmt" "io/ioutil" "net/http" @@ -13,10 +14,10 @@ import ( ) // UpdatePublicKeyInRepo update public key and if necessary deploy key updates -func UpdatePublicKeyInRepo(keyID, repoID int64) error { +func UpdatePublicKeyInRepo(ctx context.Context, keyID, repoID int64) error { // Ask for running deliver hook and test pull request tasks. reqURL := setting.LocalURL + fmt.Sprintf("api/internal/ssh/%d/update/%d", keyID, repoID) - resp, err := newInternalRequest(reqURL, "POST").Response() + resp, err := newInternalRequest(ctx, reqURL, "POST").Response() if err != nil { return err } @@ -32,10 +33,10 @@ func UpdatePublicKeyInRepo(keyID, repoID int64) error { // AuthorizedPublicKeyByContent searches content as prefix (leak e-mail part) // and returns public key found. -func AuthorizedPublicKeyByContent(content string) (string, error) { +func AuthorizedPublicKeyByContent(ctx context.Context, content string) (string, error) { // Ask for running deliver hook and test pull request tasks. reqURL := setting.LocalURL + "api/internal/ssh/authorized_keys" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req.Param("content", content) resp, err := req.Response() if err != nil { diff --git a/modules/private/mail.go b/modules/private/mail.go index 9c0912a6e..4a5a3eedd 100644 --- a/modules/private/mail.go +++ b/modules/private/mail.go @@ -5,6 +5,7 @@ package private import ( + "context" "fmt" "io/ioutil" "net/http" @@ -27,10 +28,10 @@ type Email struct { // // If to list == nil its supposed to send an email to every // user present in DB -func SendEmail(subject, message string, to []string) (int, string) { +func SendEmail(ctx context.Context, subject, message string, to []string) (int, string) { reqURL := setting.LocalURL + "api/internal/mail/send" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") json := jsoniter.ConfigCompatibleWithStandardLibrary jsonBytes, _ := json.Marshal(Email{ diff --git a/modules/private/manager.go b/modules/private/manager.go index 2bc6cec3b..0bcc3f811 100644 --- a/modules/private/manager.go +++ b/modules/private/manager.go @@ -5,6 +5,7 @@ package private import ( + "context" "fmt" "net/http" "net/url" @@ -15,10 +16,10 @@ import ( ) // Shutdown calls the internal shutdown function -func Shutdown() (int, string) { +func Shutdown(ctx context.Context) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/shutdown" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) @@ -33,10 +34,10 @@ func Shutdown() (int, string) { } // Restart calls the internal restart function -func Restart() (int, string) { +func Restart(ctx context.Context) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/restart" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) @@ -57,10 +58,10 @@ type FlushOptions struct { } // FlushQueues calls the internal flush-queues function -func FlushQueues(timeout time.Duration, nonBlocking bool) (int, string) { +func FlushQueues(ctx context.Context, timeout time.Duration, nonBlocking bool) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/flush-queues" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") if timeout > 0 { req.SetTimeout(timeout+10*time.Second, timeout+10*time.Second) } @@ -85,10 +86,10 @@ func FlushQueues(timeout time.Duration, nonBlocking bool) (int, string) { } // PauseLogging pauses logging -func PauseLogging() (int, string) { +func PauseLogging(ctx context.Context) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/pause-logging" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) @@ -103,10 +104,10 @@ func PauseLogging() (int, string) { } // ResumeLogging resumes logging -func ResumeLogging() (int, string) { +func ResumeLogging(ctx context.Context) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/resume-logging" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) @@ -121,10 +122,10 @@ func ResumeLogging() (int, string) { } // ReleaseReopenLogging releases and reopens logging files -func ReleaseReopenLogging() (int, string) { +func ReleaseReopenLogging(ctx context.Context) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/release-and-reopen-logging" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) @@ -147,10 +148,10 @@ type LoggerOptions struct { } // AddLogger adds a logger -func AddLogger(group, name, mode string, config map[string]interface{}) (int, string) { +func AddLogger(ctx context.Context, group, name, mode string, config map[string]interface{}) (int, string) { reqURL := setting.LocalURL + "api/internal/manager/add-logger" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req = req.Header("Content-Type", "application/json") json := jsoniter.ConfigCompatibleWithStandardLibrary jsonBytes, _ := json.Marshal(LoggerOptions{ @@ -175,10 +176,10 @@ func AddLogger(group, name, mode string, config map[string]interface{}) (int, st } // RemoveLogger removes a logger -func RemoveLogger(group, name string) (int, string) { +func RemoveLogger(ctx context.Context, group, name string) (int, string) { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/manager/remove-logger/%s/%s", url.PathEscape(group), url.PathEscape(name)) - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") resp, err := req.Response() if err != nil { return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error()) diff --git a/modules/private/restore_repo.go b/modules/private/restore_repo.go index 6fe2e6844..66b60d8d1 100644 --- a/modules/private/restore_repo.go +++ b/modules/private/restore_repo.go @@ -5,6 +5,7 @@ package private import ( + "context" "fmt" "io/ioutil" "net/http" @@ -23,10 +24,10 @@ type RestoreParams struct { } // RestoreRepo calls the internal RestoreRepo function -func RestoreRepo(repoDir, ownerName, repoName string, units []string) (int, string) { +func RestoreRepo(ctx context.Context, repoDir, ownerName, repoName string, units []string) (int, string) { reqURL := setting.LocalURL + "api/internal/restore_repo" - req := newInternalRequest(reqURL, "POST") + req := newInternalRequest(ctx, reqURL, "POST") req.SetTimeout(3*time.Second, 0) // since the request will spend much time, don't timeout req = req.Header("Content-Type", "application/json") json := jsoniter.ConfigCompatibleWithStandardLibrary diff --git a/modules/private/serv.go b/modules/private/serv.go index 659af6dff..9643dad67 100644 --- a/modules/private/serv.go +++ b/modules/private/serv.go @@ -5,6 +5,7 @@ package private import ( + "context" "fmt" "net/http" "net/url" @@ -21,10 +22,10 @@ type KeyAndOwner struct { } // ServNoCommand returns information about the provided key -func ServNoCommand(keyID int64) (*models.PublicKey, *models.User, error) { +func ServNoCommand(ctx context.Context, keyID int64) (*models.PublicKey, *models.User, error) { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/serv/none/%d", keyID) - resp, err := newInternalRequest(reqURL, "GET").Response() + resp, err := newInternalRequest(ctx, reqURL, "GET").Response() if err != nil { return nil, nil, err } @@ -73,7 +74,7 @@ func IsErrServCommand(err error) bool { } // ServCommand preps for a serv call -func ServCommand(keyID int64, ownerName, repoName string, mode models.AccessMode, verbs ...string) (*ServCommandResults, error) { +func ServCommand(ctx context.Context, keyID int64, ownerName, repoName string, mode models.AccessMode, verbs ...string) (*ServCommandResults, error) { reqURL := setting.LocalURL + fmt.Sprintf("api/internal/serv/command/%d/%s/%s?mode=%d", keyID, url.PathEscape(ownerName), @@ -85,7 +86,7 @@ func ServCommand(keyID int64, ownerName, repoName string, mode models.AccessMode } } - resp, err := newInternalRequest(reqURL, "GET").Response() + resp, err := newInternalRequest(ctx, reqURL, "GET").Response() if err != nil { return nil, err } diff --git a/modules/ssh/ssh.go b/modules/ssh/ssh.go index c0897377c..efe952534 100644 --- a/modules/ssh/ssh.go +++ b/modules/ssh/ssh.go @@ -6,6 +6,7 @@ package ssh import ( "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -66,7 +67,11 @@ func sessionHandler(session ssh.Session) { args := []string{"serv", "key-" + keyID, "--config=" + setting.CustomConf} log.Trace("SSH: Arguments: %v", args) - cmd := exec.CommandContext(session.Context(), setting.AppPath, args...) + + ctx, cancel := context.WithCancel(session.Context()) + defer cancel() + + cmd := exec.CommandContext(ctx, setting.AppPath, args...) cmd.Env = append( os.Environ(), "SSH_ORIGINAL_COMMAND="+command, @@ -78,16 +83,21 @@ func sessionHandler(session ssh.Session) { log.Error("SSH: StdoutPipe: %v", err) return } + defer stdout.Close() + stderr, err := cmd.StderrPipe() if err != nil { log.Error("SSH: StderrPipe: %v", err) return } + defer stderr.Close() + stdin, err := cmd.StdinPipe() if err != nil { log.Error("SSH: StdinPipe: %v", err) return } + defer stdin.Close() wg := &sync.WaitGroup{} wg.Add(2) @@ -106,6 +116,7 @@ func sessionHandler(session ssh.Session) { go func() { defer wg.Done() + defer stdout.Close() if _, err := io.Copy(session, stdout); err != nil { log.Error("Failed to write stdout to session. %s", err) } @@ -113,6 +124,7 @@ func sessionHandler(session ssh.Session) { go func() { defer wg.Done() + defer stderr.Close() if _, err := io.Copy(session.Stderr(), stderr); err != nil { log.Error("Failed to write stderr to session. %s", err) }