diff --git a/http/sys_raft.go b/http/sys_raft.go index b2055d674c..f05af7f490 100644 --- a/http/sys_raft.go +++ b/http/sys_raft.go @@ -87,7 +87,7 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ }, } - joined, err := core.JoinRaftCluster(context.Background(), leaderInfos, req.NonVoter) + joined, err := core.JoinRaftCluster(core.ShutdownContext(), leaderInfos, req.NonVoter) if err != nil { respondError(w, http.StatusInternalServerError, err) return diff --git a/vault/core.go b/vault/core.go index 2da158e6ae..1b835edc2b 100644 --- a/vault/core.go +++ b/vault/core.go @@ -314,6 +314,12 @@ type Core struct { // that the join is complete raftJoinDoneCh chan struct{} + // shutdownCtx is a context that is canceled when the Core is shut down. + // It is used to scope background operations (such as raft join retries) + // that must not outlive the server process. + shutdownCtx context.Context + shutdownCtxCancel context.CancelFunc + // postUnsealStarted informs the raft retry join routine that unseal key // validation is completed and post unseal has started so that it can complete // the join process when Shamir seal is in use @@ -1081,6 +1087,8 @@ func CreateCore(conf *CoreConfig) (*Core, error) { mountsLock := locking.CreateConfigurableRWMutex(detectDeadlocks, "mountsLock") authLock := locking.CreateConfigurableRWMutex(detectDeadlocks, "authLock") + shutdownCtx, shutdownCtxCancel := context.WithCancel(context.Background()) + // Setup the core c := &Core{ entCore: entCore{}, @@ -1139,6 +1147,8 @@ func CreateCore(conf *CoreConfig) (*Core, error) { postUnsealStarted: new(uint32), raftInfo: new(atomic.Value), raftJoinDoneCh: make(chan struct{}), + shutdownCtx: shutdownCtx, + shutdownCtxCancel: shutdownCtxCancel, clusterHeartbeatInterval: clusterHeartbeatInterval, activityLogConfig: conf.ActivityLogConfig, billingConfig: conf.BillingConfig, @@ -1675,6 +1685,9 @@ func (c *Core) ShutdownCoreError(err error) { // happens as quickly as possible. func (c *Core) Shutdown() error { c.logger.Debug("shutdown called") + if c.shutdownCtxCancel != nil { + c.shutdownCtxCancel() + } err := c.sealInternal() c.stateLock.Lock() @@ -1703,6 +1716,15 @@ func (c *Core) ShutdownDone() <-chan struct{} { return c.shutdownDoneCh.Load().(chan struct{}) } +// ShutdownContext returns a context that is canceled when the Core shuts down. +// Use this for background operations that must not outlive the server process. +func (c *Core) ShutdownContext() context.Context { + if c.shutdownCtx == nil { + return context.Background() + } + return c.shutdownCtx +} + // CORSConfig returns the current CORS configuration func (c *Core) CORSConfig() *CORSConfig { return c.corsConfig