diff --git a/_fixtures/defercall.go b/_fixtures/defercall.go new file mode 100644 index 0000000000000000000000000000000000000000..a584794c978f403f08d7ccd7f0ce73eabc78b9f5 --- /dev/null +++ b/_fixtures/defercall.go @@ -0,0 +1,29 @@ +package main + +var n = 0 + +func sampleFunction() { + n++ +} + +func callAndDeferReturn() { + defer sampleFunction() + sampleFunction() + n++ +} + +func callAndPanic2() { + defer sampleFunction() + sampleFunction() + panic("panicking") +} + +func callAndPanic() { + defer recover() + callAndPanic2() +} + +func main() { + callAndDeferReturn() + callAndPanic() +} diff --git a/proc/breakpoints.go b/proc/breakpoints.go index 7a28ec28aadfb500af5d3c49ba2a3c38392875aa..4a78808363c40c254b60e9fa9a9eabac982c51a9 100644 --- a/proc/breakpoints.go +++ b/proc/breakpoints.go @@ -33,7 +33,18 @@ type Breakpoint struct { HitCount map[int]uint64 // Number of times a breakpoint has been reached in a certain goroutine TotalHitCount uint64 // Number of times a breakpoint has been reached - Cond ast.Expr // When Cond is not nil the breakpoint will be triggered only if evaluating Cond returns true + // When DeferCond is set the breakpoint will only trigger + // if the caller is runtime.gopanic or if the return address + // is in the DeferReturns array. + // Next sets DeferCond for the breakpoint it sets on the + // deferred function, DeferReturns is populated with the + // addresses of calls to runtime.deferreturn in the current + // function. This insures that the breakpoint on the deferred + // function only triggers on panic or on the defer call to + // the function, not when the function is called directly + DeferCond bool + DeferReturns []uint64 + Cond ast.Expr // When Cond is not nil the breakpoint will be triggered only if evaluating Cond returns true } func (bp *Breakpoint) String() string { @@ -122,6 +133,24 @@ func (bp *Breakpoint) checkCondition(thread *Thread) (bool, error) { if bp.Cond == nil { return true, nil } + if bp.DeferCond { + frames, err := thread.Stacktrace(2) + if err == nil { + ispanic := len(frames) >= 3 && frames[2].Current.Fn != nil && frames[2].Current.Fn.Name == "runtime.gopanic" + isdeferreturn := false + if len(frames) >= 1 { + for _, pc := range bp.DeferReturns { + if frames[0].Ret == pc { + isdeferreturn = true + break + } + } + } + if !ispanic && !isdeferreturn { + return false, nil + } + } + } scope, err := thread.Scope() if err != nil { return true, err diff --git a/proc/proc_test.go b/proc/proc_test.go index fb40494c428395afae8b79cf203584c58143ddd8..48efedc9f7d96d312d6ac339d00d666c492dd313 100644 --- a/proc/proc_test.go +++ b/proc/proc_test.go @@ -1907,3 +1907,30 @@ func TestTestvariables2Prologue(t *testing.T) { } }) } + +func TestNextDeferReturnAndDirectCall(t *testing.T) { + // Next should not step into a deferred function if it is called + // directly, only if it is called through a panic or a deferreturn. + // Here we test the case where the function is called by a deferreturn + testnext("defercall", []nextTest{ + {9, 10}, + {10, 11}, + {11, 12}, + {12, 13}, + {13, 5}, + {5, 6}, + {6, 7}, + {7, 13}, + {13, 28}}, "main.callAndDeferReturn", t) +} + +func TestNextPanicAndDirectCall(t *testing.T) { + // Next should not step into a deferred function if it is called + // directly, only if it is called through a panic or a deferreturn. + // Here we test the case where the function is called by a panic + testnext("defercall", []nextTest{ + {15, 16}, + {16, 17}, + {17, 18}, + {18, 5}}, "main.callAndPanic2", t) +} diff --git a/proc/threads.go b/proc/threads.go index 8c9749daabb99b42df354f0cfc61317f4f9248a8..2b106ae5b1ce1c44c463ea71b70cc78e55f3f24d 100644 --- a/proc/threads.go +++ b/proc/threads.go @@ -162,8 +162,21 @@ func (dbp *Process) setNextBreakpoints() (err error) { // Set breakpoints at every line, and the return address. Also look for // a deferred function and set a breakpoint there too. func (dbp *Process) next(g *G, topframe Stackframe) error { - pcs := dbp.lineInfo.AllPCsBetween(topframe.FDE.Begin(), topframe.FDE.End()-1, topframe.Current.File) + cond := sameGoroutineCondition(dbp.SelectedGoroutine) + + // Disassembles function to find all runtime.deferreturn locations + // See documentation of Breakpoint.DeferCond for why this is necessary + deferreturns := []uint64{} + text, err := dbp.CurrentThread.Disassemble(topframe.FDE.Begin(), topframe.FDE.End(), false) + if err == nil { + for _, instr := range text { + if instr.IsCall() && instr.DestLoc != nil && instr.DestLoc.Fn != nil && instr.DestLoc.Fn.Name == "runtime.deferreturn" { + deferreturns = append(deferreturns, instr.Loc.PC) + } + } + } + // Set breakpoint on the most recently deferred function (if any) var deferpc uint64 = 0 if g != nil && g.DeferPC != 0 { _, _, deferfn := dbp.goSymTable.PCToLine(g.DeferPC) @@ -173,6 +186,20 @@ func (dbp *Process) next(g *G, topframe Stackframe) error { return err } } + if deferpc != 0 { + bp, err := dbp.SetTempBreakpoint(deferpc, cond) + if err != nil { + if _, ok := err.(BreakpointExistsError); !ok { + dbp.ClearTempBreakpoints() + return err + } + } + bp.DeferCond = true + bp.DeferReturns = deferreturns + } + + // Add breakpoints on all the lines in the current function + pcs := dbp.lineInfo.AllPCsBetween(topframe.FDE.Begin(), topframe.FDE.End()-1, topframe.Current.File) var covered bool for i := range pcs { @@ -188,11 +215,10 @@ func (dbp *Process) next(g *G, topframe Stackframe) error { return nil } } - if deferpc != 0 { - pcs = append(pcs, deferpc) - } + + // Add a breakpoint on the return address for the current frame pcs = append(pcs, topframe.Ret) - return dbp.setTempBreakpoints(topframe.Current.PC, pcs, sameGoroutineCondition(dbp.SelectedGoroutine)) + return dbp.setTempBreakpoints(topframe.Current.PC, pcs, cond) } // Set a breakpoint at every reachable location, as well as the return address. Without @@ -385,5 +411,11 @@ func (thread *Thread) onNextGoroutine() (bool, error) { if bp == nil { return false, nil } + // we just want to check the condition on the goroutine id here + dc := bp.DeferCond + bp.DeferCond = false + defer func() { + bp.DeferCond = dc + }() return bp.checkCondition(thread) }