diff --git a/test/utils/ktesting/signals.go b/test/utils/ktesting/signals.go index f498c46259c..4edef18bd96 100644 --- a/test/utils/ktesting/signals.go +++ b/test/utils/ktesting/signals.go @@ -19,6 +19,7 @@ package ktesting import ( "context" "errors" + "fmt" "io" "os" "os/signal" @@ -29,11 +30,14 @@ import ( var ( // defaultProgressReporter is inactive until init is called. - defaultProgressReporter = &progressReporter{ - // os.Stderr gets redirected by "go test". "go test -v" has to be - // used to see the output while a test runs. - out: os.Stderr, - } + defaultProgressReporter = &progressReporter{} + + // interruptCtx tracks whether the process got interrupted via SIGINT. + // In that case, interrupted gets called to cancel interruptCtx with + // a suitable message. + // + // This gets set up once per process and never gets reset. + interruptCtx, interrupted = context.WithCancelCause(context.Background()) ) const ginkgoSpecContextKey = "GINKGO_SPEC_CONTEXT" @@ -46,17 +50,18 @@ type progressReporter struct { // initMutex protects initialization and finalization of the reporter. initMutex sync.Mutex - usageCount int64 - wg sync.WaitGroup - signalCtx, interruptCtx context.Context - signalCancel func() - progressChannel chan os.Signal + usageCount int64 + wg sync.WaitGroup + testCtx context.Context + signalChannel chan os.Signal + progressChannel chan os.Signal // reportMutex protects report creation and settings. reportMutex sync.Mutex reporterCounter int64 reporters map[int64]func() string out io.Writer + closeOut func() error } var _ ginkgoReporter = &progressReporter{} @@ -78,6 +83,16 @@ func (p *progressReporter) init(tb TB) context.Context { return context.Background() } + tb.Helper() + + // If already interrupted, then don't start the new test. + // This is necessary because normally CTRL-C would exit + // the entire process immediately. Now we keep running + // to clean up. + if interruptCtx.Err() != nil { + tb.Fatalf("testing has been interrupted: %v", context.Cause(interruptCtx)) + } + p.initMutex.Lock() defer p.initMutex.Unlock() @@ -85,14 +100,35 @@ func (p *progressReporter) init(tb TB) context.Context { tb.Cleanup(p.finalize) if p.usageCount > 1 { // Was already initialized. - return p.interruptCtx + return p.testCtx } - p.signalCtx, p.signalCancel = signal.NotifyContext(context.Background(), os.Interrupt) - cancelCtx, cancel := context.WithCancelCause(context.Background()) + // Might have been set for testing purposes. + if p.out == nil { + // os.Stderr gets redirected by "go test". "go test -v" has to be + // used to see that output while a test runs. + // + // Opening /dev/tty during init avoids the redirection. + // May fail, depending on the OS, in which case + // os.Stderr is used. + if console, err := os.OpenFile("/dev/tty", os.O_RDWR|os.O_APPEND, 0); err == nil { + p.out = console + p.closeOut = console.Close + + } else { + p.out = os.Stdout + p.closeOut = nil + } + } + + p.signalChannel = make(chan os.Signal) + signal.Notify(p.signalChannel, os.Interrupt) p.wg.Go(func() { - <-p.signalCtx.Done() - cancel(errors.New("received interrupt signal")) + _, ok := <-p.signalChannel + if ok { + _, _ = fmt.Fprint(p.out, "\n\nINFO: canceling test context: received interrupt signal\n\n") + interrupted(errors.New("received interrupt signal")) + } }) // This reimplements the contract between Ginkgo and Gomega for progress reporting. @@ -101,7 +137,7 @@ func (p *progressReporter) init(tb TB) context.Context { // nolint:staticcheck // It complains about using a plain string. This can only be fixed // by Ginkgo and Gomega formalizing this interface and define a type (somewhere... // probably cannot be in either Ginkgo or Gomega). - p.interruptCtx = context.WithValue(cancelCtx, ginkgoSpecContextKey, defaultProgressReporter) + p.testCtx = context.WithValue(interruptCtx, ginkgoSpecContextKey, defaultProgressReporter) p.progressChannel = make(chan os.Signal, 1) // progressSignals will be empty on Windows. @@ -111,7 +147,7 @@ func (p *progressReporter) init(tb TB) context.Context { p.wg.Go(p.run) - return p.interruptCtx + return p.testCtx } func (p *progressReporter) finalize() { @@ -124,8 +160,17 @@ func (p *progressReporter) finalize() { return } - p.signalCancel() + signal.Stop(p.signalChannel) + close(p.signalChannel) + signal.Stop(p.progressChannel) + close(p.progressChannel) p.wg.Wait() + + // Now that all goroutines are stopped, we can clean up some more. + if p.closeOut != nil { + _ = p.closeOut() + p.out = nil + } } // AttachProgressReporter implements Gomega's contextWithAttachProgressReporter. @@ -154,21 +199,12 @@ func (p *progressReporter) detachProgressReporter(id int64) { func (p *progressReporter) run() { for { - select { - case <-p.interruptCtx.Done(): - // Maybe do one last progress report? - // - // This is primarily for unit testing of ktesting itself, - // in a normal test we don't care anymore. - select { - case <-p.progressChannel: - p.dumpProgress() - default: - } + _, ok := <-p.progressChannel + if !ok { + // Shut down. return - case <-p.progressChannel: - p.dumpProgress() } + p.dumpProgress() } } diff --git a/test/utils/ktesting/stepcontext_test.go b/test/utils/ktesting/stepcontext_test.go index c8f3c143011..314fa485daf 100644 --- a/test/utils/ktesting/stepcontext_test.go +++ b/test/utils/ktesting/stepcontext_test.go @@ -17,9 +17,11 @@ limitations under the License. package ktesting import ( + "context" "io" "os" "testing" + "time" "github.com/onsi/gomega" "go.uber.org/goleak" @@ -73,15 +75,20 @@ func TestStepContext(t *testing.T) { } func TestProgressReport(t *testing.T) { + oldOut := defaultProgressReporter.out + out := newOutputStream() + defaultProgressReporter.out = out t.Cleanup(func() { goleak.VerifyNone(t) - }) - - oldOut := defaultProgressReporter.out - reportStream := newOutputStream() - defaultProgressReporter.out = reportStream - t.Cleanup(func() { defaultProgressReporter.out = oldOut + + // If we get here, the defaultProgressReporter is not active anymore, + // but the interrupt context should still be canceled. + gomega.NewGomegaWithT(t).Expect(defaultProgressReporter.usageCount).To(gomega.Equal(int64(0)), "usage count") + gomega.NewGomegaWithT(t).Expect(context.Cause(interruptCtx)).To(gomega.MatchError(gomega.Equal("received interrupt signal")), "interrupted persistently") + + // Reset for next test. + interruptCtx, interrupted = context.WithCancelCause(context.Background()) }) // This must use a real testing.T, otherwise Init doesn't initialize signal handling. @@ -93,11 +100,21 @@ func TestProgressReport(t *testing.T) { // Trigger report and wait for it. defaultProgressReporter.progressChannel <- os.Interrupt - report := <-reportStream.stream + report := <-out.stream tCtx.Expect(report).To(gomega.Equal(`You requested a progress report. step: hello world `), "report") + + gomega.NewGomegaWithT(t).Expect(context.Cause(interruptCtx)).To(gomega.Succeed(), "not interrupted yet") + defaultProgressReporter.signalChannel <- os.Interrupt + message := <-out.stream + tCtx.Expect(message).To(gomega.Equal(` + +INFO: canceling test context: received interrupt signal + +`)) + gomega.NewGomegaWithT(t).Eventually(func() error { return context.Cause(tCtx) }).WithTimeout(30*time.Second).To(gomega.MatchError(gomega.Equal("received interrupt signal")), "interrupted") } // outputStream forwards exactly one Write call to a stream. @@ -116,6 +133,5 @@ func newOutputStream() *outputStream { func (s *outputStream) Write(buf []byte) (int, error) { s.stream <- string(buf) - close(s.stream) return len(buf), nil }