diff --git a/test/utils/ktesting/signals.go b/test/utils/ktesting/signals.go index 8f3fe3cca73..4edef18bd96 100644 --- a/test/utils/ktesting/signals.go +++ b/test/utils/ktesting/signals.go @@ -19,6 +19,7 @@ package ktesting import ( "context" "errors" + "fmt" "io" "os" "os/signal" @@ -30,6 +31,13 @@ import ( var ( // defaultProgressReporter is inactive until init is called. defaultProgressReporter = &progressReporter{} + + // interruptCtx tracks whether the process got interrupted via SIGINT. + // In that case, interrupted gets called to cancel interruptCtx with + // a suitable message. + // + // This gets set up once per process and never gets reset. + interruptCtx, interrupted = context.WithCancelCause(context.Background()) ) const ginkgoSpecContextKey = "GINKGO_SPEC_CONTEXT" @@ -42,11 +50,11 @@ type progressReporter struct { // initMutex protects initialization and finalization of the reporter. initMutex sync.Mutex - usageCount int64 - wg sync.WaitGroup - signalCtx, interruptCtx context.Context - signalChannel chan os.Signal - progressChannel chan os.Signal + usageCount int64 + wg sync.WaitGroup + testCtx context.Context + signalChannel chan os.Signal + progressChannel chan os.Signal // reportMutex protects report creation and settings. reportMutex sync.Mutex @@ -75,6 +83,16 @@ func (p *progressReporter) init(tb TB) context.Context { return context.Background() } + tb.Helper() + + // If already interrupted, then don't start the new test. + // This is necessary because normally CTRL-C would exit + // the entire process immediately. Now we keep running + // to clean up. + if interruptCtx.Err() != nil { + tb.Fatalf("testing has been interrupted: %v", context.Cause(interruptCtx)) + } + p.initMutex.Lock() defer p.initMutex.Unlock() @@ -82,7 +100,7 @@ func (p *progressReporter) init(tb TB) context.Context { tb.Cleanup(p.finalize) if p.usageCount > 1 { // Was already initialized. - return p.interruptCtx + return p.testCtx } // Might have been set for testing purposes. @@ -108,6 +126,7 @@ func (p *progressReporter) init(tb TB) context.Context { p.wg.Go(func() { _, ok := <-p.signalChannel if ok { + _, _ = fmt.Fprint(p.out, "\n\nINFO: canceling test context: received interrupt signal\n\n") interrupted(errors.New("received interrupt signal")) } }) @@ -118,7 +137,7 @@ func (p *progressReporter) init(tb TB) context.Context { // nolint:staticcheck // It complains about using a plain string. This can only be fixed // by Ginkgo and Gomega formalizing this interface and define a type (somewhere... // probably cannot be in either Ginkgo or Gomega). - p.interruptCtx = context.WithValue(cancelCtx, ginkgoSpecContextKey, defaultProgressReporter) + p.testCtx = context.WithValue(interruptCtx, ginkgoSpecContextKey, defaultProgressReporter) p.progressChannel = make(chan os.Signal, 1) // progressSignals will be empty on Windows. @@ -128,7 +147,7 @@ func (p *progressReporter) init(tb TB) context.Context { p.wg.Go(p.run) - return p.interruptCtx + return p.testCtx } func (p *progressReporter) finalize() { diff --git a/test/utils/ktesting/stepcontext_test.go b/test/utils/ktesting/stepcontext_test.go index c8f3c143011..314fa485daf 100644 --- a/test/utils/ktesting/stepcontext_test.go +++ b/test/utils/ktesting/stepcontext_test.go @@ -17,9 +17,11 @@ limitations under the License. package ktesting import ( + "context" "io" "os" "testing" + "time" "github.com/onsi/gomega" "go.uber.org/goleak" @@ -73,15 +75,20 @@ func TestStepContext(t *testing.T) { } func TestProgressReport(t *testing.T) { + oldOut := defaultProgressReporter.out + out := newOutputStream() + defaultProgressReporter.out = out t.Cleanup(func() { goleak.VerifyNone(t) - }) - - oldOut := defaultProgressReporter.out - reportStream := newOutputStream() - defaultProgressReporter.out = reportStream - t.Cleanup(func() { defaultProgressReporter.out = oldOut + + // If we get here, the defaultProgressReporter is not active anymore, + // but the interrupt context should still be canceled. + gomega.NewGomegaWithT(t).Expect(defaultProgressReporter.usageCount).To(gomega.Equal(int64(0)), "usage count") + gomega.NewGomegaWithT(t).Expect(context.Cause(interruptCtx)).To(gomega.MatchError(gomega.Equal("received interrupt signal")), "interrupted persistently") + + // Reset for next test. + interruptCtx, interrupted = context.WithCancelCause(context.Background()) }) // This must use a real testing.T, otherwise Init doesn't initialize signal handling. @@ -93,11 +100,21 @@ func TestProgressReport(t *testing.T) { // Trigger report and wait for it. defaultProgressReporter.progressChannel <- os.Interrupt - report := <-reportStream.stream + report := <-out.stream tCtx.Expect(report).To(gomega.Equal(`You requested a progress report. step: hello world `), "report") + + gomega.NewGomegaWithT(t).Expect(context.Cause(interruptCtx)).To(gomega.Succeed(), "not interrupted yet") + defaultProgressReporter.signalChannel <- os.Interrupt + message := <-out.stream + tCtx.Expect(message).To(gomega.Equal(` + +INFO: canceling test context: received interrupt signal + +`)) + gomega.NewGomegaWithT(t).Eventually(func() error { return context.Cause(tCtx) }).WithTimeout(30*time.Second).To(gomega.MatchError(gomega.Equal("received interrupt signal")), "interrupted") } // outputStream forwards exactly one Write call to a stream. @@ -116,6 +133,5 @@ func newOutputStream() *outputStream { func (s *outputStream) Write(buf []byte) (int, error) { s.stream <- string(buf) - close(s.stream) return len(buf), nil }