From 8e94ea963ba8dacd83d1d9fe5c6dccb932e7245d Mon Sep 17 00:00:00 2001 From: Nick Cabatoff Date: Fri, 16 Apr 2021 17:03:22 -0400 Subject: [PATCH] On lease deletion, also delete non-orphan batch token parent index (#11377) --- changelog/11377.txt | 3 +++ vault/expiration.go | 51 +++++++++++++++++++++++++++------------- vault/expiration_test.go | 33 +++++++++++++++++--------- vault/token_store.go | 13 ++++++++-- 4 files changed, 71 insertions(+), 29 deletions(-) create mode 100644 changelog/11377.txt diff --git a/changelog/11377.txt b/changelog/11377.txt new file mode 100644 index 0000000000..171947399d --- /dev/null +++ b/changelog/11377.txt @@ -0,0 +1,3 @@ +```release-note:bug +core: Fix storage entry leak when revoking leases created with non-orphan batch tokens. +``` diff --git a/vault/expiration.go b/vault/expiration.go index 457690110e..bff06d9515 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -915,8 +915,24 @@ func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, fo // Delete the secondary index, but only if it's a leased secret (not auth) if le.Secret != nil { - if err := m.removeIndexByToken(ctx, le); err != nil { - return err + var indexToken string + // Maintain secondary index by token, except for orphan batch tokens + switch le.ClientTokenType { + case logical.TokenTypeBatch: + te, err := m.tokenStore.lookupBatchTokenInternal(ctx, le.ClientToken) + if err != nil { + return err + } + // If it's a non-orphan batch token, assign the secondary index to its + // parent + indexToken = te.Parent + default: + indexToken = le.ClientToken + } + if indexToken != "" { + if err := m.removeIndexByToken(ctx, le, indexToken); err != nil { + return err + } } } @@ -1364,6 +1380,17 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, Version: 1, } + var indexToken string + // Maintain secondary index by token, except for orphan batch tokens + switch { + case te.Type != logical.TokenTypeBatch: + indexToken = le.ClientToken + case te.Parent != "": + // If it's a non-orphan batch token, assign the secondary index to its + // parent + indexToken = te.Parent + } + defer func() { // If there is an error we want to rollback as much as possible (note // that errors here are ignored to do as much cleanup as we can). We @@ -1382,7 +1409,7 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, retErr = multierror.Append(retErr, errwrap.Wrapf("an additional error was encountered deleting any lease associated with the newly-generated secret: {{err}}", err)) } - if err := m.removeIndexByToken(ctx, le); err != nil { + if err := m.removeIndexByToken(ctx, le, indexToken); err != nil { retErr = multierror.Append(retErr, errwrap.Wrapf("an additional error was encountered removing lease indexes associated with the newly-generated secret: {{err}}", err)) } } @@ -1408,16 +1435,8 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, return "", err } - // Maintain secondary index by token, except for orphan batch tokens - switch { - case te.Type != logical.TokenTypeBatch: - if err := m.createIndexByToken(ctx, le, le.ClientToken); err != nil { - return "", err - } - case te.Parent != "": - // If it's a non-orphan batch token, assign the secondary index to its - // parent - if err := m.createIndexByToken(ctx, le, te.Parent); err != nil { + if indexToken != "" { + if err := m.createIndexByToken(ctx, le, indexToken); err != nil { return "", err } } @@ -1966,10 +1985,10 @@ func (m *ExpirationManager) indexByToken(ctx context.Context, le *leaseEntry) (* } // removeIndexByToken removes the secondary index from the token to a lease entry -func (m *ExpirationManager) removeIndexByToken(ctx context.Context, le *leaseEntry) error { +func (m *ExpirationManager) removeIndexByToken(ctx context.Context, le *leaseEntry, token string) error { tokenNS := namespace.RootNamespace saltCtx := namespace.ContextWithNamespace(ctx, namespace.RootNamespace) - _, nsID := namespace.SplitIDFromString(le.ClientToken) + _, nsID := namespace.SplitIDFromString(token) if nsID != "" { var err error tokenNS, err = NamespaceByID(ctx, nsID, m.core) @@ -1990,7 +2009,7 @@ func (m *ExpirationManager) removeIndexByToken(ctx context.Context, le *leaseEnt } } - saltedID, err := m.tokenStore.SaltID(saltCtx, le.ClientToken) + saltedID, err := m.tokenStore.SaltID(saltCtx, token) if err != nil { return err } diff --git a/vault/expiration_test.go b/vault/expiration_test.go index dfa41c3cfc..fd92a4c6ba 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -677,7 +677,8 @@ func TestExpiration_Register(t *testing.T) { } func TestExpiration_Register_BatchToken(t *testing.T) { - exp := mockExpiration(t) + c, _, rootToken := TestCoreUnsealed(t) + exp := c.expiration noop := &NoopBackend{ RequestHandler: func(ctx context.Context, req *logical.Request) (*logical.Response, error) { resp := &logical.Response{Secret: req.Secret} @@ -685,15 +686,17 @@ func TestExpiration_Register_BatchToken(t *testing.T) { return resp, nil }, } - _, barrier, _ := mockBarrier(t) - view := NewBarrierView(barrier, "logical/") - meUUID, err := uuid.GenerateUUID() - if err != nil { - t.Fatal(err) - } - err = exp.router.Mount(noop, "prod/aws/", &MountEntry{Path: "prod/aws/", Type: "noop", UUID: meUUID, Accessor: "noop-accessor", namespace: namespace.RootNamespace}, view) - if err != nil { - t.Fatal(err) + { + _, barrier, _ := mockBarrier(t) + view := NewBarrierView(barrier, "logical/") + meUUID, err := uuid.GenerateUUID() + if err != nil { + t.Fatal(err) + } + err = exp.router.Mount(noop, "prod/aws/", &MountEntry{Path: "prod/aws/", Type: "noop", UUID: meUUID, Accessor: "noop-accessor", namespace: namespace.RootNamespace}, view) + if err != nil { + t.Fatal(err) + } } te := &logical.TokenEntry{ @@ -701,9 +704,10 @@ func TestExpiration_Register_BatchToken(t *testing.T) { TTL: 1 * time.Second, NamespaceID: "root", CreationTime: time.Now().Unix(), + Parent: rootToken, } - err = exp.tokenStore.create(context.Background(), te) + err := exp.tokenStore.create(context.Background(), te) if err != nil { t.Fatal(err) } @@ -760,6 +764,13 @@ func TestExpiration_Register_BatchToken(t *testing.T) { break } + idEnts, err := exp.tokenView.List(context.Background(), "") + if err != nil { + t.Fatal(err) + } + if len(idEnts) != 0 { + t.Fatalf("expected no entries in sys/expire/token, got: %v", idEnts) + } } func TestExpiration_RegisterAuth(t *testing.T) { diff --git a/vault/token_store.go b/vault/token_store.go index b1c6725c4e..3c17c61b3c 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -1122,7 +1122,7 @@ func (ts *TokenStore) lookupTainted(ctx context.Context, id string) (*logical.To return ts.lookupInternal(ctx, id, false, true) } -func (ts *TokenStore) lookupBatchToken(ctx context.Context, id string) (*logical.TokenEntry, error) { +func (ts *TokenStore) lookupBatchTokenInternal(ctx context.Context, id string) (*logical.TokenEntry, error) { // Strip the b. from the front and namespace ID from the back bEntry, _ := namespace.SplitIDFromString(id[2:]) @@ -1146,6 +1146,16 @@ func (ts *TokenStore) lookupBatchToken(ctx context.Context, id string) (*logical return nil, err } + te.ID = id + return te, nil +} + +func (ts *TokenStore) lookupBatchToken(ctx context.Context, id string) (*logical.TokenEntry, error) { + te, err := ts.lookupBatchTokenInternal(ctx, id) + if err != nil { + return nil, err + } + if time.Now().After(time.Unix(te.CreationTime, 0).Add(te.TTL)) { return nil, nil } @@ -1160,7 +1170,6 @@ func (ts *TokenStore) lookupBatchToken(ctx context.Context, id string) (*logical } } - te.ID = id return te, nil }