diff --git a/audit/types.go b/audit/types.go index 8d8bd158c3..956bf35c0a 100644 --- a/audit/types.go +++ b/audit/types.go @@ -284,23 +284,11 @@ type Backend interface { // filtered pipelines. IsFallback() bool - // LogRequest is used to synchronously log a request. This is done after the - // request is authorized but before the request is executed. The arguments - // MUST not be modified in any way. They should be deep copied if this is - // a possibility. - LogRequest(context.Context, *logical.LogInput) error - - // LogResponse is used to synchronously log a response. This is done after - // the request is processed but before the response is sent. The arguments - // MUST not be modified in any way. They should be deep copied if this is - // a possibility. - LogResponse(context.Context, *logical.LogInput) error - // LogTestMessage is used to check an audit backend before adding it // permanently. It should attempt to synchronously log the given test // message, WITHOUT using the normal Salt (which would require a storage // operation on creation, which is currently disallowed.) - LogTestMessage(context.Context, *logical.LogInput, map[string]string) error + LogTestMessage(context.Context, *logical.LogInput) error // Reload is called on SIGHUP for supporting backends. Reload(context.Context) error @@ -326,4 +314,4 @@ type BackendConfig struct { } // Factory is the factory function to create an audit backend. -type Factory func(context.Context, *BackendConfig, bool, HeaderFormatter) (Backend, error) +type Factory func(context.Context, *BackendConfig, HeaderFormatter) (Backend, error) diff --git a/builtin/audit/file/backend.go b/builtin/audit/file/backend.go index d9f8e0ef53..f1d04bfa53 100644 --- a/builtin/audit/file/backend.go +++ b/builtin/audit/file/backend.go @@ -4,12 +4,8 @@ package file import ( - "bytes" "context" "fmt" - "io" - "os" - "path/filepath" "strconv" "strings" "sync" @@ -36,23 +32,18 @@ var _ audit.Backend = (*Backend)(nil) // It doesn't do anything more at the moment to assist with rotation // or reset the write cursor, this should be done in the future. type Backend struct { - f *os.File - fallback bool - fileLock sync.RWMutex - formatter *audit.EntryFormatterWriter - formatConfig audit.FormatterConfig - mode os.FileMode - name string - nodeIDList []eventlogger.NodeID - nodeMap map[eventlogger.NodeID]eventlogger.Node - filePath string - salt *atomic.Value - saltConfig *salt.Config - saltMutex sync.RWMutex - saltView logical.Storage + fallback bool + name string + nodeIDList []eventlogger.NodeID + nodeMap map[eventlogger.NodeID]eventlogger.Node + filePath string + salt *atomic.Value + saltConfig *salt.Config + saltMutex sync.RWMutex + saltView logical.Storage } -func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { +func Factory(_ context.Context, conf *audit.BackendConfig, headersConfig audit.HeaderFormatter) (audit.Backend, error) { const op = "file.Factory" if conf.SaltConfig == nil { @@ -96,25 +87,24 @@ func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, filePath = discard } - mode := os.FileMode(0o600) - if modeRaw, ok := conf.Config["mode"]; ok { - m, err := strconv.ParseUint(modeRaw, 8, 32) - if err != nil { - return nil, fmt.Errorf("%s: unable to parse 'mode': %w", op, err) - } - switch m { - case 0: - // if mode is 0000, then do not modify file mode - if filePath != stdout && filePath != discard { - fileInfo, err := os.Stat(filePath) - if err != nil { - return nil, fmt.Errorf("%s: unable to stat %q: %w", op, filePath, err) - } - mode = fileInfo.Mode() - } - default: - mode = os.FileMode(m) - } + b := &Backend{ + fallback: fallback, + filePath: filePath, + name: conf.MountPath, + saltConfig: conf.SaltConfig, + saltView: conf.SaltView, + salt: new(atomic.Value), + nodeIDList: []eventlogger.NodeID{}, + nodeMap: make(map[eventlogger.NodeID]eventlogger.Node), + } + + // Ensure we are working with the right type by explicitly storing a nil of + // the right type + b.salt.Store((*salt.Salt)(nil)) + + err = b.configureFilterNode(conf.Config["filter"]) + if err != nil { + return nil, fmt.Errorf("%s: error configuring filter node: %w", op, err) } cfg, err := formatterConfig(conf.Config) @@ -122,78 +112,19 @@ func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, return nil, fmt.Errorf("%s: failed to create formatter config: %w", op, err) } - b := &Backend{ - fallback: fallback, - filePath: filePath, - formatConfig: cfg, - mode: mode, - name: conf.MountPath, - saltConfig: conf.SaltConfig, - saltView: conf.SaltView, - salt: new(atomic.Value), + formatterOpts := []audit.Option{ + audit.WithHeaderFormatter(headersConfig), + audit.WithPrefix(conf.Config["prefix"]), } - // Ensure we are working with the right type by explicitly storing a nil of - // the right type - b.salt.Store((*salt.Salt)(nil)) - - // Configure the formatter for either case. - f, err := audit.NewEntryFormatter(b.formatConfig, b, audit.WithHeaderFormatter(headersConfig), audit.WithPrefix(conf.Config["prefix"])) + err = b.configureFormatterNode(cfg, formatterOpts...) if err != nil { - return nil, fmt.Errorf("%s: error creating formatter: %w", op, err) + return nil, fmt.Errorf("%s: error configuring formatter node: %w", op, err) } - var w audit.Writer - switch b.formatConfig.RequiredFormat { - case audit.JSONFormat: - w = &audit.JSONWriter{Prefix: conf.Config["prefix"]} - case audit.JSONxFormat: - w = &audit.JSONxWriter{Prefix: conf.Config["prefix"]} - default: - return nil, fmt.Errorf("%s: unknown format type %q", op, b.formatConfig.RequiredFormat) - } - - fw, err := audit.NewEntryFormatterWriter(b.formatConfig, f, w) + err = b.configureSinkNode(conf.MountPath, filePath, conf.Config["mode"], cfg.RequiredFormat.String()) if err != nil { - return nil, fmt.Errorf("%s: error creating formatter writer: %w", op, err) - } - b.formatter = fw - - if useEventLogger { - b.nodeIDList = []eventlogger.NodeID{} - b.nodeMap = make(map[eventlogger.NodeID]eventlogger.Node) - - err := b.configureFilterNode(conf.Config["filter"]) - if err != nil { - return nil, fmt.Errorf("%s: error configuring filter node: %w", op, err) - } - - formatterOpts := []audit.Option{ - audit.WithHeaderFormatter(headersConfig), - audit.WithPrefix(conf.Config["prefix"]), - } - - err = b.configureFormatterNode(cfg, formatterOpts...) - if err != nil { - return nil, fmt.Errorf("%s: error configuring formatter node: %w", op, err) - } - - err = b.configureSinkNode(conf.MountPath, filePath, conf.Config["mode"], cfg.RequiredFormat.String()) - if err != nil { - return nil, fmt.Errorf("%s: error configuring sink node: %w", op, err) - } - } else { - switch filePath { - case stdout: - case discard: - default: - // Ensure that the file can be successfully opened for writing; - // otherwise it will be too late to catch later without problems - // (ref: https://github.com/hashicorp/vault/issues/550) - if err := b.open(); err != nil { - return nil, fmt.Errorf("%s: sanity check failed; unable to open %q for writing: %w", op, filePath, err) - } - } + return nil, fmt.Errorf("%s: error configuring sink node: %w", op, err) } return b, nil @@ -223,178 +154,22 @@ func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { return newSalt, nil } -// Deprecated: Use eventlogger. -func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { - var writer io.Writer - switch b.filePath { - case stdout: - writer = os.Stdout - case discard: - return nil - } - - buf := bytes.NewBuffer(make([]byte, 0, 2000)) - err := b.formatter.FormatAndWriteRequest(ctx, buf, in) - if err != nil { - return err - } - - return b.log(ctx, buf, writer) -} - -// Deprecated: Use eventlogger. -func (b *Backend) log(_ context.Context, buf *bytes.Buffer, writer io.Writer) error { - reader := bytes.NewReader(buf.Bytes()) - - b.fileLock.Lock() - - if writer == nil { - if err := b.open(); err != nil { - b.fileLock.Unlock() - return err - } - writer = b.f - } - - if _, err := reader.WriteTo(writer); err == nil { - b.fileLock.Unlock() - return nil - } else if b.filePath == stdout { - b.fileLock.Unlock() - return err - } - - // If writing to stdout there's no real reason to think anything would have - // changed so return above. Otherwise, opportunistically try to re-open the - // FD, once per call. - b.f.Close() - b.f = nil - - if err := b.open(); err != nil { - b.fileLock.Unlock() - return err - } - - reader.Seek(0, io.SeekStart) - _, err := reader.WriteTo(writer) - b.fileLock.Unlock() - return err -} - -// Deprecated: Use eventlogger. -func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error { - var writer io.Writer - switch b.filePath { - case stdout: - writer = os.Stdout - case discard: - return nil - } - - buf := bytes.NewBuffer(make([]byte, 0, 6000)) - err := b.formatter.FormatAndWriteResponse(ctx, buf, in) - if err != nil { - return err - } - - return b.log(ctx, buf, writer) -} - -func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error { - // Event logger behavior - manually Process each node +func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput) error { if len(b.nodeIDList) > 0 { return audit.ProcessManual(ctx, in, b.nodeIDList, b.nodeMap) } - // Old behavior - var writer io.Writer - switch b.filePath { - case stdout: - writer = os.Stdout - case discard: - return nil - } - - var buf bytes.Buffer - - temporaryFormatter, err := audit.NewTemporaryFormatter(config["format"], config["prefix"]) - if err != nil { - return err - } - - if err = temporaryFormatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { - return err - } - - return b.log(ctx, &buf, writer) -} - -// The file lock must be held before calling this -// Deprecated: Use eventlogger. -func (b *Backend) open() error { - if b.f != nil { - return nil - } - if err := os.MkdirAll(filepath.Dir(b.filePath), b.mode); err != nil { - return err - } - - var err error - b.f, err = os.OpenFile(b.filePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, b.mode) - if err != nil { - return err - } - - // Change the file mode in case the log file already existed. We special - // case /dev/null since we can't chmod it and bypass if the mode is zero - switch b.filePath { - case "/dev/null": - default: - if b.mode != 0 { - err = os.Chmod(b.filePath, b.mode) - if err != nil { - return err - } - } - } - return nil } func (b *Backend) Reload(_ context.Context) error { - // When there are nodes created in the map, use the eventlogger behavior. - if len(b.nodeMap) > 0 { - for _, n := range b.nodeMap { - if n.Type() == eventlogger.NodeTypeSink { - return n.Reopen() - } + for _, n := range b.nodeMap { + if n.Type() == eventlogger.NodeTypeSink { + return n.Reopen() } - - return nil - } else { - // old non-eventlogger behavior - switch b.filePath { - case stdout, discard: - return nil - } - - b.fileLock.Lock() - defer b.fileLock.Unlock() - - if b.f == nil { - return b.open() - } - - err := b.f.Close() - // Set to nil here so that even if we error out, on the next access open() - // will be tried - b.f = nil - if err != nil { - return err - } - - return b.open() } + + return nil } func (b *Backend) Invalidate(_ context.Context) { diff --git a/builtin/audit/file/backend_test.go b/builtin/audit/file/backend_test.go index 00f0bccacf..99436c4b80 100644 --- a/builtin/audit/file/backend_test.go +++ b/builtin/audit/file/backend_test.go @@ -5,130 +5,110 @@ package file import ( "context" - "io/ioutil" "os" "path/filepath" "strconv" "testing" - "time" "github.com/hashicorp/eventlogger" "github.com/hashicorp/vault/audit" - "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/internal/observability/event" "github.com/hashicorp/vault/sdk/helper/salt" "github.com/hashicorp/vault/sdk/logical" "github.com/stretchr/testify/require" ) +// TestAuditFile_fileModeNew verifies that the backend Factory correctly sets +// the file mode when the mode argument is set. func TestAuditFile_fileModeNew(t *testing.T) { + t.Parallel() + modeStr := "0777" mode, err := strconv.ParseUint(modeStr, 8, 32) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) file := filepath.Join(t.TempDir(), "auditTest.txt") - config := map[string]string{ - "path": file, - "mode": modeStr, - } - _, err = Factory(context.Background(), &audit.BackendConfig{ + backendConfig := &audit.BackendConfig{ + Config: map[string]string{ + "path": file, + "mode": modeStr, + }, + MountPath: "foo/bar", SaltConfig: &salt.Config{}, SaltView: &logical.InmemStorage{}, - Config: config, - }, false, nil) - if err != nil { - t.Fatal(err) } + _, err = Factory(context.Background(), backendConfig, nil) + require.NoError(t, err) info, err := os.Stat(file) - if err != nil { - t.Fatalf("Cannot retrieve file mode from `Stat`") - } - if info.Mode() != os.FileMode(mode) { - t.Fatalf("File mode does not match.") - } + require.NoErrorf(t, err, "cannot retrieve file mode from `Stat`") + require.Equalf(t, os.FileMode(mode), info.Mode(), "File mode does not match.") } +// TestAuditFile_fileModeExisting verifies that the backend Factory correctly sets +// the mode on an existing file. func TestAuditFile_fileModeExisting(t *testing.T) { - f, err := ioutil.TempFile("", "test") - if err != nil { - t.Fatalf("Failure to create test file.") - } - defer os.Remove(f.Name()) + t.Parallel() + + dir := t.TempDir() + f, err := os.CreateTemp(dir, "auditTest.log") + require.NoErrorf(t, err, "Failure to create test file.") err = os.Chmod(f.Name(), 0o777) - if err != nil { - t.Fatalf("Failure to chmod temp file for testing.") - } + require.NoErrorf(t, err, "Failure to chmod temp file for testing.") err = f.Close() - if err != nil { - t.Fatalf("Failure to close temp file for test.") - } + require.NoErrorf(t, err, "Failure to close temp file for test.") - config := map[string]string{ - "path": f.Name(), - } - - _, err = Factory(context.Background(), &audit.BackendConfig{ - Config: config, + backendConfig := &audit.BackendConfig{ + Config: map[string]string{ + "path": f.Name(), + }, + MountPath: "foo/bar", SaltConfig: &salt.Config{}, SaltView: &logical.InmemStorage{}, - }, false, nil) - if err != nil { - t.Fatal(err) } + _, err = Factory(context.Background(), backendConfig, nil) + require.NoError(t, err) + info, err := os.Stat(f.Name()) - if err != nil { - t.Fatalf("cannot retrieve file mode from `Stat`") - } - if info.Mode() != os.FileMode(0o600) { - t.Fatalf("File mode does not match.") - } + require.NoErrorf(t, err, "cannot retrieve file mode from `Stat`") + require.Equalf(t, os.FileMode(0o600), info.Mode(), "File mode does not match.") } +// TestAuditFile_fileMode0000 verifies that setting the audit file mode to +// "0000" prevents Vault from modifying the permissions of the file. func TestAuditFile_fileMode0000(t *testing.T) { - f, err := ioutil.TempFile("", "test") - if err != nil { - t.Fatalf("Failure to create test file. The error is %v", err) - } - defer os.Remove(f.Name()) + t.Parallel() + + dir := t.TempDir() + f, err := os.CreateTemp(dir, "auditTest.log") + require.NoErrorf(t, err, "Failure to create test file.") err = os.Chmod(f.Name(), 0o777) - if err != nil { - t.Fatalf("Failure to chmod temp file for testing. The error is %v", err) - } + require.NoErrorf(t, err, "Failure to chmod temp file for testing.") err = f.Close() - if err != nil { - t.Fatalf("Failure to close temp file for test. The error is %v", err) - } + require.NoErrorf(t, err, "Failure to close temp file for test.") - config := map[string]string{ - "path": f.Name(), - "mode": "0000", - } - - _, err = Factory(context.Background(), &audit.BackendConfig{ - Config: config, + backendConfig := &audit.BackendConfig{ + Config: map[string]string{ + "path": f.Name(), + "mode": "0000", + }, + MountPath: "foo/bar", SaltConfig: &salt.Config{}, SaltView: &logical.InmemStorage{}, - }, false, nil) - if err != nil { - t.Fatal(err) } + _, err = Factory(context.Background(), backendConfig, nil) + require.NoError(t, err) + info, err := os.Stat(f.Name()) - if err != nil { - t.Fatalf("cannot retrieve file mode from `Stat`. The error is %v", err) - } - if info.Mode() != os.FileMode(0o777) { - t.Fatalf("File mode does not match.") - } + require.NoErrorf(t, err, "cannot retrieve file mode from `Stat`. The error is %v", err) + require.Equalf(t, os.FileMode(0o777), info.Mode(), "File mode does not match.") } // TestAuditFile_EventLogger_fileModeNew verifies that the Factory function @@ -137,82 +117,26 @@ func TestAuditFile_fileMode0000(t *testing.T) { func TestAuditFile_EventLogger_fileModeNew(t *testing.T) { modeStr := "0777" mode, err := strconv.ParseUint(modeStr, 8, 32) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) file := filepath.Join(t.TempDir(), "auditTest.txt") - config := map[string]string{ - "path": file, - "mode": modeStr, - } - _, err = Factory(context.Background(), &audit.BackendConfig{ + backendConfig := &audit.BackendConfig{ + Config: map[string]string{ + "path": file, + "mode": modeStr, + }, MountPath: "foo/bar", SaltConfig: &salt.Config{}, SaltView: &logical.InmemStorage{}, - Config: config, - }, true, nil) - if err != nil { - t.Fatal(err) } + _, err = Factory(context.Background(), backendConfig, nil) + require.NoError(t, err) + info, err := os.Stat(file) - if err != nil { - t.Fatalf("Cannot retrieve file mode from `Stat`") - } - if info.Mode() != os.FileMode(mode) { - t.Fatalf("File mode does not match.") - } -} - -func BenchmarkAuditFile_request(b *testing.B) { - config := map[string]string{ - "path": "/dev/null", - } - sink, err := Factory(context.Background(), &audit.BackendConfig{ - Config: config, - SaltConfig: &salt.Config{}, - SaltView: &logical.InmemStorage{}, - }, false, nil) - if err != nil { - b.Fatal(err) - } - - in := &logical.LogInput{ - Auth: &logical.Auth{ - ClientToken: "foo", - Accessor: "bar", - EntityID: "foobarentity", - DisplayName: "testtoken", - NoDefaultPolicy: true, - Policies: []string{"root"}, - TokenType: logical.TokenTypeService, - }, - Request: &logical.Request{ - Operation: logical.UpdateOperation, - Path: "/foo", - Connection: &logical.Connection{ - RemoteAddr: "127.0.0.1", - }, - WrapInfo: &logical.RequestWrapInfo{ - TTL: 60 * time.Second, - }, - Headers: map[string][]string{ - "foo": {"bar"}, - }, - }, - } - - ctx := namespace.RootContext(context.Background()) - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := sink.LogRequest(ctx, in); err != nil { - panic(err) - } - } - }) + require.NoErrorf(t, err, "Cannot retrieve file mode from `Stat`") + require.Equalf(t, os.FileMode(mode), info.Mode(), "File mode does not match.") } // TestBackend_formatterConfig ensures that all the configuration values are parsed correctly. @@ -639,7 +563,7 @@ func TestBackend_Factory_Conf(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - be, err := Factory(ctx, tc.backendConfig, true, nil) + be, err := Factory(ctx, tc.backendConfig, nil) switch { case tc.isErrorExpected: @@ -696,7 +620,7 @@ func TestBackend_IsFallback(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - be, err := Factory(ctx, tc.backendConfig, true, nil) + be, err := Factory(ctx, tc.backendConfig, nil) require.NoError(t, err) require.NotNil(t, be) require.Equal(t, tc.isFallbackExpected, be.IsFallback()) diff --git a/builtin/audit/socket/backend.go b/builtin/audit/socket/backend.go index 9cf65130a2..c35f512e04 100644 --- a/builtin/audit/socket/backend.go +++ b/builtin/audit/socket/backend.go @@ -4,7 +4,6 @@ package socket import ( - "bytes" "context" "fmt" "net" @@ -14,7 +13,6 @@ import ( "time" "github.com/hashicorp/eventlogger" - "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/internal/observability/event" @@ -30,8 +28,6 @@ type Backend struct { address string connection net.Conn fallback bool - formatter *audit.EntryFormatterWriter - formatConfig audit.FormatterConfig name string nodeIDList []eventlogger.NodeID nodeMap map[eventlogger.NodeID]eventlogger.Node @@ -43,7 +39,7 @@ type Backend struct { writeDuration time.Duration } -func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { +func Factory(_ context.Context, conf *audit.BackendConfig, headersConfig audit.HeaderFormatter) (audit.Backend, error) { const op = "socket.Factory" if conf.SaltConfig == nil { @@ -88,206 +84,66 @@ func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, return nil, fmt.Errorf("%s: cannot configure a fallback device with a filter: %w", op, event.ErrInvalidParameter) } - cfg, err := formatterConfig(conf.Config) - if err != nil { - return nil, fmt.Errorf("%s: failed to create formatter config: %w", op, err) - } - b := &Backend{ fallback: fallback, address: address, - formatConfig: cfg, name: conf.MountPath, saltConfig: conf.SaltConfig, saltView: conf.SaltView, socketType: socketType, writeDuration: writeDuration, + nodeIDList: []eventlogger.NodeID{}, + nodeMap: make(map[eventlogger.NodeID]eventlogger.Node), } - // Configure the formatter for either case. - f, err := audit.NewEntryFormatter(cfg, b, audit.WithHeaderFormatter(headersConfig)) + err = b.configureFilterNode(conf.Config["filter"]) if err != nil { - return nil, fmt.Errorf("%s: error creating formatter: %w", op, err) - } - var w audit.Writer - switch b.formatConfig.RequiredFormat { - case audit.JSONFormat: - w = &audit.JSONWriter{Prefix: conf.Config["prefix"]} - case audit.JSONxFormat: - w = &audit.JSONxWriter{Prefix: conf.Config["prefix"]} + return nil, fmt.Errorf("%s: error configuring filter node: %w", op, err) } - fw, err := audit.NewEntryFormatterWriter(b.formatConfig, f, w) + cfg, err := formatterConfig(conf.Config) if err != nil { - return nil, fmt.Errorf("%s: error creating formatter writer: %w", op, err) + return nil, fmt.Errorf("%s: failed to create formatter config: %w", op, err) } - b.formatter = fw + opts := []audit.Option{ + audit.WithHeaderFormatter(headersConfig), + } - if useEventLogger { - b.nodeIDList = []eventlogger.NodeID{} - b.nodeMap = make(map[eventlogger.NodeID]eventlogger.Node) + err = b.configureFormatterNode(cfg, opts...) + if err != nil { + return nil, fmt.Errorf("%s: error configuring formatter node: %w", op, err) + } - err := b.configureFilterNode(conf.Config["filter"]) - if err != nil { - return nil, fmt.Errorf("%s: error configuring filter node: %w", op, err) - } + sinkOpts := []event.Option{ + event.WithSocketType(socketType), + event.WithMaxDuration(writeDeadline), + } - opts := []audit.Option{ - audit.WithHeaderFormatter(headersConfig), - } - - err = b.configureFormatterNode(cfg, opts...) - if err != nil { - return nil, fmt.Errorf("%s: error configuring formatter node: %w", op, err) - } - - sinkOpts := []event.Option{ - event.WithSocketType(socketType), - event.WithMaxDuration(writeDeadline), - } - - err = b.configureSinkNode(conf.MountPath, address, cfg.RequiredFormat.String(), sinkOpts...) - if err != nil { - return nil, fmt.Errorf("%s: error configuring sink node: %w", op, err) - } + err = b.configureSinkNode(conf.MountPath, address, cfg.RequiredFormat.String(), sinkOpts...) + if err != nil { + return nil, fmt.Errorf("%s: error configuring sink node: %w", op, err) } return b, nil } -// Deprecated: Use eventlogger. -func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { - var buf bytes.Buffer - if err := b.formatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { - return err - } - - b.Lock() - defer b.Unlock() - - err := b.write(ctx, buf.Bytes()) - if err != nil { - rErr := b.reconnect(ctx) - if rErr != nil { - err = multierror.Append(err, rErr) - } else { - // Try once more after reconnecting - err = b.write(ctx, buf.Bytes()) - } - } - - return err -} - -// Deprecated: Use eventlogger. -func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error { - var buf bytes.Buffer - if err := b.formatter.FormatAndWriteResponse(ctx, &buf, in); err != nil { - return err - } - - b.Lock() - defer b.Unlock() - - err := b.write(ctx, buf.Bytes()) - if err != nil { - rErr := b.reconnect(ctx) - if rErr != nil { - err = multierror.Append(err, rErr) - } else { - // Try once more after reconnecting - err = b.write(ctx, buf.Bytes()) - } - } - - return err -} - -func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error { - // Event logger behavior - manually Process each node +func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput) error { if len(b.nodeIDList) > 0 { return audit.ProcessManual(ctx, in, b.nodeIDList, b.nodeMap) } - // Old behavior - var buf bytes.Buffer - - temporaryFormatter, err := audit.NewTemporaryFormatter(config["format"], config["prefix"]) - if err != nil { - return err - } - - if err = temporaryFormatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { - return err - } - - b.Lock() - defer b.Unlock() - - err = b.write(ctx, buf.Bytes()) - if err != nil { - rErr := b.reconnect(ctx) - if rErr != nil { - err = multierror.Append(err, rErr) - } else { - // Try once more after reconnecting - err = b.write(ctx, buf.Bytes()) - } - } - - return err -} - -// Deprecated: Use eventlogger. -func (b *Backend) write(ctx context.Context, buf []byte) error { - if b.connection == nil { - if err := b.reconnect(ctx); err != nil { - return err - } - } - - err := b.connection.SetWriteDeadline(time.Now().Add(b.writeDuration)) - if err != nil { - return err - } - - _, err = b.connection.Write(buf) - if err != nil { - return err - } - - return nil -} - -// Deprecated: Use eventlogger. -func (b *Backend) reconnect(ctx context.Context) error { - if b.connection != nil { - b.connection.Close() - b.connection = nil - } - - timeoutContext, cancel := context.WithTimeout(ctx, b.writeDuration) - defer cancel() - - dialer := net.Dialer{} - conn, err := dialer.DialContext(timeoutContext, b.socketType, b.address) - if err != nil { - return err - } - - b.connection = conn - return nil } func (b *Backend) Reload(ctx context.Context) error { - b.Lock() - defer b.Unlock() + for _, n := range b.nodeMap { + if n.Type() == eventlogger.NodeTypeSink { + return n.Reopen() + } + } - err := b.reconnect(ctx) - - return err + return nil } func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { diff --git a/builtin/audit/socket/backend_test.go b/builtin/audit/socket/backend_test.go index db50f3529b..c118df6093 100644 --- a/builtin/audit/socket/backend_test.go +++ b/builtin/audit/socket/backend_test.go @@ -456,7 +456,7 @@ func TestBackend_Factory_Conf(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - be, err := Factory(ctx, tc.backendConfig, true, nil) + be, err := Factory(ctx, tc.backendConfig, nil) switch { case tc.isErrorExpected: @@ -515,7 +515,7 @@ func TestBackend_IsFallback(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - be, err := Factory(ctx, tc.backendConfig, true, nil) + be, err := Factory(ctx, tc.backendConfig, nil) require.NoError(t, err) require.NotNil(t, be) require.Equal(t, tc.isFallbackExpected, be.IsFallback()) diff --git a/builtin/audit/syslog/backend.go b/builtin/audit/syslog/backend.go index c09302950f..6d0be428f6 100644 --- a/builtin/audit/syslog/backend.go +++ b/builtin/audit/syslog/backend.go @@ -4,7 +4,6 @@ package syslog import ( - "bytes" "context" "fmt" "strconv" @@ -24,20 +23,18 @@ var _ audit.Backend = (*Backend)(nil) // Backend is the audit backend for the syslog-based audit store. type Backend struct { - fallback bool - formatter *audit.EntryFormatterWriter - formatConfig audit.FormatterConfig - logger gsyslog.Syslogger - name string - nodeIDList []eventlogger.NodeID - nodeMap map[eventlogger.NodeID]eventlogger.Node - salt *salt.Salt - saltConfig *salt.Config - saltMutex sync.RWMutex - saltView logical.Storage + fallback bool + logger gsyslog.Syslogger + name string + nodeIDList []eventlogger.NodeID + nodeMap map[eventlogger.NodeID]eventlogger.Node + salt *salt.Salt + saltConfig *salt.Config + saltMutex sync.RWMutex + saltView logical.Storage } -func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { +func Factory(_ context.Context, conf *audit.BackendConfig, headersConfig audit.HeaderFormatter) (audit.Backend, error) { const op = "syslog.Factory" if conf.SaltConfig == nil { @@ -75,11 +72,6 @@ func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, return nil, fmt.Errorf("%s: cannot configure a fallback device with a filter: %w", op, event.ErrInvalidParameter) } - cfg, err := formatterConfig(conf.Config) - if err != nil { - return nil, fmt.Errorf("%s: failed to create formatter config: %w", op, err) - } - // Get the logger logger, err := gsyslog.NewLogger(gsyslog.LOG_INFO, facility, tag) if err != nil { @@ -87,113 +79,54 @@ func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, } b := &Backend{ - fallback: fallback, - formatConfig: cfg, - logger: logger, - name: conf.MountPath, - saltConfig: conf.SaltConfig, - saltView: conf.SaltView, + fallback: fallback, + logger: logger, + name: conf.MountPath, + saltConfig: conf.SaltConfig, + saltView: conf.SaltView, + nodeIDList: []eventlogger.NodeID{}, + nodeMap: make(map[eventlogger.NodeID]eventlogger.Node), } - // Configure the formatter for either case. - f, err := audit.NewEntryFormatter(b.formatConfig, b, audit.WithHeaderFormatter(headersConfig), audit.WithPrefix(conf.Config["prefix"])) + err = b.configureFilterNode(conf.Config["filter"]) if err != nil { - return nil, fmt.Errorf("%s: error creating formatter: %w", op, err) + return nil, fmt.Errorf("%s: error configuring filter node: %w", op, err) } - var w audit.Writer - switch b.formatConfig.RequiredFormat { - case audit.JSONFormat: - w = &audit.JSONWriter{Prefix: conf.Config["prefix"]} - case audit.JSONxFormat: - w = &audit.JSONxWriter{Prefix: conf.Config["prefix"]} - } - - fw, err := audit.NewEntryFormatterWriter(b.formatConfig, f, w) + cfg, err := formatterConfig(conf.Config) if err != nil { - return nil, fmt.Errorf("%s: error creating formatter writer: %w", op, err) + return nil, fmt.Errorf("%s: failed to create formatter config: %w", op, err) } - b.formatter = fw + formatterOpts := []audit.Option{ + audit.WithHeaderFormatter(headersConfig), + audit.WithPrefix(conf.Config["prefix"]), + } - if useEventLogger { - b.nodeIDList = []eventlogger.NodeID{} - b.nodeMap = make(map[eventlogger.NodeID]eventlogger.Node) + err = b.configureFormatterNode(cfg, formatterOpts...) + if err != nil { + return nil, fmt.Errorf("%s: error configuring formatter node: %w", op, err) + } - err := b.configureFilterNode(conf.Config["filter"]) - if err != nil { - return nil, fmt.Errorf("%s: error configuring filter node: %w", op, err) - } + sinkOpts := []event.Option{ + event.WithFacility(facility), + event.WithTag(tag), + } - formatterOpts := []audit.Option{ - audit.WithHeaderFormatter(headersConfig), - audit.WithPrefix(conf.Config["prefix"]), - } - - err = b.configureFormatterNode(cfg, formatterOpts...) - if err != nil { - return nil, fmt.Errorf("%s: error configuring formatter node: %w", op, err) - } - - sinkOpts := []event.Option{ - event.WithFacility(facility), - event.WithTag(tag), - } - - err = b.configureSinkNode(conf.MountPath, cfg.RequiredFormat.String(), sinkOpts...) - if err != nil { - return nil, fmt.Errorf("%s: error configuring sink node: %w", op, err) - } + err = b.configureSinkNode(conf.MountPath, cfg.RequiredFormat.String(), sinkOpts...) + if err != nil { + return nil, fmt.Errorf("%s: error configuring sink node: %w", op, err) } return b, nil } -// Deprecated: Use eventlogger. -func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { - var buf bytes.Buffer - if err := b.formatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { - return err - } - - // Write out to syslog - _, err := b.logger.Write(buf.Bytes()) - return err -} - -// Deprecated: Use eventlogger. -func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error { - var buf bytes.Buffer - if err := b.formatter.FormatAndWriteResponse(ctx, &buf, in); err != nil { - return err - } - - // Write out to syslog - _, err := b.logger.Write(buf.Bytes()) - return err -} - -func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error { - // Event logger behavior - manually Process each node +func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput) error { if len(b.nodeIDList) > 0 { return audit.ProcessManual(ctx, in, b.nodeIDList, b.nodeMap) } - // Old behavior - var buf bytes.Buffer - - temporaryFormatter, err := audit.NewTemporaryFormatter(config["format"], config["prefix"]) - if err != nil { - return err - } - - if err = temporaryFormatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { - return err - } - - // Send to syslog - _, err = b.logger.Write(buf.Bytes()) - return err + return nil } func (b *Backend) Reload(_ context.Context) error { diff --git a/builtin/audit/syslog/backend_test.go b/builtin/audit/syslog/backend_test.go index f3c728d5c7..c60addb547 100644 --- a/builtin/audit/syslog/backend_test.go +++ b/builtin/audit/syslog/backend_test.go @@ -375,7 +375,7 @@ func TestBackend_Factory_Conf(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - be, err := Factory(ctx, tc.backendConfig, true, nil) + be, err := Factory(ctx, tc.backendConfig, nil) switch { case tc.isErrorExpected: @@ -430,7 +430,7 @@ func TestBackend_IsFallback(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - be, err := Factory(ctx, tc.backendConfig, true, nil) + be, err := Factory(ctx, tc.backendConfig, nil) require.NoError(t, err) require.NotNil(t, be) require.Equal(t, tc.isFallbackExpected, be.IsFallback()) diff --git a/helper/testhelpers/corehelpers/corehelpers.go b/helper/testhelpers/corehelpers/corehelpers.go index 8c3d7cfb96..8b2c6e7d15 100644 --- a/helper/testhelpers/corehelpers/corehelpers.go +++ b/helper/testhelpers/corehelpers/corehelpers.go @@ -249,31 +249,32 @@ func NewNoopAudit(config *audit.BackendConfig, opts ...audit.Option) (*NoopAudit MountPath: config.MountPath, } - n := &NoopAudit{Config: backendConfig} + noopBackend := &NoopAudit{ + Config: backendConfig, + nodeIDList: make([]eventlogger.NodeID, 2), + nodeMap: make(map[eventlogger.NodeID]eventlogger.Node, 2), + } cfg, err := audit.NewFormatterConfig() if err != nil { return nil, err } - f, err := audit.NewEntryFormatter(cfg, n, opts...) - if err != nil { - return nil, fmt.Errorf("error creating formatter: %w", err) - } - - n.nodeIDList = make([]eventlogger.NodeID, 2) - n.nodeMap = make(map[eventlogger.NodeID]eventlogger.Node, 2) - formatterNodeID, err := event.GenerateNodeID() if err != nil { return nil, fmt.Errorf("error generating random NodeID for formatter node: %w", err) } - // Wrap the formatting node, so we can get any bytes that were formatted etc. - wrappedFormatter := &noopWrapper{format: "json", node: f, backend: n} + formatterNode, err := audit.NewEntryFormatter(cfg, noopBackend, opts...) + if err != nil { + return nil, fmt.Errorf("error creating formatter: %w", err) + } - n.nodeIDList[0] = formatterNodeID - n.nodeMap[formatterNodeID] = wrappedFormatter + // Wrap the formatting node, so we can get any bytes that were formatted etc. + wrappedFormatter := &noopWrapper{format: "json", node: formatterNode, backend: noopBackend} + + noopBackend.nodeIDList[0] = formatterNodeID + noopBackend.nodeMap[formatterNodeID] = wrappedFormatter sinkNode := event.NewNoopSink() sinkNodeID, err := event.GenerateNodeID() @@ -281,17 +282,17 @@ func NewNoopAudit(config *audit.BackendConfig, opts ...audit.Option) (*NoopAudit return nil, fmt.Errorf("error generating random NodeID for sink node: %w", err) } - n.nodeIDList[1] = sinkNodeID - n.nodeMap[sinkNodeID] = sinkNode + noopBackend.nodeIDList[1] = sinkNodeID + noopBackend.nodeMap[sinkNodeID] = sinkNode - return n, nil + return noopBackend, nil } // NoopAuditFactory should be used when the test needs a way to access bytes that // have been formatted by the pipeline during audit requests. // The records parameter will be repointed to the one used within the pipeline. func NoopAuditFactory(records **[][]byte) audit.Factory { - return func(_ context.Context, config *audit.BackendConfig, _ bool, headerFormatter audit.HeaderFormatter) (audit.Backend, error) { + return func(_ context.Context, config *audit.BackendConfig, headerFormatter audit.HeaderFormatter) (audit.Backend, error) { n, err := NewNoopAudit(config, audit.WithHeaderFormatter(headerFormatter)) if err != nil { return nil, err @@ -429,65 +430,10 @@ func (n *noopWrapper) Type() eventlogger.NodeType { return n.node.Type() } -// Deprecated: use eventlogger. -func (n *NoopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error { - return nil -} - -// Deprecated: use eventlogger. -func (n *NoopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error { - return nil -} - // LogTestMessage will manually crank the handle on the nodes associated with this backend. -func (n *NoopAudit) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error { - n.l.Lock() - defer n.l.Unlock() - - // Fake event for test purposes. - e := &eventlogger.Event{ - Type: eventlogger.EventType(event.AuditType.String()), - CreatedAt: time.Now(), - Formatted: make(map[string][]byte), - Payload: in, - } - - // Try to get the required format from config and default to JSON. - format, ok := config["format"] - if !ok { - format = "json" - } - cfg, err := audit.NewFormatterConfig(audit.WithFormat(format)) - if err != nil { - return fmt.Errorf("cannot create config for formatter node: %w", err) - } - // Create a temporary formatter node for reuse. - f, err := audit.NewEntryFormatter(cfg, n, audit.WithPrefix(config["prefix"])) - - // Go over each node in order from our list. - for _, id := range n.nodeIDList { - node, ok := n.nodeMap[id] - if !ok { - return fmt.Errorf("node not found: %v", id) - } - - switch node.Type() { - case eventlogger.NodeTypeFormatter: - // Use a temporary formatter node which doesn't persist its salt anywhere. - if formatNode, ok := node.(*audit.EntryFormatter); ok && formatNode != nil { - e, err = f.Process(ctx, e) - - // Housekeeping, we should update that we processed some bytes. - if e != nil { - b, ok := e.Format(format) - if ok { - n.records = append(n.records, b) - } - } - } - default: - e, err = node.Process(ctx, e) - } +func (n *NoopAudit) LogTestMessage(ctx context.Context, in *logical.LogInput) error { + if len(n.nodeIDList) > 0 { + return audit.ProcessManual(ctx, in, n.nodeIDList, n.nodeMap) } return nil diff --git a/http/logical_test.go b/http/logical_test.go index 88964ac874..bd90ac4ea2 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -573,7 +573,7 @@ func TestLogical_Audit_invalidWrappingToken(t *testing.T) { noop := corehelpers.TestNoopAudit(t, "noop/", nil) c, _, root := vault.TestCoreUnsealedWithConfig(t, &vault.CoreConfig{ AuditBackends: map[string]audit.Factory{ - "noop": func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { + "noop": func(ctx context.Context, config *audit.BackendConfig, _ audit.HeaderFormatter) (audit.Backend, error) { return noop, nil }, }, diff --git a/vault/audit.go b/vault/audit.go index 7f2e5cac47..9c1aaaca3f 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -8,14 +8,12 @@ import ( "crypto/sha256" "errors" "fmt" - "os" "strconv" "strings" "time" "github.com/hashicorp/go-secure-stdlib/parseutil" - - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/consts" @@ -41,12 +39,6 @@ const ( // auditTableType is the value we expect to find for the audit table and // corresponding entries auditTableType = "audit" - - // featureFlagDisableEventLogger contains the feature flag name which can be - // used to disable internal eventlogger behavior for the audit system. - // NOTE: this is an undocumented and temporary feature flag, it should not - // be relied on to remain part of Vault for any subsequent releases. - featureFlagDisableEventLogger = "VAULT_AUDIT_DISABLE_EVENTLOGGER" ) // loadAuditFailed if loading audit tables encounters an error @@ -152,7 +144,7 @@ func (c *Core) enableAudit(ctx context.Context, entry *MountEntry, updateStorage if err != nil { return err } - err = backend.LogTestMessage(ctx, testProbe, entry.Options) + err = backend.LogTestMessage(ctx, testProbe) if err != nil { c.logger.Error("new audit backend failed test", "path", entry.Path, "type", entry.Type, "error", err) return fmt.Errorf("audit backend failed test message: %w", err) @@ -416,14 +408,9 @@ func (c *Core) setupAudits(ctx context.Context) error { c.auditLock.Lock() defer c.auditLock.Unlock() - disableEventLogger, err := parseutil.ParseBool(os.Getenv(featureFlagDisableEventLogger)) - if err != nil { - return fmt.Errorf("unable to parse feature flag: %q: %w", featureFlagDisableEventLogger, err) - } - brokerLogger := c.baseLogger.Named("audit") - broker, err := NewAuditBroker(brokerLogger, !disableEventLogger) + broker, err := NewAuditBroker(brokerLogger) if err != nil { return err } @@ -530,11 +517,6 @@ func (c *Core) newAuditBackend(ctx context.Context, entry *MountEntry, view logi Location: salt.DefaultLocation, } - disableEventLogger, err := parseutil.ParseBool(os.Getenv(featureFlagDisableEventLogger)) - if err != nil { - return nil, fmt.Errorf("unable to parse feature flag: %q: %w", featureFlagDisableEventLogger, err) - } - be, err := f( ctx, &audit.BackendConfig{ SaltView: view, @@ -542,7 +524,6 @@ func (c *Core) newAuditBackend(ctx context.Context, entry *MountEntry, view logi Config: conf, MountPath: entry.Path, }, - !disableEventLogger, c.auditedHeaders) if err != nil { return nil, fmt.Errorf("unable to create new audit backend: %w", err) @@ -607,14 +588,14 @@ func (b *basicAuditor) AuditRequest(ctx context.Context, input *logical.LogInput if b.c.auditBroker == nil { return consts.ErrSealed } - return b.c.auditBroker.LogRequest(ctx, input, b.c.auditedHeaders) + return b.c.auditBroker.LogRequest(ctx, input) } func (b *basicAuditor) AuditResponse(ctx context.Context, input *logical.LogInput) error { if b.c.auditBroker == nil { return consts.ErrSealed } - return b.c.auditBroker.LogResponse(ctx, input, b.c.auditedHeaders) + return b.c.auditBroker.LogResponse(ctx, input) } type genericAuditor struct { @@ -627,12 +608,12 @@ func (g genericAuditor) AuditRequest(ctx context.Context, input *logical.LogInpu ctx = namespace.ContextWithNamespace(ctx, g.namespace) logInput := *input logInput.Type = g.mountType + "-request" - return g.c.auditBroker.LogRequest(ctx, &logInput, g.c.auditedHeaders) + return g.c.auditBroker.LogRequest(ctx, &logInput) } func (g genericAuditor) AuditResponse(ctx context.Context, input *logical.LogInput) error { ctx = namespace.ContextWithNamespace(ctx, g.namespace) logInput := *input logInput.Type = g.mountType + "-response" - return g.c.auditBroker.LogResponse(ctx, &logInput, g.c.auditedHeaders) + return g.c.auditBroker.LogResponse(ctx, &logInput) } diff --git a/vault/audit_broker.go b/vault/audit_broker.go index 31f99bd087..f783106fd4 100644 --- a/vault/audit_broker.go +++ b/vault/audit_broker.go @@ -50,40 +50,32 @@ type AuditBroker struct { } // NewAuditBroker creates a new audit broker -func NewAuditBroker(log hclog.Logger, useEventLogger bool) (*AuditBroker, error) { - var eventBroker *eventlogger.Broker - var fallbackBroker *eventlogger.Broker - var err error +func NewAuditBroker(log hclog.Logger) (*AuditBroker, error) { + const op = "vault.NewAuditBroker" - // The reason for this check is due to 1.15.x supporting the env var: - // 'VAULT_AUDIT_DISABLE_EVENTLOGGER' - // When NewAuditBroker is called, it is supplied a bool to determine whether - // we initialize the broker (and fallback broker), which are left nil otherwise. - // In 1.16.x this check should go away and the env var removed. - if useEventLogger { - eventBroker, err = eventlogger.NewBroker() - if err != nil { - return nil, fmt.Errorf("error creating event broker for audit events: %w", err) - } - - // Set up the broker that will support a single fallback device. - fallbackBroker, err = eventlogger.NewBroker() - if err != nil { - return nil, fmt.Errorf("error creating event fallback broker for audit event: %w", err) - } + eventBroker, err := eventlogger.NewBroker() + if err != nil { + return nil, fmt.Errorf("%s: error creating event broker for audit events: %w", op, err) } - b := &AuditBroker{ + // Set up the broker that will support a single fallback device. + fallbackEventBroker, err := eventlogger.NewBroker() + if err != nil { + return nil, fmt.Errorf("%s: error creating event fallback broker for audit event: %w", op, err) + } + + broker := &AuditBroker{ backends: make(map[string]backendEntry), logger: log, broker: eventBroker, - fallbackBroker: fallbackBroker, + fallbackBroker: fallbackEventBroker, } - return b, nil + + return broker, nil } // Register is used to add new audit backend to the broker -func (a *AuditBroker) Register(name string, b audit.Backend, local bool) error { +func (a *AuditBroker) Register(name string, backend audit.Backend, local bool) error { const op = "vault.(AuditBroker).Register" a.Lock() @@ -100,7 +92,7 @@ func (a *AuditBroker) Register(name string, b audit.Backend, local bool) error { } // Fallback devices are singleton instances, we cannot register more than one or overwrite the existing one. - if b.IsFallback() && a.fallbackBroker.IsAnyPipelineRegistered(eventlogger.EventType(event.AuditType.String())) { + if backend.IsFallback() && a.fallbackBroker.IsAnyPipelineRegistered(eventlogger.EventType(event.AuditType.String())) { existing, err := a.existingFallbackName() if err != nil { return fmt.Errorf("%s: existing fallback device already registered: %w", op, err) @@ -109,32 +101,25 @@ func (a *AuditBroker) Register(name string, b audit.Backend, local bool) error { return fmt.Errorf("%s: existing fallback device already registered: %q", op, existing) } - // The reason for this check is due to 1.15.x supporting the env var: - // 'VAULT_AUDIT_DISABLE_EVENTLOGGER' - // When NewAuditBroker is called, it is supplied a bool to determine whether - // we initialize the broker (and fallback broker), which are left nil otherwise. - // In 1.16.x this check should go away and the env var removed. - if a.broker != nil { - if name != b.Name() { - return fmt.Errorf("%s: audit registration failed due to device name mismatch: %q, %q", op, name, b.Name()) - } + if name != backend.Name() { + return fmt.Errorf("%s: audit registration failed due to device name mismatch: %q, %q", op, name, backend.Name()) + } - switch { - case b.IsFallback(): - err := a.registerFallback(name, b) - if err != nil { - return fmt.Errorf("%s: unable to register fallback device for %q: %w", op, name, err) - } - default: - err := a.register(name, b) - if err != nil { - return fmt.Errorf("%s: unable to register device for %q: %w", op, name, err) - } + switch { + case backend.IsFallback(): + err := a.registerFallback(name, backend) + if err != nil { + return fmt.Errorf("%s: unable to register fallback device for %q: %w", op, name, err) + } + default: + err := a.register(name, backend) + if err != nil { + return fmt.Errorf("%s: unable to register device for %q: %w", op, name, err) } } a.backends[name] = backendEntry{ - backend: b, + backend: backend, local: local, } @@ -164,23 +149,16 @@ func (a *AuditBroker) Deregister(ctx context.Context, name string) error { // the error. delete(a.backends, name) - // The reason for this check is due to 1.15.x supporting the env var: - // 'VAULT_AUDIT_DISABLE_EVENTLOGGER' - // When NewAuditBroker is called, it is supplied a bool to determine whether - // we initialize the broker (and fallback broker), which are left nil otherwise. - // In 1.16.x this check should go away and the env var removed. - if a.broker != nil { - switch { - case name == a.fallbackName: - err := a.deregisterFallback(ctx, name) - if err != nil { - return fmt.Errorf("%s: deregistration failed for fallback audit device %q: %w", op, name, err) - } - default: - err := a.deregister(ctx, name) - if err != nil { - return fmt.Errorf("%s: deregistration failed for audit device %q: %w", op, name, err) - } + switch { + case name == a.fallbackName: + err := a.deregisterFallback(ctx, name) + if err != nil { + return fmt.Errorf("%s: deregistration failed for fallback audit device %q: %w", op, name, err) + } + default: + err := a.deregister(ctx, name) + if err != nil { + return fmt.Errorf("%s: deregistration failed for audit device %q: %w", op, name, err) } } @@ -206,10 +184,12 @@ func (a *AuditBroker) isRegistered(name string) bool { func (a *AuditBroker) IsLocal(name string) (bool, error) { a.RLock() defer a.RUnlock() + be, ok := a.backends[name] if ok { return be.local, nil } + return false, fmt.Errorf("unknown audit backend %q", name) } @@ -217,6 +197,7 @@ func (a *AuditBroker) IsLocal(name string) (bool, error) { func (a *AuditBroker) GetHash(ctx context.Context, name string, input string) (string, error) { a.RLock() defer a.RUnlock() + be, ok := a.backends[name] if !ok { return "", fmt.Errorf("unknown audit backend %q", name) @@ -227,7 +208,12 @@ func (a *AuditBroker) GetHash(ctx context.Context, name string, input string) (s // LogRequest is used to ensure all the audit backends have an opportunity to // log the given request and that *at least one* succeeds. -func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput, headersConfig *AuditedHeadersConfig) (ret error) { +func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput) (ret error) { + // If no backends are registered then we have no devices to log the request. + if len(a.backends) < 1 { + return nil + } + defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now()) a.RLock() @@ -264,76 +250,47 @@ func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput, head in.Request.Headers = headers }() - // Old behavior (no events) - if a.broker == nil { - // Ensure at least one backend logs - anyLogged := false - for name, be := range a.backends { - in.Request.Headers = nil - transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend) - if thErr != nil { - a.logger.Error("backend failed to include headers", "backend", name, "error", thErr) - continue - } - in.Request.Headers = transHeaders + e, err := audit.NewEvent(audit.RequestType) + if err != nil { + retErr = multierror.Append(retErr, err) + return retErr.ErrorOrNil() + } - start := time.Now() - lrErr := be.backend.LogRequest(ctx, in) - metrics.MeasureSince([]string{"audit", name, "log_request"}, start) - if lrErr != nil { - a.logger.Error("backend failed to log request", "backend", name, "error", lrErr) - } else { - anyLogged = true - } + e.Data = in + + // There may be cases where only the fallback device was added but no other + // normal audit devices, so check if the broker had an audit based pipeline + // registered before trying to send to it. + var status eventlogger.Status + if a.broker.IsAnyPipelineRegistered(eventlogger.EventType(event.AuditType.String())) { + status, err = a.broker.Send(ctx, eventlogger.EventType(event.AuditType.String()), e) + if err != nil { + retErr = multierror.Append(retErr, multierror.Append(err, status.Warnings...)) + return retErr.ErrorOrNil() } - if !anyLogged && len(a.backends) > 0 { - retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the request")) + } + + // Audit event ended up in at least 1 sink. + if len(status.CompleteSinks()) > 0 { + return retErr.ErrorOrNil() + } + + // There were errors from inside the pipeline and we didn't write to a sink. + if len(status.Warnings) > 0 { + retErr = multierror.Append(retErr, multierror.Append(errors.New("error during audit pipeline processing"), status.Warnings...)) + return retErr.ErrorOrNil() + } + + // If a fallback device is registered we can rely on that to 'catch all' + // and also the broker level guarantee for completed sinks. + if a.fallbackBroker.IsAnyPipelineRegistered(eventlogger.EventType(event.AuditType.String())) { + status, err = a.fallbackBroker.Send(ctx, eventlogger.EventType(event.AuditType.String()), e) + if err != nil { + retErr = multierror.Append(retErr, multierror.Append(fmt.Errorf("auditing request to fallback device failed: %w", err), status.Warnings...)) } } else { - if len(a.backends) > 0 { - e, err := audit.NewEvent(audit.RequestType) - if err != nil { - retErr = multierror.Append(retErr, err) - return retErr.ErrorOrNil() - } - - e.Data = in - - // There may be cases where only the fallback device was added but no other - // normal audit devices, so check if the broker had an audit based pipeline - // registered before trying to send to it. - var status eventlogger.Status - if a.broker.IsAnyPipelineRegistered(eventlogger.EventType(event.AuditType.String())) { - status, err = a.broker.Send(ctx, eventlogger.EventType(event.AuditType.String()), e) - if err != nil { - retErr = multierror.Append(retErr, multierror.Append(err, status.Warnings...)) - return retErr.ErrorOrNil() - } - } - - // Audit event ended up in at least 1 sink. - if len(status.CompleteSinks()) > 0 { - return retErr.ErrorOrNil() - } - - // There were errors from inside the pipeline and we didn't write to a sink. - if len(status.Warnings) > 0 { - retErr = multierror.Append(retErr, multierror.Append(errors.New("error during audit pipeline processing"), status.Warnings...)) - return retErr.ErrorOrNil() - } - - // If a fallback device is registered we can rely on that to 'catch all' - // and also the broker level guarantee for completed sinks. - if a.fallbackBroker.IsAnyPipelineRegistered(eventlogger.EventType(event.AuditType.String())) { - status, err = a.fallbackBroker.Send(ctx, eventlogger.EventType(event.AuditType.String()), e) - if err != nil { - retErr = multierror.Append(retErr, multierror.Append(fmt.Errorf("auditing request to fallback device failed: %w", err), status.Warnings...)) - } - } else { - // This audit event won't make it to any devices, we class this as a 'miss' for auditing. - metrics.IncrCounter(audit.MetricLabelsFallbackMiss(), 1) - } - } + // This audit event won't make it to any devices, we class this as a 'miss' for auditing. + metrics.IncrCounter(audit.MetricLabelsFallbackMiss(), 1) } return retErr.ErrorOrNil() @@ -341,18 +298,21 @@ func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput, head // LogResponse is used to ensure all the audit backends have an opportunity to // log the given response and that *at least one* succeeds. -func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput, headersConfig *AuditedHeadersConfig) (ret error) { +func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput) (ret error) { + // If no backends are registered then we have no devices to send audit entries to. + if len(a.backends) < 1 { + return nil + } + defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now()) + a.RLock() defer a.RUnlock() - if in.Request.InboundSSCToken != "" { - if in.Auth != nil { - reqAuthToken := in.Auth.ClientToken - in.Auth.ClientToken = in.Request.InboundSSCToken - defer func() { - in.Auth.ClientToken = reqAuthToken - }() - } + + if in.Request.InboundSSCToken != "" && in.Auth != nil { + reqAuthToken := in.Auth.ClientToken + in.Auth.ClientToken = in.Request.InboundSSCToken + defer func() { in.Auth.ClientToken = reqAuthToken }() } var retErr *multierror.Error @@ -364,7 +324,6 @@ func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput, hea } ret = retErr.ErrorOrNil() - failure := float32(0.0) if ret != nil { failure = 1.0 @@ -377,101 +336,74 @@ func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput, hea in.Request.Headers = headers }() - // Ensure at least one backend logs - if a.broker == nil { - anyLogged := false - for name, be := range a.backends { - in.Request.Headers = nil - transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend) - if thErr != nil { - a.logger.Error("backend failed to include headers", "backend", name, "error", thErr) - continue - } - in.Request.Headers = transHeaders + e, err := audit.NewEvent(audit.ResponseType) + if err != nil { + retErr = multierror.Append(retErr, err) + return retErr.ErrorOrNil() + } - start := time.Now() - lrErr := be.backend.LogResponse(ctx, in) - metrics.MeasureSince([]string{"audit", name, "log_response"}, start) - if lrErr != nil { - a.logger.Error("backend failed to log response", "backend", name, "error", lrErr) - } else { - anyLogged = true - } + e.Data = in + + // In cases where we are trying to audit the response, we detach + // ourselves from the original context (keeping only the namespace). + // This is so that we get a fair run at writing audit entries if Vault + // has taken up a lot of time handling the request before audit (response) + // is triggered. Pipeline nodes and the eventlogger.Broker may check for a + // cancelled context and refuse to process the nodes further. + ns, err := namespace.FromContext(ctx) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("namespace missing from context: %w", err)) + return retErr.ErrorOrNil() + } + + auditContext, auditCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer auditCancel() + auditContext = namespace.ContextWithNamespace(auditContext, ns) + + // There may be cases where only the fallback device was added but no other + // normal audit devices, so check if the broker had an audit based pipeline + // registered before trying to send to it. + var status eventlogger.Status + if a.broker.IsAnyPipelineRegistered(eventlogger.EventType(event.AuditType.String())) { + status, err = a.broker.Send(auditContext, eventlogger.EventType(event.AuditType.String()), e) + if err != nil { + retErr = multierror.Append(retErr, multierror.Append(err, status.Warnings...)) + return retErr.ErrorOrNil() } - if !anyLogged && len(a.backends) > 0 { - retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the response")) + } + + // Audit event ended up in at least 1 sink. + if len(status.CompleteSinks()) > 0 { + return retErr.ErrorOrNil() + } + + // There were errors from inside the pipeline and we didn't write to a sink. + if len(status.Warnings) > 0 { + retErr = multierror.Append(retErr, multierror.Append(errors.New("error during audit pipeline processing"), status.Warnings...)) + return retErr.ErrorOrNil() + } + + // If a fallback device is registered we can rely on that to 'catch all' + // and also the broker level guarantee for completed sinks. + if a.fallbackBroker.IsAnyPipelineRegistered(eventlogger.EventType(event.AuditType.String())) { + status, err = a.fallbackBroker.Send(auditContext, eventlogger.EventType(event.AuditType.String()), e) + if err != nil { + retErr = multierror.Append(retErr, multierror.Append(fmt.Errorf("auditing response to fallback device failed: %w", err), status.Warnings...)) } } else { - if len(a.backends) > 0 { - e, err := audit.NewEvent(audit.ResponseType) - if err != nil { - retErr = multierror.Append(retErr, err) - return retErr.ErrorOrNil() - } - - e.Data = in - - // In cases where we are trying to audit the response, we detach - // ourselves from the original context (keeping only the namespace). - // This is so that we get a fair run at writing audit entries if Vault - // Took up a lot of time handling the request before audit (response) - // is triggered. Pipeline nodes may check for a cancelled context and - // refuse to process the nodes further. - ns, err := namespace.FromContext(ctx) - if err != nil { - retErr = multierror.Append(retErr, fmt.Errorf("namespace missing from context: %w", err)) - return retErr.ErrorOrNil() - } - - auditContext, auditCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer auditCancel() - auditContext = namespace.ContextWithNamespace(auditContext, ns) - - // There may be cases where only the fallback device was added but no other - // normal audit devices, so check if the broker had an audit based pipeline - // registered before trying to send to it. - var status eventlogger.Status - if a.broker.IsAnyPipelineRegistered(eventlogger.EventType(event.AuditType.String())) { - status, err = a.broker.Send(auditContext, eventlogger.EventType(event.AuditType.String()), e) - if err != nil { - retErr = multierror.Append(retErr, multierror.Append(err, status.Warnings...)) - return retErr.ErrorOrNil() - } - } - - // Audit event ended up in at least 1 sink. - if len(status.CompleteSinks()) > 0 { - return retErr.ErrorOrNil() - } - - // There were errors from inside the pipeline and we didn't write to a sink. - if len(status.Warnings) > 0 { - retErr = multierror.Append(retErr, multierror.Append(errors.New("error during audit pipeline processing"), status.Warnings...)) - return retErr.ErrorOrNil() - } - - // If a fallback device is registered we can rely on that to 'catch all' - // and also the broker level guarantee for completed sinks. - if a.fallbackBroker.IsAnyPipelineRegistered(eventlogger.EventType(event.AuditType.String())) { - status, err = a.fallbackBroker.Send(auditContext, eventlogger.EventType(event.AuditType.String()), e) - if err != nil { - retErr = multierror.Append(retErr, multierror.Append(fmt.Errorf("auditing response to fallback device failed: %w", err), status.Warnings...)) - } - } else { - // This audit event won't make it to any devices, we class this as a 'miss' for auditing. - metrics.IncrCounter(audit.MetricLabelsFallbackMiss(), 1) - } - } + // This audit event won't make it to any devices, we class this as a 'miss' for auditing. + metrics.IncrCounter(audit.MetricLabelsFallbackMiss(), 1) } return retErr.ErrorOrNil() } -func (a *AuditBroker) Invalidate(ctx context.Context, key string) { - // For now, we ignore the key as this would only apply to salts. We just - // sort of brute force it on each one. +func (a *AuditBroker) Invalidate(ctx context.Context, _ string) { + // For now, we ignore the key as this would only apply to salts. + // We just sort of brute force it on each one. a.Lock() defer a.Unlock() + for _, be := range a.backends { be.backend.Invalidate(ctx) } diff --git a/vault/audit_broker_test.go b/vault/audit_broker_test.go index 18efaa5601..b4fd21a9cb 100644 --- a/vault/audit_broker_test.go +++ b/vault/audit_broker_test.go @@ -7,6 +7,10 @@ import ( "context" "crypto/sha256" "testing" + "time" + + "github.com/hashicorp/vault/builtin/audit/file" + "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/eventlogger" "github.com/hashicorp/vault/audit" @@ -42,7 +46,7 @@ func testAuditBackend(t *testing.T, path string, config map[string]string) audit MountPath: path, } - be, err := syslog.Factory(context.Background(), cfg, true, headersCfg) + be, err := syslog.Factory(context.Background(), cfg, headersCfg) require.NoError(t, err) require.NotNil(t, be) @@ -58,7 +62,7 @@ func testAuditBackend(t *testing.T, path string, config map[string]string) audit func TestAuditBroker_Register_SuccessThresholdSinks(t *testing.T) { t.Parallel() l := corehelpers.NewTestLogger(t) - a, err := NewAuditBroker(l, true) + a, err := NewAuditBroker(l) require.NoError(t, err) require.NotNil(t, a) @@ -100,7 +104,7 @@ func TestAuditBroker_Register_SuccessThresholdSinks(t *testing.T) { func TestAuditBroker_Deregister_SuccessThresholdSinks(t *testing.T) { t.Parallel() l := corehelpers.NewTestLogger(t) - a, err := NewAuditBroker(l, true) + a, err := NewAuditBroker(l) require.NoError(t, err) require.NotNil(t, a) @@ -147,7 +151,7 @@ func TestAuditBroker_Register_Fallback(t *testing.T) { t.Parallel() l := corehelpers.NewTestLogger(t) - a, err := NewAuditBroker(l, true) + a, err := NewAuditBroker(l) require.NoError(t, err) require.NotNil(t, a) @@ -168,7 +172,7 @@ func TestAuditBroker_Register_FallbackMultiple(t *testing.T) { t.Parallel() l := corehelpers.NewTestLogger(t) - a, err := NewAuditBroker(l, true) + a, err := NewAuditBroker(l) require.NoError(t, err) require.NotNil(t, a) @@ -194,7 +198,7 @@ func TestAuditBroker_Deregister_Fallback(t *testing.T) { t.Parallel() l := corehelpers.NewTestLogger(t) - a, err := NewAuditBroker(l, true) + a, err := NewAuditBroker(l) require.NoError(t, err) require.NotNil(t, a) @@ -225,7 +229,7 @@ func TestAuditBroker_Deregister_Multiple(t *testing.T) { t.Parallel() l := corehelpers.NewTestLogger(t) - a, err := NewAuditBroker(l, true) + a, err := NewAuditBroker(l) require.NoError(t, err) require.NotNil(t, a) @@ -242,7 +246,7 @@ func TestAuditBroker_Register_MultipleFails(t *testing.T) { t.Parallel() l := corehelpers.NewTestLogger(t) - a, err := NewAuditBroker(l, true) + a, err := NewAuditBroker(l) require.NoError(t, err) require.NotNil(t, a) @@ -256,3 +260,73 @@ func TestAuditBroker_Register_MultipleFails(t *testing.T) { require.Error(t, err) require.EqualError(t, err, "vault.(AuditBroker).Register: backend already registered 'b2-no-filter'") } + +// BenchmarkAuditBroker_File_Request_DevNull Attempts to register a single `file` +// audit device on the broker, which points at /dev/null. +// It will then attempt to benchmark how long it takes Vault to complete logging +// a request, this really only shows us how Vault can handle lots of calls to the +// broker to trigger the eventlogger pipelines that audit devices are configured as. +// Since we aren't writing anything to file or doing any I/O. +// This test used to live in the file package for the file backend, but once the +// move to eventlogger was complete, there wasn't a way to create a file backend +// and manually just write to the underlying file itself, the old code used to do +// formatting and writing all together, but we've split this up with eventlogger +// with different nodes in a pipeline (think 1 audit device:1 pipeline) each +// handling a responsibility, for example: +// filter nodes filter events, so you can select which ones make it to your audit log +// formatter nodes format the events (to JSON/JSONX and perform HMACing etc) +// sink nodes handle sending the formatted data to a file, syslog or socket. +func BenchmarkAuditBroker_File_Request_DevNull(b *testing.B) { + backendConfig := &audit.BackendConfig{ + Config: map[string]string{ + "path": "/dev/null", + }, + MountPath: "test", + SaltConfig: &salt.Config{}, + SaltView: &logical.InmemStorage{}, + } + + sink, err := file.Factory(context.Background(), backendConfig, nil) + require.NoError(b, err) + + broker, err := NewAuditBroker(nil) + require.NoError(b, err) + + err = broker.Register("test", sink, false) + require.NoError(b, err) + + in := &logical.LogInput{ + Auth: &logical.Auth{ + ClientToken: "foo", + Accessor: "bar", + EntityID: "foobarentity", + DisplayName: "testtoken", + NoDefaultPolicy: true, + Policies: []string{"root"}, + TokenType: logical.TokenTypeService, + }, + Request: &logical.Request{ + Operation: logical.UpdateOperation, + Path: "/foo", + Connection: &logical.Connection{ + RemoteAddr: "127.0.0.1", + }, + WrapInfo: &logical.RequestWrapInfo{ + TTL: 60 * time.Second, + }, + Headers: map[string][]string{ + "foo": {"bar"}, + }, + }, + } + + ctx := namespace.RootContext(context.Background()) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := broker.LogRequest(ctx, in); err != nil { + panic(err) + } + } + }) +} diff --git a/vault/audit_test.go b/vault/audit_test.go index 87f06f4a57..e58764cb24 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -12,24 +12,22 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - - "github.com/hashicorp/vault/helper/testhelpers/corehelpers" - "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/helper/testhelpers/corehelpers" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/logical" "github.com/mitchellh/copystructure" + "github.com/stretchr/testify/require" ) func TestAudit_ReadOnlyViewDuringMount(t *testing.T) { c, _, _ := TestCoreUnsealed(t) - c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { + c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ audit.HeaderFormatter) (audit.Backend, error) { err := config.SaltView.Put(ctx, &logical.StorageEntry{ Key: "bar", Value: []byte("baz"), @@ -38,7 +36,7 @@ func TestAudit_ReadOnlyViewDuringMount(t *testing.T) { t.Fatalf("expected a read-only error") } factory := corehelpers.NoopAuditFactory(nil) - return factory(ctx, config, false, nil) + return factory(ctx, config, nil) } me := &MountEntry{ @@ -105,7 +103,7 @@ func TestCore_EnableAudit(t *testing.T) { func TestCore_EnableAudit_MixedFailures(t *testing.T) { c, _, _ := TestCoreUnsealed(t) c.auditBackends["noop"] = corehelpers.NoopAuditFactory(nil) - c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { + c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ audit.HeaderFormatter) (audit.Backend, error) { return nil, fmt.Errorf("failing enabling") } @@ -154,7 +152,7 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) { func TestCore_EnableAudit_Local(t *testing.T) { c, _, _ := TestCoreUnsealed(t) c.auditBackends["noop"] = corehelpers.NoopAuditFactory(nil) - c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { + c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ audit.HeaderFormatter) (audit.Backend, error) { return nil, fmt.Errorf("failing enabling") } @@ -405,7 +403,7 @@ func verifyDefaultAuditTable(t *testing.T, table *MountTable) { func TestAuditBroker_LogRequest(t *testing.T) { l := logging.NewVaultLogger(log.Trace) - b, err := NewAuditBroker(l, true) + b, err := NewAuditBroker(l) if err != nil { t.Fatal(err) } @@ -451,17 +449,13 @@ func TestAuditBroker_LogRequest(t *testing.T) { reqErrs := errors.New("errs") - headersConf := &AuditedHeadersConfig{ - Headers: make(map[string]*auditedHeaderSettings), - } - logInput := &logical.LogInput{ Auth: authCopy, Request: reqCopy, OuterErr: reqErrs, } ctx := namespace.RootContext(context.Background()) - err = b.LogRequest(ctx, logInput, headersConf) + err = b.LogRequest(ctx, logInput) if err != nil { t.Fatalf("err: %v", err) } @@ -484,20 +478,20 @@ func TestAuditBroker_LogRequest(t *testing.T) { Auth: auth, Request: req, } - if err := b.LogRequest(ctx, logInput, headersConf); err != nil { + if err := b.LogRequest(ctx, logInput); err != nil { t.Fatalf("err: %v", err) } // Should FAIL work with both failing backends a2.ReqErr = fmt.Errorf("failed") - if err := b.LogRequest(ctx, logInput, headersConf); !errwrap.Contains(err, "event not processed by enough 'sink' nodes") { + if err := b.LogRequest(ctx, logInput); !errwrap.Contains(err, "event not processed by enough 'sink' nodes") { t.Fatalf("err: %v", err) } } func TestAuditBroker_LogResponse(t *testing.T) { l := logging.NewVaultLogger(log.Trace) - b, err := NewAuditBroker(l, true) + b, err := NewAuditBroker(l) if err != nil { t.Fatal(err) } @@ -553,10 +547,6 @@ func TestAuditBroker_LogResponse(t *testing.T) { } respCopy := respCopyRaw.(*logical.Response) - headersConf := &AuditedHeadersConfig{ - Headers: make(map[string]*auditedHeaderSettings), - } - logInput := &logical.LogInput{ Auth: authCopy, Request: reqCopy, @@ -564,7 +554,7 @@ func TestAuditBroker_LogResponse(t *testing.T) { OuterErr: respErr, } ctx := namespace.RootContext(context.Background()) - err = b.LogResponse(ctx, logInput, headersConf) + err = b.LogResponse(ctx, logInput) if err != nil { t.Fatalf("err: %v", err) } @@ -592,12 +582,12 @@ func TestAuditBroker_LogResponse(t *testing.T) { Response: resp, OuterErr: respErr, } - err = b.LogResponse(ctx, logInput, headersConf) + err = b.LogResponse(ctx, logInput) require.NoError(t, err) // Should FAIL work with both failing backends a2.RespErr = fmt.Errorf("failed") - err = b.LogResponse(ctx, logInput, headersConf) + err = b.LogResponse(ctx, logInput) require.Error(t, err) require.ErrorContains(t, err, "event not processed by enough 'sink' nodes") } @@ -605,7 +595,7 @@ func TestAuditBroker_LogResponse(t *testing.T) { func TestAuditBroker_AuditHeaders(t *testing.T) { logger := logging.NewVaultLogger(log.Trace) - b, err := NewAuditBroker(logger, true) + b, err := NewAuditBroker(logger) if err != nil { t.Fatal(err) } @@ -660,7 +650,7 @@ func TestAuditBroker_AuditHeaders(t *testing.T) { OuterErr: respErr, } ctx := namespace.RootContext(context.Background()) - err = b.LogRequest(ctx, logInput, nil) + err = b.LogRequest(ctx, logInput) if err != nil { t.Fatalf("err: %v", err) } @@ -683,14 +673,14 @@ func TestAuditBroker_AuditHeaders(t *testing.T) { Request: req, OuterErr: respErr, } - err = b.LogRequest(ctx, logInput, headersConf) + err = b.LogRequest(ctx, logInput) if err != nil { t.Fatalf("err: %v", err) } // Should FAIL work with both failing backends a2.ReqErr = fmt.Errorf("failed") - err = b.LogRequest(ctx, logInput, headersConf) + err = b.LogRequest(ctx, logInput) if !errwrap.Contains(err, "event not processed by enough 'sink' nodes") { t.Fatalf("err: %v", err) } diff --git a/vault/core.go b/vault/core.go index 542a62b249..22f19fc5b0 100644 --- a/vault/core.go +++ b/vault/core.go @@ -28,8 +28,6 @@ import ( "sync/atomic" "time" - kv "github.com/hashicorp/vault-plugin-secrets-kv" - "github.com/armon/go-metrics" "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" @@ -38,11 +36,11 @@ import ( "github.com/hashicorp/go-kms-wrapping/wrappers/awskms/v2" "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-secure-stdlib/mlock" - "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/reloadutil" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-secure-stdlib/tlsutil" "github.com/hashicorp/go-uuid" + kv "github.com/hashicorp/vault-plugin-secrets-kv" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/command/server" @@ -2177,7 +2175,7 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr Auth: auth, Request: req, } - if err := c.auditBroker.LogRequest(ctx, logInput, c.auditedHeaders); err != nil { + if err := c.auditBroker.LogRequest(ctx, logInput); err != nil { c.logger.Error("failed to audit request", "request_path", req.Path, "error", err) return errors.New("failed to audit request, cannot continue") } @@ -2435,15 +2433,11 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c return err } } else { - var err error - disableEventLogger, err := parseutil.ParseBool(os.Getenv(featureFlagDisableEventLogger)) - if err != nil { - return fmt.Errorf("unable to parse feature flag: %q: %w", featureFlagDisableEventLogger, err) - } - c.auditBroker, err = NewAuditBroker(logger, !disableEventLogger) + broker, err := NewAuditBroker(logger) if err != nil { return err } + c.auditBroker = broker } if c.isPrimary() { diff --git a/vault/core_test.go b/vault/core_test.go index 3f16e825ac..d6d1d01b94 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -1467,7 +1467,7 @@ func TestCore_HandleRequest_AuditTrail(t *testing.T) { // Create a noop audit backend var noop *corehelpers.NoopAudit c, _, root := TestCoreUnsealed(t) - c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, headerFormatter audit.HeaderFormatter) (audit.Backend, error) { + c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, headerFormatter audit.HeaderFormatter) (audit.Backend, error) { var err error noop, err = corehelpers.NewNoopAudit(config, audit.WithHeaderFormatter(headerFormatter)) return noop, err @@ -1530,7 +1530,7 @@ func TestCore_HandleRequest_AuditTrail_noHMACKeys(t *testing.T) { // Create a noop audit backend var noop *corehelpers.NoopAudit c, _, root := TestCoreUnsealed(t) - c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, headerFormatter audit.HeaderFormatter) (audit.Backend, error) { + c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, headerFormatter audit.HeaderFormatter) (audit.Backend, error) { var err error noop, err = corehelpers.NewNoopAudit(config, audit.WithHeaderFormatter(headerFormatter)) return noop, err @@ -1651,7 +1651,7 @@ func TestCore_HandleLogin_AuditTrail(t *testing.T) { c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { return noopBack, nil } - c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, headerFormatter audit.HeaderFormatter) (audit.Backend, error) { + c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, headerFormatter audit.HeaderFormatter) (audit.Backend, error) { var err error noop, err = corehelpers.NewNoopAudit(config, audit.WithHeaderFormatter(headerFormatter)) return noop, err diff --git a/vault/external_tests/identity/login_mfa_totp_test.go b/vault/external_tests/identity/login_mfa_totp_test.go index 8e6c476d72..b5794b2b85 100644 --- a/vault/external_tests/identity/login_mfa_totp_test.go +++ b/vault/external_tests/identity/login_mfa_totp_test.go @@ -61,7 +61,7 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { "totp": totp.Factory, }, AuditBackends: map[string]audit.Factory{ - "noop": func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { + "noop": func(ctx context.Context, config *audit.BackendConfig, _ audit.HeaderFormatter) (audit.Backend, error) { return noop, nil }, }, diff --git a/vault/ha.go b/vault/ha.go index e81787e7e9..536142e579 100644 --- a/vault/ha.go +++ b/vault/ha.go @@ -347,7 +347,7 @@ func (c *Core) StepDown(httpCtx context.Context, req *logical.Request) (retErr e Auth: auth, Request: req, } - if err := c.auditBroker.LogRequest(ctx, logInput, c.auditedHeaders); err != nil { + if err := c.auditBroker.LogRequest(ctx, logInput); err != nil { c.logger.Error("failed to audit request", "request_path", req.Path, "error", err) return errors.New("failed to audit request, cannot continue") } diff --git a/vault/request_handling.go b/vault/request_handling.go index 034146b5c0..12c8ac2778 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -25,8 +25,6 @@ import ( "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/go-uuid" - uberAtomic "go.uber.org/atomic" - "github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/identity/mfa" @@ -42,6 +40,7 @@ import ( "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/quotas" "github.com/hashicorp/vault/vault/tokens" + uberAtomic "go.uber.org/atomic" ) const ( @@ -898,7 +897,7 @@ func (c *Core) handleCancelableRequest(ctx context.Context, req *logical.Request NonHMACReqDataKeys: nonHMACReqDataKeys, NonHMACRespDataKeys: nonHMACRespDataKeys, } - if auditErr := c.auditBroker.LogResponse(ctx, logInput, c.auditedHeaders); auditErr != nil { + if auditErr := c.auditBroker.LogResponse(ctx, logInput); auditErr != nil { c.logger.Error("failed to audit response", "request_path", req.Path, "error", auditErr) return nil, ErrInternalError } @@ -1088,7 +1087,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp OuterErr: ctErr, NonHMACReqDataKeys: nonHMACReqDataKeys, } - if err := c.auditBroker.LogRequest(ctx, logInput, c.auditedHeaders); err != nil { + if err := c.auditBroker.LogRequest(ctx, logInput); err != nil { c.logger.Error("failed to audit request", "path", req.Path, "error", err) } } @@ -1109,7 +1108,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp Request: req, NonHMACReqDataKeys: nonHMACReqDataKeys, } - if err := c.auditBroker.LogRequest(ctx, logInput, c.auditedHeaders); err != nil { + if err := c.auditBroker.LogRequest(ctx, logInput); err != nil { c.logger.Error("failed to audit request", "path", req.Path, "error", err) retErr = multierror.Append(retErr, ErrInternalError) return nil, auth, retErr @@ -1451,7 +1450,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re OuterErr: ctErr, NonHMACReqDataKeys: nonHMACReqDataKeys, } - if err := c.auditBroker.LogRequest(ctx, logInput, c.auditedHeaders); err != nil { + if err := c.auditBroker.LogRequest(ctx, logInput); err != nil { c.logger.Error("failed to audit request", "path", req.Path, "error", err) return nil, nil, ErrInternalError } @@ -1475,7 +1474,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re Request: req, NonHMACReqDataKeys: nonHMACReqDataKeys, } - if err := c.auditBroker.LogRequest(ctx, logInput, c.auditedHeaders); err != nil { + if err := c.auditBroker.LogRequest(ctx, logInput); err != nil { c.logger.Error("failed to audit request", "path", req.Path, "error", err) return nil, nil, ErrInternalError } diff --git a/vault/wrapping.go b/vault/wrapping.go index 1be5542da8..accca9a65e 100644 --- a/vault/wrapping.go +++ b/vault/wrapping.go @@ -376,7 +376,7 @@ func (c *Core) validateWrappingToken(ctx context.Context, req *logical.Request) if !valid { logInput.OuterErr = consts.ErrInvalidWrappingToken } - if err := c.auditBroker.LogRequest(ctx, logInput, c.auditedHeaders); err != nil { + if err := c.auditBroker.LogRequest(ctx, logInput); err != nil { c.logger.Error("failed to audit request", "path", req.Path, "error", err) } }