Refactor git command context & pipeline (#36406)

Less and simpler code, fewer bugs
This commit is contained in:
wxiaoguang
2026-01-21 09:35:14 +08:00
committed by GitHub
parent f6db180a80
commit 9ea91e036f
36 changed files with 286 additions and 434 deletions

View File

@@ -28,13 +28,6 @@ import (
// In most cases, it shouldn't be used. Use AddXxx function instead
type TrustedCmdArgs []internal.CmdArg
// defaultCommandExecutionTimeout default command execution timeout duration
var defaultCommandExecutionTimeout = 360 * time.Second
func SetDefaultCommandExecutionTimeout(timeout time.Duration) {
defaultCommandExecutionTimeout = timeout
}
// DefaultLocale is the default LC_ALL to run git commands in.
const DefaultLocale = "C"
@@ -44,13 +37,14 @@ type Command struct {
prog string
args []string
preErrors []error
cmd *exec.Cmd // for debug purpose only
configArgs []string
opts runOpts
cmd *exec.Cmd
cmdCtx context.Context
cmdCancel context.CancelFunc
cmdFinished context.CancelFunc
cmdCancel process.CancelCauseFunc
cmdFinished process.FinishedFunc
cmdStartTime time.Time
cmdStdinWriter *io.WriteCloser
@@ -209,11 +203,9 @@ func ToTrustedCmdArgs(args []string) TrustedCmdArgs {
return ret
}
// runOpts represents parameters to run the command. If UseContextTimeout is specified, then Timeout is ignored.
type runOpts struct {
Env []string
Timeout time.Duration
UseContextTimeout bool
Env []string
Timeout time.Duration
// Dir is the working dir for the git command, however:
// FIXME: this could be incorrect in many cases, for example:
@@ -236,7 +228,7 @@ type runOpts struct {
// Use new functions like WithStdinWriter to avoid such problems.
Stdin io.Reader
PipelineFunc func(context.Context, context.CancelFunc) error
PipelineFunc func(Context) error
}
func commonBaseEnvs() []string {
@@ -321,16 +313,11 @@ func (c *Command) WithStdin(stdin io.Reader) *Command {
return c
}
func (c *Command) WithPipelineFunc(f func(context.Context, context.CancelFunc) error) *Command {
func (c *Command) WithPipelineFunc(f func(Context) error) *Command {
c.opts.PipelineFunc = f
return c
}
func (c *Command) WithUseContextTimeout(useContextTimeout bool) *Command {
c.opts.UseContextTimeout = useContextTimeout
return c
}
// WithParentCallerInfo can be used to set the caller info (usually function name) of the parent function of the caller.
// For most cases, "Run" family functions can get its caller info automatically
// But if you need to call "Run" family functions in a wrapper function: "FeatureFunc -> GeneralWrapperFunc -> RunXxx",
@@ -363,9 +350,7 @@ func (c *Command) Start(ctx context.Context) (retErr error) {
defer func() {
if retErr != nil {
// release the pipes to avoid resource leak
safeClosePtrCloser(c.cmdStdoutReader)
safeClosePtrCloser(c.cmdStderrReader)
safeClosePtrCloser(c.cmdStdinWriter)
c.closeStdioPipes()
// if error occurs, we must also finish the task, otherwise, cmdFinished will be called in "Wait" function
if c.cmdFinished != nil {
c.cmdFinished()
@@ -380,12 +365,6 @@ func (c *Command) Start(ctx context.Context) (retErr error) {
return err
}
// We must not change the provided options
timeout := c.opts.Timeout
if timeout <= 0 {
timeout = defaultCommandExecutionTimeout
}
cmdLogString := c.LogString()
if c.callerInfo == "" {
c.WithParentCallerInfo()
@@ -399,83 +378,85 @@ func (c *Command) Start(ctx context.Context) (retErr error) {
span.SetAttributeString(gtprof.TraceAttrFuncCaller, c.callerInfo)
span.SetAttributeString(gtprof.TraceAttrGitCommand, cmdLogString)
if c.opts.UseContextTimeout {
if c.opts.Timeout <= 0 {
c.cmdCtx, c.cmdCancel, c.cmdFinished = process.GetManager().AddContext(ctx, desc)
} else {
c.cmdCtx, c.cmdCancel, c.cmdFinished = process.GetManager().AddContextTimeout(ctx, timeout, desc)
c.cmdCtx, c.cmdCancel, c.cmdFinished = process.GetManager().AddContextTimeout(ctx, c.opts.Timeout, desc)
}
c.cmdStartTime = time.Now()
cmd := exec.CommandContext(ctx, c.prog, append(c.configArgs, c.args...)...)
c.cmd = cmd // for debug purpose only
c.cmd = exec.CommandContext(ctx, c.prog, append(c.configArgs, c.args...)...)
if c.opts.Env == nil {
cmd.Env = os.Environ()
c.cmd.Env = os.Environ()
} else {
cmd.Env = c.opts.Env
c.cmd.Env = c.opts.Env
}
process.SetSysProcAttribute(cmd)
cmd.Env = append(cmd.Env, CommonGitCmdEnvs()...)
cmd.Dir = c.opts.Dir
cmd.Stdout = c.opts.Stdout
cmd.Stdin = c.opts.Stdin
process.SetSysProcAttribute(c.cmd)
c.cmd.Env = append(c.cmd.Env, CommonGitCmdEnvs()...)
c.cmd.Dir = c.opts.Dir
c.cmd.Stdout = c.opts.Stdout
c.cmd.Stdin = c.opts.Stdin
if _, err := safeAssignPipe(c.cmdStdinWriter, cmd.StdinPipe); err != nil {
if _, err := safeAssignPipe(c.cmdStdinWriter, c.cmd.StdinPipe); err != nil {
return err
}
if _, err := safeAssignPipe(c.cmdStdoutReader, cmd.StdoutPipe); err != nil {
if _, err := safeAssignPipe(c.cmdStdoutReader, c.cmd.StdoutPipe); err != nil {
return err
}
if _, err := safeAssignPipe(c.cmdStderrReader, cmd.StderrPipe); err != nil {
if _, err := safeAssignPipe(c.cmdStderrReader, c.cmd.StderrPipe); err != nil {
return err
}
if c.cmdManagedStderr != nil {
if cmd.Stderr != nil {
if c.cmd.Stderr != nil {
panic("CombineStderr needs managed (but not caller-provided) stderr pipe")
}
cmd.Stderr = c.cmdManagedStderr
c.cmd.Stderr = c.cmdManagedStderr
}
return cmd.Start()
return c.cmd.Start()
}
func (c *Command) closeStdioPipes() {
safeClosePtrCloser(c.cmdStdoutReader)
safeClosePtrCloser(c.cmdStderrReader)
safeClosePtrCloser(c.cmdStdinWriter)
}
func (c *Command) Wait() error {
defer func() {
safeClosePtrCloser(c.cmdStdoutReader)
safeClosePtrCloser(c.cmdStderrReader)
safeClosePtrCloser(c.cmdStdinWriter)
c.closeStdioPipes()
c.cmdFinished()
}()
cmd, ctx, cancel := c.cmd, c.cmdCtx, c.cmdCancel
if c.opts.PipelineFunc != nil {
err := c.opts.PipelineFunc(ctx, cancel)
if err != nil {
cancel()
errWait := cmd.Wait()
return errors.Join(err, errWait)
errCallback := c.opts.PipelineFunc(&cmdContext{Context: c.cmdCtx, cmd: c})
// after the pipeline function returns, we can safely cancel the command context and close the stdio pipes
c.cmdCancel(errCallback)
c.closeStdioPipes()
errWait := c.cmd.Wait()
errCause := context.Cause(c.cmdCtx)
// the pipeline function should be able to know whether it succeeds or fails
if errCallback == nil && (errCause == nil || errors.Is(errCause, context.Canceled)) {
return nil
}
return errors.Join(errCallback, errCause, errWait)
}
errWait := cmd.Wait()
// there might be other goroutines using the context or pipes, so we just wait for the command to finish
errWait := c.cmd.Wait()
elapsed := time.Since(c.cmdStartTime)
if elapsed > time.Second {
log.Debug("slow git.Command.Run: %s (%s)", c, elapsed)
log.Debug("slow git.Command.Run: %s (%s)", c, elapsed) // TODO: no need to log this for long-running commands
}
// Here the logic is different from "PipelineFunc" case,
// because PipelineFunc can return error if it fails, it knows whether it succeeds or fails.
// But in normal case, the caller just runs the git command, the command's exit code is the source of truth.
// If the caller need to know whether the command error is caused by cancellation, it should check the "err" by itself.
errCause := context.Cause(c.cmdCtx)
if errors.Is(errCause, context.Canceled) {
// if the ctx is canceled without other error, it must be caused by normal cancellation
return errCause
}
if errWait != nil {
// no matter whether there is other cause error, if "Wait" also has error,
// it's likely the error is caused by Wait error (from git command)
return errWait
}
return errCause
return errors.Join(errCause, errWait)
}
func (c *Command) StartWithStderr(ctx context.Context) RunStdError {
@@ -513,59 +494,6 @@ func (c *Command) Run(ctx context.Context) (err error) {
return c.Wait()
}
type RunStdError interface {
error
Unwrap() error
Stderr() string
}
type runStdError struct {
err error // usually the low-level error like `*exec.ExitError`
stderr string // git command's stderr output
errMsg string // the cached error message for Error() method
}
func (r *runStdError) Error() string {
// FIXME: GIT-CMD-STDERR: it is a bad design, the stderr should not be put in the error message
// But a lot of code only checks `strings.Contains(err.Error(), "git error")`
if r.errMsg == "" {
r.errMsg = fmt.Sprintf("%s - %s", r.err.Error(), strings.TrimSpace(r.stderr))
}
return r.errMsg
}
func (r *runStdError) Unwrap() error {
return r.err
}
func (r *runStdError) Stderr() string {
return r.stderr
}
func ErrorAsStderr(err error) (string, bool) {
var runErr RunStdError
if errors.As(err, &runErr) {
return runErr.Stderr(), true
}
return "", false
}
func StderrHasPrefix(err error, prefix string) bool {
stderr, ok := ErrorAsStderr(err)
if !ok {
return false
}
return strings.HasPrefix(stderr, prefix)
}
func IsErrorExitCode(err error, code int) bool {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
return exitError.ExitCode() == code
}
return false
}
// RunStdString runs the command and returns stdout/stderr as string. and store stderr to returned error (err combined with stderr).
func (c *Command) RunStdString(ctx context.Context) (stdout, stderr string, runErr RunStdError) {
stdoutBytes, stderrBytes, runErr := c.WithParentCallerInfo().runStdBytes(ctx)