@@ -10,6 +10,7 @@ package testing
10
10
11
11
import (
12
12
"bytes"
13
+ "context"
13
14
"errors"
14
15
"flag"
15
16
"fmt"
@@ -78,6 +79,9 @@ type common struct {
78
79
tempDir string
79
80
tempDirErr error
80
81
tempDirSeq int32
82
+
83
+ ctx context.Context
84
+ cancelCtx context.CancelFunc
81
85
}
82
86
83
87
type logger struct {
@@ -152,6 +156,7 @@ func fmtDuration(d time.Duration) string {
152
156
// TB is the interface common to T and B.
153
157
type TB interface {
154
158
Cleanup (func ())
159
+ Context () context.Context
155
160
Error (args ... interface {})
156
161
Errorf (format string , args ... interface {})
157
162
Fail ()
@@ -307,6 +312,15 @@ func (c *common) Cleanup(f func()) {
307
312
c .cleanups = append (c .cleanups , f )
308
313
}
309
314
315
+ // Context returns a context that is canceled just before
316
+ // Cleanup-registered functions are called.
317
+ //
318
+ // Cleanup functions can wait for any resources
319
+ // that shut down on [context.Context.Done] before the test or benchmark completes.
320
+ func (c * common ) Context () context.Context {
321
+ return c .ctx
322
+ }
323
+
310
324
// TempDir returns a temporary directory for the test to use.
311
325
// The directory is automatically removed by Cleanup when the test and
312
326
// all its subtests complete.
@@ -447,6 +461,9 @@ func (c *common) runCleanup() {
447
461
if cleanup == nil {
448
462
return
449
463
}
464
+ if c .cancelCtx != nil {
465
+ c .cancelCtx ()
466
+ }
450
467
cleanup ()
451
468
}
452
469
}
@@ -488,12 +505,15 @@ func (t *T) Run(name string, f func(t *T)) bool {
488
505
}
489
506
490
507
// Create a subtest.
508
+ ctx , cancelCtx := context .WithCancel (context .Background ())
491
509
sub := T {
492
510
common : common {
493
- output : & logger {logToStdout : flagVerbose },
494
- name : testName ,
495
- parent : & t .common ,
496
- level : t .level + 1 ,
511
+ output : & logger {logToStdout : flagVerbose },
512
+ name : testName ,
513
+ parent : & t .common ,
514
+ level : t .level + 1 ,
515
+ ctx : ctx ,
516
+ cancelCtx : cancelCtx ,
497
517
},
498
518
context : t .context ,
499
519
}
@@ -606,9 +626,12 @@ func runTests(matchString func(pat, str string) (bool, error), tests []InternalT
606
626
ok = true
607
627
608
628
ctx := newTestContext (newMatcher (matchString , flagRunRegexp , "-test.run" , flagSkipRegexp ))
629
+ runCtx , cancelCtx := context .WithCancel (context .Background ())
609
630
t := & T {
610
631
common : common {
611
- output : & logger {logToStdout : flagVerbose },
632
+ output : & logger {logToStdout : flagVerbose },
633
+ ctx : runCtx ,
634
+ cancelCtx : cancelCtx ,
612
635
},
613
636
context : ctx ,
614
637
}
0 commit comments