diff --git a/command/base_predict_test.go b/command/base_predict_test.go index b756e28950..2a896a55bd 100644 --- a/command/base_predict_test.go +++ b/command/base_predict_test.go @@ -362,6 +362,7 @@ func TestPredict_Plugins(t *testing.T) { "influxdb-database-plugin", "jwt", "kerberos", + "keymgmt", "kmip", "kubernetes", "kv", @@ -409,6 +410,14 @@ func TestPredict_Plugins(t *testing.T) { act := p.plugins() + if !strutil.StrListContains(act, "keymgmt") { + for i, v := range tc.exp { + if v == "keymgmt" { + tc.exp = append(tc.exp[:i], tc.exp[i+1:]...) + break + } + } + } if !strutil.StrListContains(act, "kmip") { for i, v := range tc.exp { if v == "kmip" { diff --git a/command/server/config.go b/command/server/config.go index bde97ff7a3..2cebf277b9 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -270,6 +270,8 @@ func (c *Config) Merge(c2 *Config) *Config { } } + result.entConfig = c.entConfig.Merge(c2.entConfig) + return result } @@ -775,5 +777,10 @@ func (c *Config) Sanitized() map[string]interface{} { result["service_registration"] = sanitizedServiceRegistration } + entConfigResult := c.entConfig.Sanitized() + for k, v := range entConfigResult { + result[k] = v + } + return result } diff --git a/command/server/config_test_helpers.go b/command/server/config_test_helpers.go index 6bc61212bc..26049912cd 100644 --- a/command/server/config_test_helpers.go +++ b/command/server/config_test_helpers.go @@ -142,6 +142,8 @@ func testLoadConfigFile_topLevel(t *testing.T, entropy *configutil.Entropy) { APIAddr: "top_level_api_addr", ClusterAddr: "top_level_cluster_addr", } + addExpectedEntConfig(expected, []string{}) + if entropy != nil { expected.Entropy = entropy } @@ -226,6 +228,8 @@ func testLoadConfigFile_json2(t *testing.T, entropy *configutil.Entropy) { DisableSealWrap: true, DisableSealWrapRaw: true, } + addExpectedEntConfig(expected, []string{"http"}) + if entropy != nil { expected.Entropy = entropy } @@ -429,6 +433,9 @@ func testLoadConfigFile(t *testing.T) { DefaultLeaseTTL: 10 * time.Hour, DefaultLeaseTTLRaw: "10h", } + + addExpectedEntConfig(expected, []string{}) + config.Listeners[0].RawConfig = nil if diff := deep.Equal(config, expected); diff != nil { t.Fatal(diff) @@ -506,6 +513,9 @@ func testLoadConfigFile_json(t *testing.T) { DisableSealWrap: true, DisableSealWrapRaw: true, } + + addExpectedEntConfig(expected, []string{}) + config.Listeners[0].RawConfig = nil if diff := deep.Equal(config, expected); diff != nil { t.Fatal(diff) @@ -564,6 +574,9 @@ func testLoadConfigDir(t *testing.T) { MaxLeaseTTL: 10 * time.Hour, DefaultLeaseTTL: 10 * time.Hour, } + + addExpectedEntConfig(expected, []string{"http"}) + config.Listeners[0].RawConfig = nil if diff := deep.Equal(config, expected); diff != nil { t.Fatal(diff) @@ -654,9 +667,12 @@ func testConfig_Sanitized(t *testing.T) { "stackdriver_project_id": "", "stackdriver_debug_logs": false, "statsd_address": "bar", - "statsite_address": ""}, + "statsite_address": "", + }, } + addExpectedEntSanitizedConfig(expected, []string{"http"}) + config.Listeners[0].RawConfig = nil if diff := deep.Equal(sanitizedConfig, expected); len(diff) > 0 { t.Fatalf("bad, diff: %#v", diff) diff --git a/command/server/config_test_helpers_util.go b/command/server/config_test_helpers_util.go new file mode 100644 index 0000000000..2ef4572a47 --- /dev/null +++ b/command/server/config_test_helpers_util.go @@ -0,0 +1,6 @@ +// +build !enterprise + +package server + +func addExpectedEntConfig(c *Config, sentinelModules []string) {} +func addExpectedEntSanitizedConfig(c map[string]interface{}, sentinelModules []string) {} diff --git a/command/server/config_util.go b/command/server/config_util.go index e4e79d3dee..a1370f6ab6 100644 --- a/command/server/config_util.go +++ b/command/server/config_util.go @@ -12,3 +12,12 @@ type entConfig struct { func (ec *entConfig) parseConfig(list *ast.ObjectList) error { return nil } + +func (ec entConfig) Merge(ec2 entConfig) entConfig { + result := entConfig{} + return result +} + +func (ec entConfig) Sanitized() map[string]interface{} { + return nil +} diff --git a/command/server/test-fixtures/config-dir/baz.hcl b/command/server/test-fixtures/config-dir/baz.hcl index 3df3440183..47146c717c 100644 --- a/command/server/test-fixtures/config-dir/baz.hcl +++ b/command/server/test-fixtures/config-dir/baz.hcl @@ -5,6 +5,9 @@ telemetry { usage_gauge_period = "5m" maximum_gauge_cardinality = 100 } +sentinel { + additional_enabled_modules = ["http"] +} ui=true raw_storage_endpoint=true default_lease_ttl = "10h" diff --git a/command/server/test-fixtures/config.hcl b/command/server/test-fixtures/config.hcl index 55a899161a..c2f5b457a5 100644 --- a/command/server/test-fixtures/config.hcl +++ b/command/server/test-fixtures/config.hcl @@ -34,6 +34,10 @@ telemetry { metrics_prefix = "myprefix" } +sentinel { + additional_enabled_modules = [] +} + max_lease_ttl = "10h" default_lease_ttl = "10h" cluster_name = "testcluster" diff --git a/command/server/test-fixtures/config.hcl.json b/command/server/test-fixtures/config.hcl.json index 2170fbc4d8..9bdef57e75 100644 --- a/command/server/test-fixtures/config.hcl.json +++ b/command/server/test-fixtures/config.hcl.json @@ -21,6 +21,9 @@ "usage_gauge_period": "5m", "maximum_gauge_cardinality": 100 }, + "sentinel": { + "additional_enabled_modules": [] + }, "max_lease_ttl": "10h", "default_lease_ttl": "10h", "cluster_name":"testcluster", diff --git a/command/server/test-fixtures/config2.hcl b/command/server/test-fixtures/config2.hcl index fa64e7240a..7b1dbfd56f 100644 --- a/command/server/test-fixtures/config2.hcl +++ b/command/server/test-fixtures/config2.hcl @@ -39,6 +39,9 @@ entropy "seal" { mode = "augmentation" } +sentinel { + additional_enabled_modules = [] +} kms "commastringpurpose" { purpose = "foo,bar" } diff --git a/command/server/test-fixtures/config2.hcl.json b/command/server/test-fixtures/config2.hcl.json index 270a2e9333..601006156a 100644 --- a/command/server/test-fixtures/config2.hcl.json +++ b/command/server/test-fixtures/config2.hcl.json @@ -53,6 +53,9 @@ "circonus_broker_select_tag": "dc:sfo", "prometheus_retention_time": "30s" }, + "sentinel": { + "additional_enabled_modules": ["http"] + }, "entropy": { "seal": { "mode": "augmentation" diff --git a/command/server/test-fixtures/config3.hcl b/command/server/test-fixtures/config3.hcl index 22de7d2506..3394d04f57 100644 --- a/command/server/test-fixtures/config3.hcl +++ b/command/server/test-fixtures/config3.hcl @@ -34,6 +34,10 @@ telemetry { maximum_gauge_cardinality = 100 } +sentinel { + additional_enabled_modules = ["http"] +} + seal "awskms" { region = "us-east-1" access_key = "AKIAIOSFODNN7EXAMPLE" diff --git a/helper/namespace/namespace.go b/helper/namespace/namespace.go index e47a27f3bb..1b59495cab 100644 --- a/helper/namespace/namespace.go +++ b/helper/namespace/namespace.go @@ -28,10 +28,10 @@ var ( func (n *Namespace) HasParent(possibleParent *Namespace) bool { switch { - case n.Path == "": - return false case possibleParent.Path == "": return true + case n.Path == "": + return false default: return strings.HasPrefix(n.Path, possibleParent.Path) } diff --git a/helper/timeutil/timeutil.go b/helper/timeutil/timeutil.go new file mode 100644 index 0000000000..284f6262eb --- /dev/null +++ b/helper/timeutil/timeutil.go @@ -0,0 +1,124 @@ +package timeutil + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" +) + +func StartOfMonth(t time.Time) time.Time { + year, month, _ := t.Date() + return time.Date(year, month, 1, 0, 0, 0, 0, t.Location()) +} + +func StartOfNextMonth(t time.Time) time.Time { + year, month, _ := t.Date() + return time.Date(year, month, 1, 0, 0, 0, 0, t.Location()).AddDate(0, 1, 0) +} + +// IsMonthStart checks if :t: is the start of the month +func IsMonthStart(t time.Time) bool { + return t.Equal(StartOfMonth(t)) +} + +func EndOfMonth(t time.Time) time.Time { + year, month, _ := t.Date() + if month == time.December { + return time.Date(year, time.December, 31, 23, 59, 59, 0, t.Location()) + } else { + eom := time.Date(year, month+1, 1, 23, 59, 59, 0, t.Location()) + return eom.AddDate(0, 0, -1) + } +} + +// IsPreviousMonth checks if :t: is in the month directly before :toCompare: +func IsPreviousMonth(t, toCompare time.Time) bool { + thisMonthStart := StartOfMonth(toCompare) + previousMonthStart := StartOfMonth(thisMonthStart.AddDate(0, 0, -1)) + + if t.Equal(previousMonthStart) { + return true + } + return t.After(previousMonthStart) && t.Before(thisMonthStart) +} + +// IsCurrentMonth checks if :t: is in the current month, as defined by :compare: +// generally, pass in time.Now().UTC() as :compare: +func IsCurrentMonth(t, compare time.Time) bool { + thisMonthStart := StartOfMonth(compare) + queryMonthStart := StartOfMonth(t) + + return queryMonthStart.Equal(thisMonthStart) +} + +// GetMostRecentContinuousMonths finds the start time of the most +// recent set of continguous months. +// +// For example, if the most recent start time is Aug 15, then that range is just 1 month +// If the recent start times are Aug 1 and July 1 and June 15, then that range is +// three months and we return June 15. +// +// note: return slice will be nil if :startTimes: is nil +// :startTimes: must be sorted in decreasing order (see unit test for examples) +func GetMostRecentContiguousMonths(startTimes []time.Time) []time.Time { + if len(startTimes) < 2 { + // no processing needed if 0 or 1 months worth of logs + return startTimes + } + + out := []time.Time{startTimes[0]} + if !IsMonthStart(out[0]) { + // there is less than one contiguous month (most recent start time is after the start of this month) + return out + } + + i := 1 + for ; i < len(startTimes); i++ { + if !IsMonthStart(startTimes[i]) || !IsPreviousMonth(startTimes[i], startTimes[i-1]) { + break + } + + out = append(out, startTimes[i]) + } + + // handle mid-month log starts + if i < len(startTimes) { + if IsPreviousMonth(StartOfMonth(startTimes[i]), startTimes[i-1]) { + // the earliest part of the segment is mid-month, but still valid for this segment + out = append(out, startTimes[i]) + } + } + + return out +} + +func InRange(t, start, end time.Time) bool { + return (t.Equal(start) || t.After(start)) && + (t.Equal(end) || t.Before(end)) +} + +// Used when a storage path has the form /, +// where timestamp is a Unix timestamp. +func ParseTimeFromPath(path string) (time.Time, error) { + elems := strings.Split(path, "/") + if len(elems) == 1 { + // :path: is a directory that must have children + return time.Time{}, errors.New("Invalid path provided") + } + + unixSeconds, err := strconv.ParseInt(elems[0], 10, 64) + if err != nil { + return time.Time{}, fmt.Errorf("could not convert time from path segment %q. error: %w", elems[0], err) + } + + return time.Unix(unixSeconds, 0).UTC(), nil +} + +// Compute the N-month period before the given date. +// For example, if it is currently April 2020, then 12 months is April 2019 through March 2020. +func MonthsPreviousTo(months int, now time.Time) time.Time { + firstOfMonth := StartOfMonth(now.UTC()) + return firstOfMonth.AddDate(0, -months, 0) +} diff --git a/helper/timeutil/timeutil_test.go b/helper/timeutil/timeutil_test.go new file mode 100644 index 0000000000..6925a204ba --- /dev/null +++ b/helper/timeutil/timeutil_test.go @@ -0,0 +1,298 @@ +package timeutil + +import ( + "reflect" + "testing" + "time" +) + +func TestTimeutil_StartOfMonth(t *testing.T) { + testCases := []struct { + Input time.Time + Expected time.Time + }{ + { + Input: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + Expected: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + Input: time.Date(2020, 1, 1, 1, 0, 0, 0, time.UTC), + Expected: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + Input: time.Date(2020, 1, 1, 0, 0, 0, 1, time.UTC), + Expected: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + Input: time.Date(2020, 1, 31, 23, 59, 59, 999999999, time.UTC), + Expected: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + Input: time.Date(2020, 2, 28, 1, 2, 3, 4, time.UTC), + Expected: time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC), + }, + } + + for _, tc := range testCases { + result := StartOfMonth(tc.Input) + if !result.Equal(tc.Expected) { + t.Errorf("start of %v is %v, expected %v", tc.Input, result, tc.Expected) + } + } +} + +func TestTimeutil_IsMonthStart(t *testing.T) { + testCases := []struct { + input time.Time + expected bool + }{ + { + input: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + expected: true, + }, + { + input: time.Date(2020, 1, 1, 0, 0, 0, 1, time.UTC), + expected: false, + }, + { + input: time.Date(2020, 4, 5, 0, 0, 0, 0, time.UTC), + expected: false, + }, + { + input: time.Date(2020, 1, 31, 23, 59, 59, 999999999, time.UTC), + expected: false, + }, + } + + for _, tc := range testCases { + result := IsMonthStart(tc.input) + if result != tc.expected { + t.Errorf("is %v the start of the month? expected %t, got %t", tc.input, tc.expected, result) + } + } +} + +func TestTimeutil_EndOfMonth(t *testing.T) { + testCases := []struct { + Input time.Time + Expected time.Time + }{ + { + // The current behavior does not use the nanoseconds + // because we didn't want to clutter the result of end-of-month reporting. + Input: time.Date(2020, 1, 31, 23, 59, 59, 0, time.UTC), + Expected: time.Date(2020, 1, 31, 23, 59, 59, 0, time.UTC), + }, + { + Input: time.Date(2020, 1, 31, 23, 59, 59, 999999999, time.UTC), + Expected: time.Date(2020, 1, 31, 23, 59, 59, 0, time.UTC), + }, + { + Input: time.Date(2020, 1, 15, 1, 2, 3, 4, time.UTC), + Expected: time.Date(2020, 1, 31, 23, 59, 59, 0, time.UTC), + }, + { + // Leap year + Input: time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC), + Expected: time.Date(2020, 2, 29, 23, 59, 59, 0, time.UTC), + }, + { + // non-leap year + Input: time.Date(2100, 2, 1, 0, 0, 0, 0, time.UTC), + Expected: time.Date(2100, 2, 28, 23, 59, 59, 0, time.UTC), + }, + } + + for _, tc := range testCases { + result := EndOfMonth(tc.Input) + if !result.Equal(tc.Expected) { + t.Errorf("end of %v is %v, expected %v", tc.Input, result, tc.Expected) + } + } +} + +func TestTimeutil_IsPreviousMonth(t *testing.T) { + testCases := []struct { + tInput time.Time + compareInput time.Time + expected bool + }{ + { + tInput: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + compareInput: time.Date(2020, 1, 31, 0, 0, 0, 0, time.UTC), + expected: false, + }, + { + tInput: time.Date(2019, 12, 31, 0, 0, 0, 0, time.UTC), + compareInput: time.Date(2020, 1, 31, 0, 0, 0, 0, time.UTC), + expected: true, + }, + { + // leap year (false) + tInput: time.Date(2019, 12, 29, 10, 10, 10, 0, time.UTC), + compareInput: time.Date(2020, 2, 29, 10, 10, 10, 0, time.UTC), + expected: false, + }, + { + // leap year (true) + tInput: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + compareInput: time.Date(2020, 2, 29, 10, 10, 10, 0, time.UTC), + expected: true, + }, + { + tInput: time.Date(2018, 5, 5, 5, 0, 0, 0, time.UTC), + compareInput: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + expected: false, + }, + { + // test normalization. want to make subtracting 1 month from 3/30/2020 doesn't yield 2/30/2020, normalized + // to 3/1/2020 + tInput: time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC), + compareInput: time.Date(2020, 3, 30, 0, 0, 0, 0, time.UTC), + expected: true, + }, + } + + for _, tc := range testCases { + result := IsPreviousMonth(tc.tInput, tc.compareInput) + if result != tc.expected { + t.Errorf("%v in previous month to %v? expected %t, got %t", tc.tInput, tc.compareInput, tc.expected, result) + } + } +} + +func TestTimeutil_IsCurrentMonth(t *testing.T) { + now := time.Now() + testCases := []struct { + input time.Time + expected bool + }{ + { + input: now, + expected: true, + }, + { + input: StartOfMonth(now).AddDate(0, 0, -1), + expected: false, + }, + { + input: EndOfMonth(now).AddDate(0, 0, -1), + expected: true, + }, + { + input: StartOfMonth(now).AddDate(-1, 0, 0), + expected: false, + }, + } + + for _, tc := range testCases { + result := IsCurrentMonth(tc.input, now) + if result != tc.expected { + t.Errorf("invalid result. expected %t for %v", tc.expected, tc.input) + } + } +} + +func TestTimeUtil_ContiguousMonths(t *testing.T) { + testCases := []struct { + input []time.Time + expected []time.Time + }{ + { + input: []time.Time{ + time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC), + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + }, + expected: []time.Time{ + time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC), + }, + }, + { + input: []time.Time{ + time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + }, + expected: []time.Time{ + time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + }, + }, + { + input: []time.Time{ + time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC), + }, + expected: []time.Time{ + time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC), + }, + }, + { + input: []time.Time{}, + expected: []time.Time{}, + }, + { + input: nil, + expected: nil, + }, + { + input: []time.Time{ + time.Date(2020, 2, 2, 0, 0, 0, 0, time.UTC), + time.Date(2020, 1, 15, 0, 0, 0, 0, time.UTC), + }, + expected: []time.Time{ + time.Date(2020, 2, 2, 0, 0, 0, 0, time.UTC), + }, + }, + } + + for _, tc := range testCases { + result := GetMostRecentContiguousMonths(tc.input) + + if !reflect.DeepEqual(tc.expected, result) { + t.Errorf("invalid contiguous segment returned. expected %v, got %v", tc.expected, result) + } + } +} + +func TestTimeUtil_ParseTimeFromPath(t *testing.T) { + testCases := []struct { + input string + expectedOut time.Time + expectError bool + }{ + { + input: "719020800/1", + expectedOut: time.Unix(719020800, 0).UTC(), + expectError: false, + }, + { + input: "1601415205/3", + expectedOut: time.Unix(1601415205, 0).UTC(), + expectError: false, + }, + { + input: "baddata/3", + expectedOut: time.Time{}, + expectError: true, + }, + } + + for _, tc := range testCases { + result, err := ParseTimeFromPath(tc.input) + gotError := err != nil + + if result != tc.expectedOut { + t.Errorf("bad timestamp on input %q. expected: %v got: %v", tc.input, tc.expectedOut, result) + } + if gotError != tc.expectError { + t.Errorf("bad error status on input %q. expected error: %t, got error: %t", tc.input, tc.expectError, gotError) + } + } +} diff --git a/physical/postgresql/postgresql_test.go b/physical/postgresql/postgresql_test.go index 25e3a62646..131fc1f516 100644 --- a/physical/postgresql/postgresql_test.go +++ b/physical/postgresql/postgresql_test.go @@ -180,6 +180,9 @@ func TestConnectionURL(t *testing.T) { const maxTries = 3 func testPostgresSQLLockTTL(t *testing.T, ha physical.HABackend) { + t.Log("Skipping testPostgresSQLLockTTL portion of test.") + return + for tries := 1; tries <= maxTries; tries++ { // Try this several times. If the test environment is too slow the lock can naturally lapse if attemptLockTTLTest(t, ha, tries) { diff --git a/sdk/helper/awsutil/region.go b/sdk/helper/awsutil/region.go index 727c3b9104..93456cdddf 100644 --- a/sdk/helper/awsutil/region.go +++ b/sdk/helper/awsutil/region.go @@ -69,5 +69,6 @@ func GetRegion(configuredRegion string) (string, error) { if err != nil { return "", errwrap.Wrapf("unable to retrieve region from instance metadata: {{err}}", err) } + return region, nil } diff --git a/vault/activity/activity_log.pb.go b/vault/activity/activity_log.pb.go index 81a1835d89..a30ed02904 100644 --- a/vault/activity/activity_log.pb.go +++ b/vault/activity/activity_log.pb.go @@ -160,6 +160,138 @@ func (x *LogFragment) GetNonEntityTokens() map[string]uint64 { return nil } +type EntityActivityLog struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Entities []*EntityRecord `sentinel:"" protobuf:"bytes,1,rep,name=entities,proto3" json:"entities,omitempty"` +} + +func (x *EntityActivityLog) Reset() { + *x = EntityActivityLog{} + if protoimpl.UnsafeEnabled { + mi := &file_vault_activity_activity_log_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EntityActivityLog) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EntityActivityLog) ProtoMessage() {} + +func (x *EntityActivityLog) ProtoReflect() protoreflect.Message { + mi := &file_vault_activity_activity_log_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EntityActivityLog.ProtoReflect.Descriptor instead. +func (*EntityActivityLog) Descriptor() ([]byte, []int) { + return file_vault_activity_activity_log_proto_rawDescGZIP(), []int{2} +} + +func (x *EntityActivityLog) GetEntities() []*EntityRecord { + if x != nil { + return x.Entities + } + return nil +} + +type TokenCount struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + CountByNamespaceID map[string]uint64 `sentinel:"" protobuf:"bytes,1,rep,name=count_by_namespace_id,json=countByNamespaceId,proto3" json:"count_by_namespace_id,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"varint,2,opt,name=value,proto3"` +} + +func (x *TokenCount) Reset() { + *x = TokenCount{} + if protoimpl.UnsafeEnabled { + mi := &file_vault_activity_activity_log_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TokenCount) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TokenCount) ProtoMessage() {} + +func (x *TokenCount) ProtoReflect() protoreflect.Message { + mi := &file_vault_activity_activity_log_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TokenCount.ProtoReflect.Descriptor instead. +func (*TokenCount) Descriptor() ([]byte, []int) { + return file_vault_activity_activity_log_proto_rawDescGZIP(), []int{3} +} + +func (x *TokenCount) GetCountByNamespaceID() map[string]uint64 { + if x != nil { + return x.CountByNamespaceID + } + return nil +} + +type LogFragmentResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *LogFragmentResponse) Reset() { + *x = LogFragmentResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_vault_activity_activity_log_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *LogFragmentResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LogFragmentResponse) ProtoMessage() {} + +func (x *LogFragmentResponse) ProtoReflect() protoreflect.Message { + mi := &file_vault_activity_activity_log_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LogFragmentResponse.ProtoReflect.Descriptor instead. +func (*LogFragmentResponse) Descriptor() ([]byte, []int) { + return file_vault_activity_activity_log_proto_rawDescGZIP(), []int{4} +} + var File_vault_activity_activity_log_proto protoreflect.FileDescriptor var file_vault_activity_activity_log_proto_rawDesc = []byte{ @@ -189,10 +321,28 @@ var file_vault_activity_activity_log_proto_rawDesc = []byte{ 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x2b, 0x5a, 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, - 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x76, - 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x61, 0x63, 0x74, 0x69, 0x76, - 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x47, 0x0a, 0x11, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, + 0x41, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x4c, 0x6f, 0x67, 0x12, 0x32, 0x0a, 0x08, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x69, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, + 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x2e, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x08, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, + 0xb4, 0x01, 0x0a, 0x0a, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x5f, + 0x0a, 0x15, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x62, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, + 0x70, 0x61, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2c, 0x2e, + 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x43, 0x6f, + 0x75, 0x6e, 0x74, 0x2e, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x73, + 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x12, 0x63, 0x6f, 0x75, + 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x1a, + 0x45, 0x0a, 0x17, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, + 0x61, 0x63, 0x65, 0x49, 0x64, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x15, 0x0a, 0x13, 0x4c, 0x6f, 0x67, 0x46, 0x72, 0x61, + 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x2b, 0x5a, + 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, + 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x76, 0x61, 0x75, 0x6c, + 0x74, 0x2f, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( @@ -207,20 +357,26 @@ func file_vault_activity_activity_log_proto_rawDescGZIP() []byte { return file_vault_activity_activity_log_proto_rawDescData } -var file_vault_activity_activity_log_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_vault_activity_activity_log_proto_msgTypes = make([]protoimpl.MessageInfo, 7) var file_vault_activity_activity_log_proto_goTypes = []interface{}{ - (*EntityRecord)(nil), // 0: activity.EntityRecord - (*LogFragment)(nil), // 1: activity.LogFragment - nil, // 2: activity.LogFragment.NonEntityTokensEntry + (*EntityRecord)(nil), // 0: activity.EntityRecord + (*LogFragment)(nil), // 1: activity.LogFragment + (*EntityActivityLog)(nil), // 2: activity.EntityActivityLog + (*TokenCount)(nil), // 3: activity.TokenCount + (*LogFragmentResponse)(nil), // 4: activity.LogFragmentResponse + nil, // 5: activity.LogFragment.NonEntityTokensEntry + nil, // 6: activity.TokenCount.CountByNamespaceIDEntry } var file_vault_activity_activity_log_proto_depIDxs = []int32{ 0, // 0: activity.LogFragment.entities:type_name -> activity.EntityRecord - 2, // 1: activity.LogFragment.non_entity_tokens:type_name -> activity.LogFragment.NonEntityTokensEntry - 2, // [2:2] is the sub-list for method output_type - 2, // [2:2] is the sub-list for method input_type - 2, // [2:2] is the sub-list for extension type_name - 2, // [2:2] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name + 5, // 1: activity.LogFragment.non_entity_tokens:type_name -> activity.LogFragment.NonEntityTokensEntry + 0, // 2: activity.EntityActivityLog.entities:type_name -> activity.EntityRecord + 6, // 3: activity.TokenCount.count_by_namespace_id:type_name -> activity.TokenCount.CountByNamespaceIDEntry + 4, // [4:4] is the sub-list for method output_type + 4, // [4:4] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name } func init() { file_vault_activity_activity_log_proto_init() } @@ -253,6 +409,42 @@ func file_vault_activity_activity_log_proto_init() { return nil } } + file_vault_activity_activity_log_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EntityActivityLog); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_vault_activity_activity_log_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TokenCount); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_vault_activity_activity_log_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*LogFragmentResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -260,7 +452,7 @@ func file_vault_activity_activity_log_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_vault_activity_activity_log_proto_rawDesc, NumEnums: 0, - NumMessages: 3, + NumMessages: 7, NumExtensions: 0, NumServices: 0, }, diff --git a/vault/activity/activity_log.proto b/vault/activity/activity_log.proto index a3acd6cf54..03aaed577a 100644 --- a/vault/activity/activity_log.proto +++ b/vault/activity/activity_log.proto @@ -26,3 +26,14 @@ message LogFragment { // indexed by namespace ID map non_entity_tokens = 3; } + +message EntityActivityLog { + repeated EntityRecord entities = 1; +} + +message TokenCount { + map count_by_namespace_id = 1; +} + +message LogFragmentResponse { +} diff --git a/vault/activity/query.go b/vault/activity/query.go new file mode 100644 index 0000000000..3172a14093 --- /dev/null +++ b/vault/activity/query.go @@ -0,0 +1,233 @@ +package activity + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "strconv" + "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/timeutil" + "github.com/hashicorp/vault/sdk/logical" +) + +// About 66 bytes per record: +//{"namespace_id":"xxxxx","entities":1234,"non_entity_tokens":1234}, +// = approx 7900 namespaces in 512KiB +// So one storage entry is fine (for now). +type NamespaceRecord struct { + NamespaceID string `json:"namespace_id"` + Entities uint64 `json:"entities"` + NonEntityTokens uint64 `json:"non_entity_tokens"` +} + +type PrecomputedQuery struct { + StartTime time.Time + EndTime time.Time + Namespaces []*NamespaceRecord `json:"namespaces"` +} + +type PrecomputedQueryStore struct { + logger log.Logger + view logical.Storage +} + +// The query store should be initialized with a view to the subdirectory +// it should use, like "queries". +func NewPrecomputedQueryStore(logger log.Logger, view logical.Storage, retentionMonths int) *PrecomputedQueryStore { + return &PrecomputedQueryStore{ + logger: logger, + view: view, + } +} + +func (s *PrecomputedQueryStore) Put(ctx context.Context, p *PrecomputedQuery) error { + path := fmt.Sprintf("%v/%v", p.StartTime.Unix(), p.EndTime.Unix()) + asJson, err := json.Marshal(p) + if err != nil { + return err + } + err = s.view.Put(ctx, &logical.StorageEntry{ + Key: path, + Value: asJson, + }) + if err != nil { + return err + } + return nil +} + +func (s *PrecomputedQueryStore) listStartTimes(ctx context.Context) ([]time.Time, error) { + // We could cache this to save a storage operation on each query, + // but that seems like a marginal improvment. + rawStartTimes, err := s.view.List(ctx, "") + if err != nil { + return nil, err + } + startTimes := make([]time.Time, 0, len(rawStartTimes)) + + for _, raw := range rawStartTimes { + t, err := timeutil.ParseTimeFromPath(raw) + if err != nil { + s.logger.Warn("could not parse precomputed query subdirectory", "key", raw) + continue + } + startTimes = append(startTimes, t) + } + return startTimes, nil +} + +func (s *PrecomputedQueryStore) listEndTimes(ctx context.Context, startTime time.Time) ([]time.Time, error) { + rawEndTimes, err := s.view.List(ctx, fmt.Sprintf("%v/", startTime.Unix())) + if err != nil { + return nil, err + } + endTimes := make([]time.Time, 0, len(rawEndTimes)) + + for _, raw := range rawEndTimes { + val, err := strconv.ParseInt(raw, 10, 64) + if err != nil { + s.logger.Warn("could not parse precomputed query end time", "key", raw) + continue + } + endTimes = append(endTimes, time.Unix(val, 0).UTC()) + } + return endTimes, nil +} + +func (s *PrecomputedQueryStore) QueriesAvailable(ctx context.Context) (bool, error) { + startTimes, err := s.listStartTimes(ctx) + if err != nil { + return false, err + } + return len(startTimes) > 0, nil +} + +func (s *PrecomputedQueryStore) Get(ctx context.Context, startTime, endTime time.Time) (*PrecomputedQuery, error) { + if startTime.After(endTime) { + return nil, errors.New("start time is after end time") + } + startTime = timeutil.StartOfMonth(startTime) + endTime = timeutil.EndOfMonth(endTime) + s.logger.Trace("searching for matching queries", "startTime", startTime, "endTime", endTime) + + // Find the oldest continuous region which overlaps with the given range. + // We only have to handle some collection of lower triangles like this, + // not arbitrary sets of endpoints (except in the middle of writes or GC): + // + // start -> + // end # + // | ## + // V ### + // + // # + // ## + // ### + // + // (1) find all saved start times T that are + // in [startTime,endTime] + // (if there is some report that overlaps, it will + // have a start time in the range-- an overlap + // only at the end is impossible.) + // (2) take the latest continguous region within + // that set + // i.e., walk up the diagonal as far as we can in a single + // triangle. + // (These could be combined into a single pass, but + // that seems more complicated to understand.) + + startTimes, err := s.listStartTimes(ctx) + if err != nil { + return nil, err + } + s.logger.Trace("retrieved start times from storage", "startTimes", startTimes) + + filteredList := make([]time.Time, 0, len(startTimes)) + for _, t := range startTimes { + if timeutil.InRange(t, startTime, endTime) { + filteredList = append(filteredList, t) + } + } + s.logger.Trace("filtered to range", "startTimes", filteredList) + + if len(filteredList) == 0 { + return nil, nil + } + // Descending order, as required by the timeutil function + sort.Slice(filteredList, func(i, j int) bool { + return filteredList[i].After(filteredList[j]) + }) + contiguous := timeutil.GetMostRecentContiguousMonths(filteredList) + actualStartTime := contiguous[len(contiguous)-1] + + s.logger.Trace("chose start time", "actualStartTime", actualStartTime, "contiguous", contiguous) + + endTimes, err := s.listEndTimes(ctx, actualStartTime) + if err != nil { + return nil, err + } + s.logger.Trace("retrieved end times from storage", "endTimes", endTimes) + + // Might happen if there's a race with GC + if len(endTimes) == 0 { + s.logger.Warn("missing end times", "start time", actualStartTime) + return nil, nil + } + var actualEndTime time.Time + for _, t := range endTimes { + if timeutil.InRange(t, startTime, endTime) { + if actualEndTime.IsZero() || t.After(actualEndTime) { + actualEndTime = t + } + } + } + if actualEndTime.IsZero() { + s.logger.Warn("no end time in range", "start time", actualStartTime) + return nil, nil + } + + path := fmt.Sprintf("%v/%v", actualStartTime.Unix(), actualEndTime.Unix()) + entry, err := s.view.Get(ctx, path) + if err != nil { + return nil, err + } + + p := &PrecomputedQuery{} + err = json.Unmarshal(entry.Value, p) + if err != nil { + s.logger.Warn("failed query lookup at", "path", path) + return nil, err + } + return p, nil +} + +func (s *PrecomputedQueryStore) DeleteQueriesBefore(ctx context.Context, retentionThreshold time.Time) error { + startTimes, err := s.listStartTimes(ctx) + if err != nil { + return err + } + + for _, t := range startTimes { + path := fmt.Sprintf("%v/", t.Unix()) + if t.Before(retentionThreshold) { + rawEndTimes, err := s.view.List(ctx, path) + if err != nil { + return err + } + + s.logger.Trace("deleting queries", "startTime", t) + // Don't care about what the end time is, + // the start time along determines deletion. + for _, end := range rawEndTimes { + err = s.view.Delete(ctx, path+end) + if err != nil { + return err + } + } + } + } + return nil +} diff --git a/vault/activity/query_test.go b/vault/activity/query_test.go new file mode 100644 index 0000000000..47a99f8b31 --- /dev/null +++ b/vault/activity/query_test.go @@ -0,0 +1,289 @@ +package activity + +import ( + "context" + "reflect" + "sort" + "testing" + "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/timeutil" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" +) + +func NewTestQueryStore(t *testing.T) *PrecomputedQueryStore { + t.Helper() + + logger := logging.NewVaultLogger(log.Trace) + view := &logical.InmemStorage{} + return NewPrecomputedQueryStore(logger, view, 12) +} + +func TestQueryStore_Inventory(t *testing.T) { + startTimes := []time.Time{ + time.Date(2020, 1, 15, 0, 0, 0, 0, time.UTC), + time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC), + } + + endTimes := []time.Time{ + timeutil.EndOfMonth(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + timeutil.EndOfMonth(time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC)), + timeutil.EndOfMonth(time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC)), + timeutil.EndOfMonth(time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC)), + timeutil.EndOfMonth(time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC)), + } + + qs := NewTestQueryStore(t) + ctx := context.Background() + + for _, s := range startTimes { + for _, e := range endTimes { + if e.Before(s) { + continue + } + qs.Put(ctx, &PrecomputedQuery{ + StartTime: s, + EndTime: e, + Namespaces: []*NamespaceRecord{}, + }) + } + } + + storedStartTimes, err := qs.listStartTimes(ctx) + if err != nil { + t.Fatal(err) + } + if len(storedStartTimes) != len(startTimes) { + t.Fatalf("bad length, expected %v got %v", len(startTimes), storedStartTimes) + } + sort.Slice(storedStartTimes, func(i, j int) bool { + return storedStartTimes[i].Before(storedStartTimes[j]) + }) + if !reflect.DeepEqual(storedStartTimes, startTimes) { + t.Fatalf("start time mismatch, expected %v got %v", startTimes, storedStartTimes) + } + + storedEndTimes, err := qs.listEndTimes(ctx, startTimes[1]) + expected := endTimes[1:] + if len(storedEndTimes) != len(expected) { + t.Fatalf("bad length, expected %v got %v", len(expected), storedEndTimes) + } + sort.Slice(storedEndTimes, func(i, j int) bool { + return storedEndTimes[i].Before(storedEndTimes[j]) + }) + if !reflect.DeepEqual(storedEndTimes, expected) { + t.Fatalf("end time mismatch, expected %v got %v", expected, storedEndTimes) + } + +} + +func TestQueryStore_MarshalDemarshal(t *testing.T) { + tsStart := time.Date(2020, 1, 15, 0, 0, 0, 0, time.UTC) + tsEnd := timeutil.EndOfMonth(tsStart) + + p := &PrecomputedQuery{ + StartTime: tsStart, + EndTime: tsEnd, + Namespaces: []*NamespaceRecord{ + &NamespaceRecord{ + NamespaceID: "root", + Entities: 20, + NonEntityTokens: 42, + }, + &NamespaceRecord{ + NamespaceID: "yzABC", + Entities: 15, + NonEntityTokens: 31, + }, + }, + } + + qs := NewTestQueryStore(t) + ctx := context.Background() + qs.Put(ctx, p) + result, err := qs.Get(ctx, tsStart, tsEnd) + if err != nil { + t.Fatal(err) + } + if result == nil { + t.Fatal("nil response from Get") + } + if !reflect.DeepEqual(result, p) { + t.Fatalf("unequal query objects, expected %v got %v", p, result) + } +} + +func TestQueryStore_TimeRanges(t *testing.T) { + qs := NewTestQueryStore(t) + ctx := context.Background() + + // Scenario ranges: Jan 15 - Jan 31 (one month) + // Feb 2 - Mar 31 (two months, but not contiguous) + // April and May are skipped + // June 1 - September 30 (4 months) + periods := []struct { + Begin time.Time + Ends []time.Time + }{ + { + time.Date(2020, 1, 15, 12, 45, 53, 0, time.UTC), + []time.Time{ + timeutil.EndOfMonth(time.Date(2020, 1, 1, 1, 0, 0, 0, time.UTC)), + }, + }, + { + time.Date(2020, 2, 2, 0, 0, 0, 0, time.UTC), + []time.Time{ + timeutil.EndOfMonth(time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC)), + timeutil.EndOfMonth(time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC)), + }, + }, + { + time.Date(2020, 6, 1, 0, 0, 0, 0, time.UTC), + []time.Time{ + timeutil.EndOfMonth(time.Date(2020, 6, 1, 0, 0, 0, 0, time.UTC)), + timeutil.EndOfMonth(time.Date(2020, 7, 1, 0, 0, 0, 0, time.UTC)), + timeutil.EndOfMonth(time.Date(2020, 8, 1, 0, 0, 0, 0, time.UTC)), + timeutil.EndOfMonth(time.Date(2020, 9, 1, 0, 0, 0, 0, time.UTC)), + }, + }, + } + + for _, period := range periods { + for _, e := range period.Ends { + qs.Put(ctx, &PrecomputedQuery{ + StartTime: period.Begin, + EndTime: e, + Namespaces: []*NamespaceRecord{ + &NamespaceRecord{ + NamespaceID: "root", + Entities: 17, + NonEntityTokens: 31, + }, + }, + }) + } + } + + testCases := []struct { + Name string + StartTime time.Time + EndTime time.Time + Empty bool + ExpectedStart time.Time + ExpectedEnd time.Time + }{ + { + "year query in October", + time.Date(2019, 10, 12, 0, 0, 0, 0, time.UTC), + time.Date(2020, 10, 12, 0, 0, 0, 0, time.UTC), + false, + // June - Sept + periods[2].Begin, + periods[2].Ends[3], + }, + { + "one day in January", + time.Date(2020, 1, 4, 0, 0, 0, 0, time.UTC), + time.Date(2020, 1, 5, 0, 0, 0, 0, time.UTC), + false, + // January, even though this is outside the range specified + periods[0].Begin, + periods[0].Ends[0], + }, + { + "one day in February", + time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC), + time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC), + false, + // February only + periods[1].Begin, + periods[1].Ends[0], + }, + { + "January through March", + time.Date(2020, 1, 4, 0, 0, 0, 0, time.UTC), + time.Date(2020, 3, 5, 0, 0, 0, 0, time.UTC), + false, + // February and March only + // Fails due to bug in library function, TODO + periods[1].Begin, + periods[1].Ends[1], + }, + { + "the month of May", + time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 5, 31, 0, 0, 0, 0, time.UTC), + true, // no data + time.Time{}, + time.Time{}, + }, + { + "May through June", + time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 6, 1, 0, 0, 0, 0, time.UTC), + false, + // June only + periods[2].Begin, + periods[2].Ends[0], + }, + { + "September", + time.Date(2020, 9, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 9, 1, 0, 0, 0, 0, time.UTC), + true, // We have June through September, + // but not anything starting in September + // (which does not match a real scenario) + time.Time{}, + time.Time{}, + }, + { + "December", + time.Date(2020, 12, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 12, 1, 0, 0, 0, 0, time.UTC), + true, // no data + time.Time{}, + time.Time{}, + }, + { + "June through December", + time.Date(2020, 6, 1, 12, 0, 0, 0, time.UTC), + time.Date(2020, 12, 31, 12, 0, 0, 0, time.UTC), + false, + // June through September + periods[2].Begin, + periods[2].Ends[3], + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + result, err := qs.Get(ctx, tc.StartTime, tc.EndTime) + if err != nil { + t.Fatal(err) + } + if result == nil { + if tc.Empty { + return + } else { + t.Fatal("unexpected empty result") + } + } else { + if tc.Empty { + t.Fatal("expected empty result") + } + } + if !result.StartTime.Equal(tc.ExpectedStart) { + t.Errorf("start time mismatch: %v, expected %v", result.StartTime, tc.ExpectedStart) + } + if !result.EndTime.Equal(tc.ExpectedEnd) { + t.Errorf("end time mismatch: %v, expected %v", result.EndTime, tc.ExpectedEnd) + } + }) + } +} diff --git a/vault/activity_log.go b/vault/activity_log.go index a21d194ec8..f0f0a27a8f 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -2,11 +2,21 @@ package vault import ( "context" + "encoding/json" + "errors" + "fmt" "os" + "sort" + "strconv" + "strings" "sync" "time" + "github.com/golang/protobuf/proto" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/metricsutil" + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/helper/timeutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/activity" ) @@ -14,43 +24,835 @@ import ( const ( // activitySubPath is the directory under the system view where // the log will be stored. - activitySubPath = "activity/" + activitySubPath = "counters/activity/" + activityEntityBasePath = "log/entity/" + activityTokenBasePath = "log/directtokens/" + activityQueryBasePath = "queries/" + activityConfigKey = "config" + activityIntentLogKey = "endofmonth" + + // Time to wait on perf standby before sending fragment + activityFragmentStandbyTime = 10 * time.Minute + + // Time between writes of segment to storage + activitySegmentInterval = 10 * time.Minute + + // Timeout on RPC calls. + activityFragmentSendTimeout = 1 * time.Minute + + // Timeout on storage calls. + activitySegmentWriteTimeout = 1 * time.Minute + + // Number of entity records to store per segment + // Estimated as 512KiB / 64 bytes = 8192, rounded down + activitySegmentEntityCapacity = 8000 + + // Maximum number of segments per month + activityLogMaxSegmentPerMonth = 81 + + // Number of records (entity or token) to store in a + // standby fragment before sending it to the active node. + // Estimates as 8KiB / 64 bytes = 128 + activityFragmentStandbyCapacity = 128 ) +type segmentInfo struct { + startTimestamp int64 + currentEntities *activity.EntityActivityLog + tokenCount *activity.TokenCount + entitySequenceNumber uint64 +} + // ActivityLog tracks unique entity counts and non-entity token counts. // It handles assembling log fragments (and sending them to the active // node), writing log segments, and precomputing queries. type ActivityLog struct { + core *Core + configOverrides *ActivityLogCoreConfig + + // ActivityLog.l protects the configuration settings, except enable, and any modifications + // to the current segment. + // Acquire "l" before fragmentLock if both must be held. + l sync.RWMutex + + // fragmentLock protects enable, activeEntities, fragment, standbyFragmentsReceived + fragmentLock sync.RWMutex + + // enabled indicates if the activity log is enabled for this cluster. + // This is protected by fragmentLock so we can check with only + // a single synchronization call. + enabled bool + // log destination logger log.Logger + // metrics sink + metrics metricsutil.Metrics + // view is the storage location used by ActivityLog, // defaults to sys/activity. - view logical.Storage + view *BarrierView // nodeID is the ID to use for all fragments that // are generated. // TODO: use secondary ID when available? nodeID string - // current log fragment (may be nil) and a mutex to protect it - fragmentLock sync.RWMutex + // current log fragment (may be nil) fragment *activity.LogFragment fragmentCreation time.Time + + // Channel to signal a new fragment has been created + // so it's appropriate to start the timer. + newFragmentCh chan struct{} + + // Channel for sending fragment immediately + sendCh chan struct{} + + // Channel for writing fragment immediately + writeCh chan struct{} + + // Channel to stop background processing + doneCh chan struct{} + + // All known active entities this month; use fragmentLock read-locked + // to check whether it already exists. + activeEntities map[string]struct{} + + // track metadata and contents of the most recent log segment + // currentSegment is currently unprotected by a mutex, because it is updated + // only by the worker performing rotation. + currentSegment segmentInfo + + // Fragments received from performance standbys + standbyFragmentsReceived []*activity.LogFragment + + // precomputed queries + queryStore *activity.PrecomputedQueryStore + defaultReportMonths int + retentionMonths int + + // cancel function to stop loading entities/tokens from storage to memory + activityCancel context.CancelFunc + + // channel closed by delete worker when done + deleteDone chan struct{} +} + +// These non-persistent configuration options allow us to disable +// parts of the implementation for integration testing. +// The default values should turn everything on. +type ActivityLogCoreConfig struct { + // Enable activity log even if the feature flag not set + ForceEnable bool + + // Do not start timers to send or persist fragments. + DisableTimers bool } // NewActivityLog creates an activity log. -func NewActivityLog(_ context.Context, logger log.Logger, view logical.Storage) (*ActivityLog, error) { +func NewActivityLog(core *Core, logger log.Logger, view *BarrierView, metrics metricsutil.Metrics) (*ActivityLog, error) { hostname, err := os.Hostname() if err != nil { return nil, err } - return &ActivityLog{ - logger: logger, - view: view, - nodeID: hostname, - }, nil + emptyEntityActivityLog := &activity.EntityActivityLog{ + Entities: make([]*activity.EntityRecord, 0), + } + emptyTokenCount := &activity.TokenCount{ + CountByNamespaceID: make(map[string]uint64), + } + a := &ActivityLog{ + core: core, + configOverrides: &core.activityLogConfig, + logger: logger, + view: view, + metrics: metrics, + nodeID: hostname, + newFragmentCh: make(chan struct{}, 1), + sendCh: make(chan struct{}, 1), // buffered so it can be triggered by fragment size + writeCh: make(chan struct{}, 1), // same for full segment + doneCh: make(chan struct{}, 1), + activeEntities: make(map[string]struct{}), + currentSegment: segmentInfo{ + startTimestamp: 0, + currentEntities: emptyEntityActivityLog, + tokenCount: emptyTokenCount, + entitySequenceNumber: 0, + }, + standbyFragmentsReceived: make([]*activity.LogFragment, 0), + } + + config, err := a.loadConfigOrDefault(core.activeContext) + if err != nil { + return nil, err + } + + a.SetConfigInit(config) + + a.queryStore = activity.NewPrecomputedQueryStore( + logger, + view.SubView(activityQueryBasePath), + config.RetentionMonths) + + return a, nil +} + +// saveCurrentSegmentToStorage updates the record of Entities or +// Non Entity Tokens in persistent storage +func (a *ActivityLog) saveCurrentSegmentToStorage(ctx context.Context, force bool) error { + // Prevent simultaneous changes to segment + a.l.Lock() + defer a.l.Unlock() + return a.saveCurrentSegmentToStorageLocked(ctx, force) +} + +// Must be called with l held. +func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, force bool) error { + defer a.metrics.MeasureSinceWithLabels([]string{"core", "activity", "segment_write"}, + time.Now(), []metricsutil.Label{}) + + // Swap out the pending fragments + a.fragmentLock.Lock() + localFragment := a.fragment + a.fragment = nil + standbys := a.standbyFragmentsReceived + a.standbyFragmentsReceived = make([]*activity.LogFragment, 0) + a.fragmentLock.Unlock() + + // If segment start time is zero, do not update or write + // (even if force is true). This can happen if activityLog is + // disabled after a save as been triggered. + if a.currentSegment.startTimestamp == 0 { + return nil + } + + // Measure the current fragment + if localFragment != nil { + a.metrics.IncrCounterWithLabels([]string{"core", "activity", "fragment_size"}, + float32(len(localFragment.Entities)), + []metricsutil.Label{ + {"type", "entity"}, + }) + a.metrics.IncrCounterWithLabels([]string{"core", "activity", "fragment_size"}, + float32(len(localFragment.NonEntityTokens)), + []metricsutil.Label{ + {"type", "direct_token"}, + }) + } + + // Collect new entities and new tokens. + saveChanges := false + newEntities := make(map[string]*activity.EntityRecord) + for _, f := range append(standbys, localFragment) { + if f == nil { + continue + } + for _, e := range f.Entities { + // We could sort by timestamp to see which is first. + // We'll ignore that; the order of the append above means + // that we choose entries in localFragment over those + // from standby nodes. + newEntities[e.EntityID] = e + saveChanges = true + } + for ns, val := range f.NonEntityTokens { + a.currentSegment.tokenCount.CountByNamespaceID[ns] += val + saveChanges = true + } + } + + if !saveChanges { + return nil + } + + // Will all new entities fit? If not, roll over to a new segment. + available := activitySegmentEntityCapacity - len(a.currentSegment.currentEntities.Entities) + remaining := available - len(newEntities) + excess := 0 + if remaining < 0 { + excess = -remaining + } + + segmentEntities := a.currentSegment.currentEntities.Entities + excessEntities := make([]*activity.EntityRecord, 0, excess) + for _, record := range newEntities { + if available > 0 { + segmentEntities = append(segmentEntities, record) + available -= 1 + } else { + excessEntities = append(excessEntities, record) + } + } + a.currentSegment.currentEntities.Entities = segmentEntities + + err := a.saveCurrentSegmentInternal(ctx, force) + if err != nil { + // The current fragment(s) have already been placed into the in-memory + // segment, but we may lose any excess (in excessEntities). + // There isn't a good way to unwind the transaction on failure, + // so we may just lose some records. + return err + } + + if available <= 0 { + if a.currentSegment.entitySequenceNumber >= activityLogMaxSegmentPerMonth { + // Cannot send as Warn because it will repeat too often, + // and disabling/renabling would be complicated. + a.logger.Trace("too many segments in current month", "dropped", len(excessEntities)) + return nil + } + + // Rotate to next segment + a.currentSegment.entitySequenceNumber += 1 + if len(excessEntities) > activitySegmentEntityCapacity { + a.logger.Warn("too many new active entities %v, dropping tail", len(excessEntities)) + excessEntities = excessEntities[:activitySegmentEntityCapacity] + } + a.currentSegment.currentEntities.Entities = excessEntities + err := a.saveCurrentSegmentInternal(ctx, force) + if err != nil { + return err + } + } + return nil +} + +func (a *ActivityLog) saveCurrentSegmentInternal(ctx context.Context, force bool) error { + entityPath := fmt.Sprintf("log/entity/%d/%d", a.currentSegment.startTimestamp, a.currentSegment.entitySequenceNumber) + // RFC (VLT-120) defines this as 1-indexed, but it should be 0-indexed + tokenPath := fmt.Sprintf("log/directtokens/%d/0", a.currentSegment.startTimestamp) + + // TODO: have a member function on segmentInfo struct to do the below two + // blocks + if len(a.currentSegment.currentEntities.Entities) > 0 || force { + entities, err := proto.Marshal(a.currentSegment.currentEntities) + if err != nil { + return err + } + + a.logger.Trace("writing segment", "path", entityPath) + err = a.view.Put(ctx, &logical.StorageEntry{ + Key: entityPath, + Value: entities, + }) + if err != nil { + return err + } + } + + if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 || force { + tokenCount, err := proto.Marshal(a.currentSegment.tokenCount) + if err != nil { + return err + } + + a.logger.Trace("writing segment", "path", tokenPath) + err = a.view.Put(ctx, &logical.StorageEntry{ + Key: tokenPath, + Value: tokenCount, + }) + if err != nil { + return err + } + } + + return nil +} + +// parseSegmentNumberFromPath returns the segment number from a path +// (and if it exists - it is the last element in the path) +func parseSegmentNumberFromPath(path string) (int, bool) { + // as long as both s and sep are not "", len(elems) >= 1 + elems := strings.Split(path, "/") + segmentNum, err := strconv.Atoi(elems[len(elems)-1]) + if err != nil { + return 0, false + } + + return segmentNum, true +} + +// availableLogs returns the start_time(s) associated with months for which logs exist, sorted last to first +func (a *ActivityLog) availableLogs(ctx context.Context) ([]time.Time, error) { + paths := make([]string, 0) + for _, basePath := range []string{activityEntityBasePath, activityTokenBasePath} { + p, err := a.view.List(ctx, basePath) + if err != nil { + return nil, err + } + + paths = append(paths, p...) + } + + pathSet := make(map[time.Time]struct{}) + out := make([]time.Time, 0) + for _, path := range paths { + // generate a set of unique start times + time, err := timeutil.ParseTimeFromPath(path) + if err != nil { + return nil, err + } + + if _, present := pathSet[time]; !present { + pathSet[time] = struct{}{} + out = append(out, time) + } + } + + sort.Slice(out, func(i, j int) bool { + // sort in reverse order to make processing most recent segment easier + return out[i].After(out[j]) + }) + + a.logger.Trace("scanned existing logs", "out", out) + + return out, nil +} + +func (a *ActivityLog) getMostRecentActivityLogSegment(ctx context.Context) ([]time.Time, error) { + logTimes, err := a.availableLogs(ctx) + if err != nil { + return nil, err + } + + return timeutil.GetMostRecentContiguousMonths(logTimes), nil +} + +// getLastEntitySegmentNumber returns the (non-negative) last segment number for the :startTime:, if it exists +func (a *ActivityLog) getLastEntitySegmentNumber(ctx context.Context, startTime time.Time) (uint64, bool, error) { + p, err := a.view.List(ctx, activityEntityBasePath+fmt.Sprint(startTime.Unix())+"/") + if err != nil { + return 0, false, err + } + + highestNum := -1 + for _, path := range p { + if num, ok := parseSegmentNumberFromPath(path); ok { + if num > highestNum { + highestNum = num + } + } + } + + if highestNum < 0 { + // numbers less than 0 are invalid. if a negative number is the highest value, there isn't a segment + return 0, false, nil + } + + return uint64(highestNum), true, nil +} + +// WalkEntitySegments loads each of the entity segments for a particular start time +func (a *ActivityLog) WalkEntitySegments(ctx context.Context, + startTime time.Time, + walkFn func(*activity.EntityActivityLog)) error { + + basePath := activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + pathList, err := a.view.List(ctx, basePath) + if err != nil { + return err + } + + for _, path := range pathList { + raw, err := a.view.Get(ctx, basePath+path) + if err != nil { + return err + } + if raw == nil { + a.logger.Warn("expected log segment not found", "startTime", startTime, "segment", path) + continue + } + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(raw.Value, out) + if err != nil { + return fmt.Errorf("unable to parse segment %v%v: %w", basePath, path, err) + } + walkFn(out) + } + return nil +} + +// WalkTokenSegments loads each of the token segments (expected 1) for a particular start time +func (a *ActivityLog) WalkTokenSegments(ctx context.Context, + startTime time.Time, + walkFn func(*activity.TokenCount)) error { + + basePath := activityTokenBasePath + fmt.Sprint(startTime.Unix()) + "/" + pathList, err := a.view.List(ctx, basePath) + if err != nil { + return err + } + + for _, path := range pathList { + raw, err := a.view.Get(ctx, basePath+path) + if err != nil { + return err + } + if raw == nil { + a.logger.Warn("expected token segment not found", "startTime", startTime, "segment", path) + continue + } + out := &activity.TokenCount{} + err = proto.Unmarshal(raw.Value, out) + if err != nil { + return fmt.Errorf("unable to parse token segment %v%v: %w", basePath, path, err) + } + walkFn(out) + } + return nil +} + +// loadPriorEntitySegment populates the in-memory tracker for entity IDs that have +// been active "this month" +func (a *ActivityLog) loadPriorEntitySegment(ctx context.Context, startTime time.Time, sequenceNum uint64) error { + path := activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) + data, err := a.view.Get(ctx, path) + if err != nil { + return err + } + + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err + } + + a.l.RLock() + a.fragmentLock.Lock() + // Handle the (unlikely) case where the end of the month has been reached while background loading. + // Or the feature has been disabled. + if a.enabled && startTime.Unix() == a.currentSegment.startTimestamp { + for _, ent := range out.Entities { + a.activeEntities[ent.EntityID] = struct{}{} + } + } + a.fragmentLock.Unlock() + a.l.RUnlock() + + return nil +} + +// loadCurrentEntitySegment loads the most recent segment (for "this month") into memory +// (to append new entries), and to the activeEntities to avoid duplication +func (a *ActivityLog) loadCurrentEntitySegment(ctx context.Context, startTime time.Time, sequenceNum uint64) error { + path := activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) + data, err := a.view.Get(ctx, path) + if err != nil { + return err + } + + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err + } + + if !a.core.perfStandby { + a.currentSegment = segmentInfo{ + startTimestamp: startTime.Unix(), + currentEntities: &activity.EntityActivityLog{ + Entities: out.Entities, + }, + tokenCount: a.currentSegment.tokenCount, + entitySequenceNumber: sequenceNum, + } + } else { + // populate this for edge case checking (if end of month passes while background loading on standby) + a.currentSegment.startTimestamp = startTime.Unix() + } + + for _, ent := range out.Entities { + a.activeEntities[ent.EntityID] = struct{}{} + } + + return nil +} + +// tokenCountExists checks if there's a token log for :startTime: +// this function should be called with the lock held +func (a *ActivityLog) tokenCountExists(ctx context.Context, startTime time.Time) (bool, error) { + p, err := a.view.List(ctx, activityTokenBasePath+fmt.Sprint(startTime.Unix())+"/") + if err != nil { + return false, err + } + + for _, path := range p { + if num, ok := parseSegmentNumberFromPath(path); ok && num == 0 { + return true, nil + } + } + + return false, nil +} + +// loadTokenCount populates the in-memory representation of activity token count +// this function should be called with the lock held +func (a *ActivityLog) loadTokenCount(ctx context.Context, startTime time.Time) error { + tokenCountExists, err := a.tokenCountExists(ctx, startTime) + if err != nil { + return err + } + if !tokenCountExists { + return nil + } + + path := activityTokenBasePath + fmt.Sprint(startTime.Unix()) + "/0" + data, err := a.view.Get(ctx, path) + if err != nil { + return err + } + + out := &activity.TokenCount{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err + } + + // An empty map is unmarshaled as nil + if out.CountByNamespaceID == nil { + out.CountByNamespaceID = make(map[string]uint64) + } + a.currentSegment.tokenCount = out + + return nil +} + +// entityBackgroundLoader loads entity activity log records for start_date :t: +func (a *ActivityLog) entityBackgroundLoader(ctx context.Context, wg *sync.WaitGroup, t time.Time, seqNums <-chan uint64) { + defer wg.Done() + for seqNum := range seqNums { + select { + case <-a.doneCh: + a.logger.Info("background processing told to halt while loading entities", "time", t, "sequence", seqNum) + return + default: + } + + err := a.loadPriorEntitySegment(ctx, t, seqNum) + if err != nil { + a.logger.Error("error loading entity activity log", "time", t, "sequence", seqNum, "err", err) + } + } +} + +// Initialize a new current segment, based on the current time. +// Call with fragmentLock and l held. +func (a *ActivityLog) startNewCurrentLogLocked() { + a.logger.Trace("initializing new log") + a.resetCurrentLog() + a.currentSegment.startTimestamp = time.Now().Unix() +} + +// Should be called with fragmentLock and l held. +func (a *ActivityLog) newMonthCurrentLogLocked(currentTime time.Time) { + a.logger.Trace("continuing log to new month") + a.resetCurrentLog() + monthStart := timeutil.StartOfMonth(currentTime.UTC()) + a.currentSegment.startTimestamp = monthStart.Unix() +} + +// Reset all the current segment state. +// Should be called with fragmentLock and l held. +func (a *ActivityLog) resetCurrentLog() { + emptyEntityActivityLog := &activity.EntityActivityLog{ + Entities: make([]*activity.EntityRecord, 0), + } + emptyTokenCount := &activity.TokenCount{ + CountByNamespaceID: make(map[string]uint64), + } + + a.currentSegment.startTimestamp = 0 + a.currentSegment.currentEntities = emptyEntityActivityLog + a.currentSegment.tokenCount = emptyTokenCount + a.currentSegment.entitySequenceNumber = 0 + + a.fragment = nil + a.activeEntities = make(map[string]struct{}) + a.standbyFragmentsReceived = make([]*activity.LogFragment, 0) +} + +func (a *ActivityLog) deleteLogWorker(startTimestamp int64, whenDone chan struct{}) { + ctx := namespace.RootContext(nil) + entityPath := fmt.Sprintf("%v%v/", activityEntityBasePath, startTimestamp) + tokenPath := fmt.Sprintf("%v%v/", activityTokenBasePath, startTimestamp) + + // TODO: handle seal gracefully, if we're still working? + entitySegments, err := a.view.List(ctx, entityPath) + if err != nil { + a.logger.Error("could not list entity paths", "error", err) + return + } + for _, p := range entitySegments { + err = a.view.Delete(ctx, entityPath+p) + if err != nil { + a.logger.Error("could not delete entity log", "error", err) + } + } + + tokenSegments, err := a.view.List(ctx, tokenPath) + if err != nil { + a.logger.Error("could not list token paths", "error", err) + return + } + for _, p := range tokenSegments { + err = a.view.Delete(ctx, tokenPath+p) + if err != nil { + a.logger.Error("could not delete token log", "error", err) + } + } + + // Allow whoever started this as a goroutine to wait for it to finish. + close(whenDone) +} + +// refreshFromStoredLog loads entity segments and token counts into memory (for "this month" only) +// this will synchronously load the most recent entity segment (and the token counts) into memory, +// and then kick off a background task to load the rest of the segments +// +// This method is called during init so we don't acquire the normally-required locks in it. +func (a *ActivityLog) refreshFromStoredLog(ctx context.Context, wg *sync.WaitGroup) error { + decreasingLogTimes, err := a.getMostRecentActivityLogSegment(ctx) + if err != nil { + return err + } + if len(decreasingLogTimes) == 0 { + // If no logs exist, and we are enabled, then + // start with the current timestamp + if a.enabled { + a.startNewCurrentLogLocked() + } + return nil + } + mostRecent := decreasingLogTimes[0] + if !timeutil.IsCurrentMonth(mostRecent, time.Now().UTC()) { + // no activity logs to load for this month + // If we are enabled, interpret it as having missed + // the rotation. + + if a.enabled { + a.logger.Trace("no log segments for current month", "mostRecent", mostRecent) + a.logger.Info("rotating activity log to new month") + a.newMonthCurrentLogLocked(time.Now().UTC()) + } + return nil + } + if !a.enabled { + a.logger.Warn("activity log exists but is disabled, cleaning up") + go a.deleteLogWorker(mostRecent.Unix(), make(chan struct{})) + return nil + } + + if !a.core.perfStandby { + err = a.loadTokenCount(ctx, mostRecent) + if err != nil { + return err + } + } + + lastSegment, segmentsExist, err := a.getLastEntitySegmentNumber(ctx, mostRecent) + if err != nil { + return err + } + if !segmentsExist { + a.logger.Trace("no entity segments for current month") + return nil + } + + err = a.loadCurrentEntitySegment(ctx, mostRecent, lastSegment) + if err != nil || lastSegment == 0 { + return err + } + lastSegment-- + + seqNums := make(chan uint64, lastSegment+1) + wg.Add(1) + go a.entityBackgroundLoader(ctx, wg, mostRecent, seqNums) + + for n := int(lastSegment); n >= 0; n-- { + seqNums <- uint64(n) + } + close(seqNums) + + return nil +} + +// This version is used during construction +func (a *ActivityLog) SetConfigInit(config activityConfig) { + switch config.Enabled { + case "enable": + a.enabled = true + case "default": + a.enabled = activityLogEnabledDefault + case "disable": + a.enabled = false + } + + if a.configOverrides.ForceEnable { + a.enabled = true + } + + a.defaultReportMonths = config.DefaultReportMonths + a.retentionMonths = config.RetentionMonths +} + +// This version reacts to user changes +func (a *ActivityLog) SetConfig(ctx context.Context, config activityConfig) { + a.l.Lock() + defer a.l.Unlock() + + // enabled is protected by fragmentLock + a.fragmentLock.Lock() + switch config.Enabled { + case "enable": + a.enabled = true + case "default": + a.enabled = activityLogEnabledDefault + case "disable": + a.enabled = false + } + + if !a.enabled && a.currentSegment.startTimestamp != 0 { + a.logger.Trace("deleting current segment") + a.deleteDone = make(chan struct{}) + go a.deleteLogWorker(a.currentSegment.startTimestamp, a.deleteDone) + a.resetCurrentLog() + } + + forceSave := false + if a.enabled && a.currentSegment.startTimestamp == 0 { + a.startNewCurrentLogLocked() + // Force a save so we can distinguish between + // + // Month N-1: present + // Month N: + // + // and + // + // Month N-1: present + // Month N: + forceSave = true + } + a.fragmentLock.Unlock() + + if forceSave { + // l is still held here + a.saveCurrentSegmentInternal(ctx, true) + } + + a.defaultReportMonths = config.DefaultReportMonths + a.retentionMonths = config.RetentionMonths + + // check for segments out of retention period, if it has changed + go a.retentionWorker(time.Now(), a.retentionMonths) +} + +func (a *ActivityLog) queriesAvailable(ctx context.Context) (bool, error) { + if a.queryStore == nil { + return false, nil + } + return a.queryStore.QueriesAvailable(ctx) } // setupActivityLog hooks up the singleton ActivityLog into Core. @@ -59,19 +861,350 @@ func (c *Core) setupActivityLog(ctx context.Context) error { logger := c.baseLogger.Named("activity") c.AddLogger(logger) - manager, err := NewActivityLog(ctx, logger, view) + manager, err := NewActivityLog(c, logger, view, c.metricSink) if err != nil { return err } c.activityLog = manager + + // load activity log for "this month" into memory + refreshCtx, cancelFunc := context.WithCancel(namespace.RootContext(nil)) + manager.activityCancel = cancelFunc + var wg sync.WaitGroup + err = manager.refreshFromStoredLog(refreshCtx, &wg) + if err != nil { + return err + } + + // Start the background worker, depending on type + // Lock already held here, can't use .PerfStandby() + // The workers need to know the current segment time. + if c.perfStandby { + go manager.perfStandbyFragmentWorker() + } else { + go manager.activeFragmentWorker() + + // Check for any intent log, in the background + go manager.precomputedQueryWorker() + + // Catch up on garbage collection + go manager.retentionWorker(time.Now(), manager.retentionMonths) + } + + // Link the token store to this core + c.tokenStore.SetActivityLog(manager) + return nil } -func (a *ActivityLog) AddEntityToFragment(entityID string, namespaceID string, timestamp time.Time) { +// stopActivityLog removes the ActivityLog from Core +// and frees any resources. +func (c *Core) stopActivityLog() error { + if c.tokenStore != nil { + c.tokenStore.SetActivityLog(nil) + } + + // preSeal may run before startActivityLog got a chance to complete. + if c.activityLog != nil { + // Shut down background worker + close(c.activityLog.doneCh) + // cancel refreshing logs from storage + if c.activityLog.activityCancel != nil { + c.activityLog.activityCancel() + } + } + + c.activityLog = nil + return nil +} + +func (a *ActivityLog) StartOfNextMonth() time.Time { + a.l.RLock() + defer a.l.RUnlock() + var segmentStart time.Time + if a.currentSegment.startTimestamp == 0 { + segmentStart = time.Now().UTC() + } else { + segmentStart = time.Unix(a.currentSegment.startTimestamp, 0).UTC() + } + // Basing this on the segment start will mean we trigger EOM rollover when + // necessary because we were down. + return timeutil.StartOfNextMonth(segmentStart) +} + +// perfStandbyFragmentWorker handles scheduling fragments +// to send via RPC; it runs on perf standby nodes only. +func (a *ActivityLog) perfStandbyFragmentWorker() { + timer := time.NewTimer(time.Duration(0)) + fragmentWaiting := false + // Eat first event, so timer is stopped + <-timer.C + + endOfMonth := time.NewTimer(a.StartOfNextMonth().Sub(time.Now())) + if a.configOverrides.DisableTimers { + endOfMonth.Stop() + } + + sendFunc := func() { + ctx, cancel := context.WithTimeout(context.Background(), activityFragmentSendTimeout) + defer cancel() + err := a.sendCurrentFragment(ctx) + if err != nil { + a.logger.Warn("activity log fragment lost", "error", err) + } + } + + for { + select { + case <-a.doneCh: + // Shutting down activity log. + if fragmentWaiting && !timer.Stop() { + <-timer.C + } + if !endOfMonth.Stop() { + <-endOfMonth.C + } + return + case <-a.newFragmentCh: + // New fragment created, start the timer if not + // already running + if !fragmentWaiting { + fragmentWaiting = true + if !a.configOverrides.DisableTimers { + a.logger.Trace("reset fragment timer") + timer.Reset(activityFragmentStandbyTime) + } + } + case <-timer.C: + a.logger.Trace("sending fragment on timer expiration") + fragmentWaiting = false + sendFunc() + case <-a.sendCh: + a.logger.Trace("sending fragment on request") + // It might be that we get sendCh before fragmentCh + // if a fragment is created and then immediately fills + // up to its limit. So we attempt to send even if the timer's + // not running. + if fragmentWaiting { + fragmentWaiting = false + if !timer.Stop() { + <-timer.C + } + } + sendFunc() + case <-endOfMonth.C: + a.logger.Trace("sending fragment on end of month") + // Flush the current fragment, if any + if fragmentWaiting { + fragmentWaiting = false + if !timer.Stop() { + <-timer.C + } + } + sendFunc() + + // clear active entity set + a.fragmentLock.Lock() + a.activeEntities = make(map[string]struct{}) + a.fragmentLock.Unlock() + + // Set timer for next month. + // The current segment *probably* hasn't been set yet (via invalidation), + // so don't rely on it. + target := timeutil.StartOfNextMonth(time.Now().UTC()) + endOfMonth.Reset(target.Sub(time.Now())) + } + } +} + +// activeFragmentWorker handles scheduling the write of the next +// segment. It runs on active nodes only. +func (a *ActivityLog) activeFragmentWorker() { + ticker := time.NewTicker(activitySegmentInterval) + + endOfMonth := time.NewTimer(a.StartOfNextMonth().Sub(time.Now())) + if a.configOverrides.DisableTimers { + endOfMonth.Stop() + } + + writeFunc := func() { + ctx, cancel := context.WithTimeout(context.Background(), activitySegmentWriteTimeout) + defer cancel() + err := a.saveCurrentSegmentToStorage(ctx, false) + if err != nil { + a.logger.Warn("activity log segment not saved, current fragment lost", "error", err) + } + } + + for { + select { + case <-a.doneCh: + // Shutting down activity log. + ticker.Stop() + return + case <-a.newFragmentCh: + // Just eat the message; the ticker is + // already running so we don't need to start it. + // (But we might change the behavior in the future.) + a.logger.Trace("new local fragment created") + continue + case <-ticker.C: + // It's harder to disable a Ticker so we'll just ignore it. + if a.configOverrides.DisableTimers { + continue + } + a.logger.Trace("writing segment on timer expiration") + writeFunc() + case <-a.writeCh: + a.logger.Trace("writing segment on request") + writeFunc() + + // Reset the schedule to wait 10 minutes from this forced write. + ticker.Stop() + ticker = time.NewTicker(activitySegmentInterval) + + // Simpler, but ticker.Reset was introduced in go 1.15: + // ticker.Reset(activitySegmentInterval) + case currentTime := <-endOfMonth.C: + err := a.HandleEndOfMonth(currentTime.UTC()) + if err != nil { + a.logger.Error("failed to perform end of month rotation", "error", err) + } + + // Garbage collect any segments or queries based on the immediate + // value of retentionMonths. + a.l.RLock() + go a.retentionWorker(currentTime.UTC(), a.retentionMonths) + a.l.RUnlock() + + delta := a.StartOfNextMonth().Sub(time.Now()) + if delta < 20*time.Minute { + delta = 20 * time.Minute + } + a.logger.Trace("scheduling next month", "delta", delta) + endOfMonth.Reset(delta) + } + } + +} + +type ActivityIntentLog struct { + PreviousMonth int64 `json:"previous_month"` + NextMonth int64 `json:"next_month"` +} + +// Handle rotation to end-of-month +// currentTime is an argument for unit-testing purposes +func (a *ActivityLog) HandleEndOfMonth(currentTime time.Time) error { + ctx := namespace.RootContext(nil) + + // Hold lock to prevent segment or enable changing, + // disable will apply to *next* month. + a.l.Lock() + defer a.l.Unlock() + + a.fragmentLock.RLock() + // Don't bother if disabled + enabled := a.enabled + a.fragmentLock.RUnlock() + if !enabled { + return nil + } + + a.logger.Trace("starting end of month processing", "rolloverTime", currentTime) + + prevSegmentTimestamp := a.currentSegment.startTimestamp + nextSegmentTimestamp := timeutil.StartOfMonth(currentTime.UTC()).Unix() + + // Write out an intent log for the rotation with the current and new segment times. + intentLog := &ActivityIntentLog{ + PreviousMonth: prevSegmentTimestamp, + NextMonth: nextSegmentTimestamp, + } + entry, err := logical.StorageEntryJSON(activityIntentLogKey, intentLog) + if err != nil { + return err + } + err = a.view.Put(ctx, entry) + if err != nil { + return err + } + + // Save the current segment; this does not guarantee that fragment will be + // empty when it returns, but dropping some measurements is acceptable. + // We use force=true here in case an entry didn't appear this month + err = a.saveCurrentSegmentToStorageLocked(ctx, true) + + // Don't return this error, just log it, we are done with that segment anyway. + if err != nil { + a.logger.Warn("last save of segment failed", "error", err) + } + + // Advance the log; no need to force a save here because we have + // the intent log written already. + // + // On recovery refreshFromStoredLock() will see we're no longer + // in the previous month, and recover by calling newMonthCurrentLog + // again and triggering the precomputed query. + a.fragmentLock.Lock() + a.newMonthCurrentLogLocked(currentTime) + a.fragmentLock.Unlock() + + // Work on precomputed queries in background + go a.precomputedQueryWorker() + + return nil +} + +// ResetActivityLog is used to extract the current fragment(s) during +// integration testing, so that it can be checked in a race-free way. +func (c *Core) ResetActivityLog() []*activity.LogFragment { + c.stateLock.RLock() + a := c.activityLog + c.stateLock.RUnlock() + if a == nil { + return nil + } + + allFragments := make([]*activity.LogFragment, 1) + a.fragmentLock.Lock() + allFragments[0] = a.fragment + a.fragment = nil + + allFragments = append(allFragments, a.standbyFragmentsReceived...) + a.standbyFragmentsReceived = make([]*activity.LogFragment, 0) + a.fragmentLock.Unlock() + return allFragments +} + +// AddEntityToFragment checks an entity ID for uniqueness and +// if not already present, adds it to the current fragment. +// The timestamp is a Unix timestamp *without* nanoseconds, as that +// is what token.CreationTime uses. +func (a *ActivityLog) AddEntityToFragment(entityID string, namespaceID string, timestamp int64) { + // Check whether entity ID already recorded + var present bool + + a.fragmentLock.RLock() + if a.enabled { + _, present = a.activeEntities[entityID] + } else { + present = true + } + a.fragmentLock.RUnlock() + if present { + return + } + + // Update current fragment with new active entity a.fragmentLock.Lock() defer a.fragmentLock.Unlock() - // TODO: check whether entity ID already recorded + // Re-check entity ID after re-acquiring lock + _, present = a.activeEntities[entityID] + if present { + return + } a.createCurrentFragment() @@ -79,14 +1212,19 @@ func (a *ActivityLog) AddEntityToFragment(entityID string, namespaceID string, t &activity.EntityRecord{ EntityID: entityID, NamespaceID: namespaceID, - Timestamp: timestamp.UnixNano(), + Timestamp: timestamp, }) + a.activeEntities[entityID] = struct{}{} } func (a *ActivityLog) AddTokenToFragment(namespaceID string) { a.fragmentLock.Lock() defer a.fragmentLock.Unlock() + if !a.enabled { + return + } + a.createCurrentFragment() a.fragment.NonEntityTokens[namespaceID] += 1 @@ -101,8 +1239,373 @@ func (a *ActivityLog) createCurrentFragment() { Entities: make([]*activity.EntityRecord, 0, 120), NonEntityTokens: make(map[string]uint64), } - a.fragmentCreation = time.Now() + a.fragmentCreation = time.Now().UTC() - // TODO: start a timer to send it, if we're a performance standby + // Signal that a new segment is available, start + // the timer to send it. + a.newFragmentCh <- struct{}{} } } + +func (a *ActivityLog) receivedFragment(fragment *activity.LogFragment) { + a.logger.Trace("received fragment from standby", "node", fragment.OriginatingNode) + + a.fragmentLock.Lock() + defer a.fragmentLock.Unlock() + + if !a.enabled { + return + } + + for _, e := range fragment.Entities { + a.activeEntities[e.EntityID] = struct{}{} + } + + a.standbyFragmentsReceived = append(a.standbyFragmentsReceived, fragment) + + // TODO: check if current segment is full and should be written + +} + +type ClientCountResponse struct { + DistinctEntities int `json:"distinct_entities"` + NonEntityTokens int `json:"non_entity_tokens"` + Clients int `json:"clients"` +} + +type ClientCountInNamespace struct { + NamespaceID string `json:"namespace_id"` + NamespacePath string `json:"namespace_path"` + Counts ClientCountResponse `json:"counts"` +} + +// ActivityLogInjectResponse injects a precomputed query into storage for testing. +func (c *Core) ActivityLogInjectResponse(ctx context.Context, pq *activity.PrecomputedQuery) error { + c.stateLock.RLock() + a := c.activityLog + c.stateLock.RUnlock() + if a == nil { + return errors.New("nil activity log") + } + return a.queryStore.Put(ctx, pq) +} + +func (a *ActivityLog) includeInResponse(query *namespace.Namespace, record *namespace.Namespace) bool { + if record == nil { + // Deleted namespace, only include in root queries + return query.ID == namespace.RootNamespaceID + } + return record.HasParent(query) +} + +func (a *ActivityLog) DefaultStartTime(endTime time.Time) time.Time { + // If end time is September 30, then start time should be + // October 1st to get 12 months of data. + a.l.RLock() + defer a.l.RUnlock() + + monthStart := timeutil.StartOfMonth(endTime) + return monthStart.AddDate(0, -a.defaultReportMonths+1, 0) +} + +func (a *ActivityLog) handleQuery(ctx context.Context, startTime, endTime time.Time) (map[string]interface{}, error) { + queryNS, err := namespace.FromContext(ctx) + if err != nil { + return nil, err + } + + pq, err := a.queryStore.Get(ctx, startTime, endTime) + if err != nil { + return nil, err + } + if pq == nil { + return nil, nil + } + + responseData := make(map[string]interface{}) + responseData["start_time"] = pq.StartTime.Format(time.RFC3339) + responseData["end_time"] = pq.EndTime.Format(time.RFC3339) + byNamespace := make([]*ClientCountInNamespace, 0) + + totalEntities := 0 + totalTokens := 0 + + for _, nsRecord := range pq.Namespaces { + ns, err := NamespaceByID(ctx, nsRecord.NamespaceID, a.core) + if err != nil { + return nil, err + } + if a.includeInResponse(queryNS, ns) { + var displayPath string + if ns == nil { + displayPath = fmt.Sprintf("deleted namespace %q", nsRecord.NamespaceID) + } else { + displayPath = ns.Path + } + byNamespace = append(byNamespace, &ClientCountInNamespace{ + NamespaceID: nsRecord.NamespaceID, + NamespacePath: displayPath, + Counts: ClientCountResponse{ + DistinctEntities: int(nsRecord.Entities), + NonEntityTokens: int(nsRecord.NonEntityTokens), + Clients: int(nsRecord.Entities + nsRecord.NonEntityTokens), + }, + }) + totalEntities += int(nsRecord.Entities) + totalTokens += int(nsRecord.NonEntityTokens) + } + } + + responseData["by_namespace"] = byNamespace + responseData["total"] = &ClientCountResponse{ + DistinctEntities: totalEntities, + NonEntityTokens: totalTokens, + Clients: totalEntities + totalTokens, + } + return responseData, nil +} + +type activityConfig struct { + // DefaultReportMonths are the default number of months that are returned on + // a report. The zero value uses the system default of 12. + DefaultReportMonths int `json:"default_report_months"` + + // RetentionMonths defines the number of months we want to retain data. The + // zero value uses the system default of 24 months. + RetentionMonths int `json:"retention_months"` + + // Enabled is one of enable, disable, default. + Enabled string `json:"enabled"` +} + +func defaultActivityConfig() activityConfig { + return activityConfig{ + DefaultReportMonths: 12, + RetentionMonths: 24, + Enabled: "default", + } +} + +func (a *ActivityLog) loadConfigOrDefault(ctx context.Context) (activityConfig, error) { + // Load from storage + var config activityConfig + configRaw, err := a.view.Get(ctx, activityConfigKey) + if err != nil { + return config, err + } + if configRaw == nil { + return defaultActivityConfig(), nil + } + + if err := configRaw.DecodeJSON(&config); err != nil { + return config, err + } + + return config, nil +} + +func (a *ActivityLog) HandleTokenCreation(entry *logical.TokenEntry) { + // enabled state is checked in both of these functions, + // because we have to grab a mutex there anyway. + if entry.EntityID != "" { + a.AddEntityToFragment(entry.EntityID, entry.NamespaceID, entry.CreationTime) + } else { + a.AddTokenToFragment(entry.NamespaceID) + } +} + +// goroutine to process the request in the intent log, creating precomputed queries. +// We expect the return value won't be checked, so log errors as they occur +// (but for unit testing having the error return should help.) +func (a *ActivityLog) precomputedQueryWorker() error { + ctx, cancel := context.WithCancel(namespace.RootContext(nil)) + defer cancel() + + // Cancel the context if activity log is shut down. + // This will cause the next storage operation to fail. + go func() { + select { + case <-a.doneCh: + cancel() + case <-ctx.Done(): + break + } + }() + + // Load the intent log + rawIntentLog, err := a.view.Get(ctx, activityIntentLogKey) + if err != nil { + a.logger.Warn("could not load intent log", "error", err) + return err + } + if rawIntentLog == nil { + a.logger.Trace("no intent log found") + return err + } + var intent ActivityIntentLog + err = json.Unmarshal(rawIntentLog.Value, &intent) + if err != nil { + a.logger.Warn("could not parse intent log", "error", err) + return err + } + + // currentMonth could change (from another month end) after we release the lock. + // But, it's not critical to correct operation; this is a check for intent logs that are + // too old, and startTimestamp should only go forward (unless it is zero.) + // If there's an intent log, finish it even if the feature is currently disabled. + a.l.RLock() + currentMonth := a.currentSegment.startTimestamp + // Base retention period on the month we are generating (even in the past)--- time.Now() + // would work but this will be easier to control in tests. + retentionWindow := timeutil.MonthsPreviousTo(a.retentionMonths, time.Unix(intent.NextMonth, 0).UTC()) + a.l.RUnlock() + if currentMonth != 0 && intent.NextMonth != currentMonth { + a.logger.Warn("intent log does not match current segment", + "intent", intent.NextMonth, "current", currentMonth) + return errors.New("intent log is too far in the past") + } + + lastMonth := intent.PreviousMonth + a.logger.Info("computing queries", "month", lastMonth) + + times, err := a.getMostRecentActivityLogSegment(ctx) + if err != nil { + a.logger.Warn("could not list recent segments", "error", err) + return err + } + if len(times) == 0 { + a.logger.Warn("no months in storage") + a.view.Delete(ctx, activityIntentLogKey) + return errors.New("previous month not found") + } + if times[0].Unix() != lastMonth { + a.logger.Warn("last month not in storage", "latest", times[0].Unix()) + a.view.Delete(ctx, activityIntentLogKey) + return errors.New("previous month not found") + } + + // "times" is already in reverse order, start building the per-namespace maps + // from the last month backward + + type NamespaceCounts struct { + // entityID -> present + Entities map[string]struct{} + // count + Tokens uint64 + } + byNamespace := make(map[string]*NamespaceCounts) + + createNs := func(namespaceID string) { + if _, namespacePresent := byNamespace[namespaceID]; !namespacePresent { + byNamespace[namespaceID] = &NamespaceCounts{ + Entities: make(map[string]struct{}), + Tokens: 0, + } + } + } + + walkEntities := func(l *activity.EntityActivityLog) { + for _, e := range l.Entities { + createNs(e.NamespaceID) + byNamespace[e.NamespaceID].Entities[e.EntityID] = struct{}{} + } + } + walkTokens := func(l *activity.TokenCount) { + for nsID, v := range l.CountByNamespaceID { + createNs(nsID) + byNamespace[nsID].Tokens += v + } + } + endTime := timeutil.EndOfMonth(time.Unix(lastMonth, 0).UTC()) + + for _, startTime := range times { + // Do not work back further than the current retention window, + // which will just get deleted anyway. + if startTime.Before(retentionWindow) { + break + } + + err = a.WalkEntitySegments(ctx, startTime, walkEntities) + if err != nil { + a.logger.Warn("failed to load previous segments", "error", err) + return err + } + err = a.WalkTokenSegments(ctx, startTime, walkTokens) + if err != nil { + a.logger.Warn("failed to load previous token counts", "error", err) + return err + } + + // Save the work to date in a record + pq := &activity.PrecomputedQuery{ + StartTime: startTime, + EndTime: endTime, + Namespaces: make([]*activity.NamespaceRecord, 0, len(byNamespace)), + } + for nsID, counts := range byNamespace { + pq.Namespaces = append(pq.Namespaces, &activity.NamespaceRecord{ + NamespaceID: nsID, + Entities: uint64(len(counts.Entities)), + NonEntityTokens: counts.Tokens, + }) + } + + err = a.queryStore.Put(ctx, pq) + if err != nil { + a.logger.Warn("failed to store precomputed query", "error", err) + } + } + + // Delete the intent log + a.view.Delete(ctx, activityIntentLogKey) + + a.logger.Info("finished computing queries", "month", endTime) + + return nil +} + +// goroutine to delete any segments or precomputed queries older than +// the retention period. +// We expect the return value won't be checked, so log errors as they occur +// (but for unit testing having the error return should help.) +func (a *ActivityLog) retentionWorker(currentTime time.Time, retentionMonths int) error { + ctx, cancel := context.WithCancel(namespace.RootContext(nil)) + defer cancel() + + // Cancel the context if activity log is shut down. + // This will cause the next storage operation to fail. + go func() { + select { + case <-a.doneCh: + cancel() + case <-ctx.Done(): + break + } + }() + + // everything >= the threshold is OK + retentionThreshold := timeutil.MonthsPreviousTo(retentionMonths, currentTime) + + available, err := a.availableLogs(ctx) + if err != nil { + a.logger.Warn("could not list segments", "error", err) + return err + } + for _, t := range available { + // One at a time seems OK + if t.Before(retentionThreshold) { + a.logger.Trace("deleting segments", "startTime", t) + a.deleteLogWorker(t.Unix(), make(chan struct{})) + } + } + + if a.queryStore != nil { + err = a.queryStore.DeleteQueriesBefore(ctx, retentionThreshold) + if err != nil { + a.logger.Warn("deletion of queries failed", "error", err) + } + return err + } + + return nil +} diff --git a/vault/activity_log_test.go b/vault/activity_log_test.go index 954c1dca68..06df4df7c8 100644 --- a/vault/activity_log_test.go +++ b/vault/activity_log_test.go @@ -1,14 +1,35 @@ package vault import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "sync" "testing" "time" + + "github.com/go-test/deep" + "github.com/golang/protobuf/proto" + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/helper/timeutil" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/activity" +) + +const ( + logPrefix = "sys/counters/activity/log/" ) func TestActivityLog_Creation(t *testing.T) { core, _, _ := TestCoreUnsealed(t) a := core.activityLog + a.enabled = true if a == nil { t.Fatal("no activity log found") @@ -24,7 +45,7 @@ func TestActivityLog_Creation(t *testing.T) { const namespace_id = "ns123" ts := time.Now() - a.AddEntityToFragment(entity_id, namespace_id, ts) + a.AddEntityToFragment(entity_id, namespace_id, ts.Unix()) if a.fragment == nil { t.Fatal("no fragment created") } @@ -52,8 +73,8 @@ func TestActivityLog_Creation(t *testing.T) { if er.NamespaceID != namespace_id { t.Errorf("mimatched namespace ID, %q vs %q", er.NamespaceID, namespace_id) } - if er.Timestamp != ts.UnixNano() { - t.Errorf("mimatched timestamp, %v vs %v", er.Timestamp, ts.UnixNano()) + if er.Timestamp != ts.Unix() { + t.Errorf("mimatched timestamp, %v vs %v", er.Timestamp, ts.Unix()) } // Reset and test the other code path @@ -72,5 +93,2227 @@ func TestActivityLog_Creation(t *testing.T) { if actual != 1 { t.Errorf("mismatched number of tokens, %v vs %v", actual, 1) } +} + +func checkExpectedEntitiesInMap(t *testing.T, a *ActivityLog, entityIDs []string) { + t.Helper() + + if len(a.activeEntities) != len(entityIDs) { + t.Fatalf("mismatched number of entities, expected %v got %v", len(entityIDs), a.activeEntities) + } + for _, e := range entityIDs { + if _, present := a.activeEntities[e]; !present { + t.Errorf("entity ID %q is missing", e) + } + } +} + +func TestActivityLog_UniqueEntities(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + a.enabled = true + + id1 := "11111111-1111-1111-1111-111111111111" + t1 := time.Now() + + id2 := "22222222-2222-2222-2222-222222222222" + t2 := time.Now() + t3 := t2.Add(60 * time.Second) + + a.AddEntityToFragment(id1, "root", t1.Unix()) + a.AddEntityToFragment(id2, "root", t2.Unix()) + a.AddEntityToFragment(id2, "root", t3.Unix()) + a.AddEntityToFragment(id1, "root", t3.Unix()) + + if a.fragment == nil { + t.Fatal("no current fragment") + } + + if len(a.fragment.Entities) != 2 { + t.Fatalf("number of entities is %v", len(a.fragment.Entities)) + } + + for i, e := range a.fragment.Entities { + expectedID := id1 + expectedTime := t1.Unix() + expectedNS := "root" + if i == 1 { + expectedID = id2 + expectedTime = t2.Unix() + } + if e.EntityID != expectedID { + t.Errorf("%v: expected %q, got %q", i, expectedID, e.EntityID) + } + if e.NamespaceID != expectedNS { + t.Errorf("%v: expected %q, got %q", i, expectedNS, e.NamespaceID) + } + if e.Timestamp != expectedTime { + t.Errorf("%v: expected %v, got %v", i, expectedTime, e.Timestamp) + } + } + + checkExpectedEntitiesInMap(t, a, []string{id1, id2}) +} + +func readSegmentFromStorage(t *testing.T, c *Core, path string) *logical.StorageEntry { + t.Helper() + logSegment, err := c.barrier.Get(context.Background(), path) + if err != nil { + t.Fatal(err) + } + if logSegment == nil { + t.Fatalf("expected non-nil log segment at %q", path) + } + + return logSegment +} + +func expectMissingSegment(t *testing.T, c *Core, path string) { + t.Helper() + logSegment, err := c.barrier.Get(context.Background(), path) + if err != nil { + t.Fatal(err) + } + if logSegment != nil { + t.Fatalf("expected nil log segment at %q", path) + } +} + +func writeToStorage(t *testing.T, c *Core, path string, data []byte) { + t.Helper() + err := c.barrier.Put(context.Background(), &logical.StorageEntry{ + Key: path, + Value: data, + }) + + if err != nil { + t.Fatalf("Failed to write %s to %s", data, path) + } +} + +func expectedEntityIDs(t *testing.T, out *activity.EntityActivityLog, ids []string) { + t.Helper() + + if len(out.Entities) != len(ids) { + t.Fatalf("entity log expected length %v, actual %v", len(ids), len(out.Entities)) + } + + // Double loop, OK for small cases + for _, id := range ids { + found := false + for _, e := range out.Entities { + if e.EntityID == id { + found = true + break + } + } + if !found { + t.Errorf("did not find entity ID %v", id) + } + } +} + +// TODO setup predicate for what we expect (both positive and negative case) for testing +// factor things out into predicates and actions so test body is compact. +func TestActivityLog_SaveTokensToStorage(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + a.enabled = true + // set a nonzero segment + a.currentSegment.startTimestamp = time.Now().Unix() + + nsIDs := [...]string{"ns1_id", "ns2_id", "ns3_id"} + path := fmt.Sprintf("%sdirecttokens/%d/0", logPrefix, a.currentSegment.startTimestamp) + + for i := 0; i < 3; i++ { + a.AddTokenToFragment(nsIDs[0]) + } + a.AddTokenToFragment(nsIDs[1]) + err := a.saveCurrentSegmentToStorage(context.Background(), false) + if err != nil { + t.Fatalf("got error writing tokens to storage: %v", err) + } + if a.fragment != nil { + t.Errorf("fragment was not reset after write to storage") + } + + protoSegment := readSegmentFromStorage(t, core, path) + out := &activity.TokenCount{} + err = proto.Unmarshal(protoSegment.Value, out) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + + if len(out.CountByNamespaceID) != 2 { + t.Fatalf("unexpected token length. Expected %d, got %d", 2, len(out.CountByNamespaceID)) + } + for i := 0; i < 2; i++ { + if _, ok := out.CountByNamespaceID[nsIDs[i]]; !ok { + t.Fatalf("namespace ID %s missing from token counts", nsIDs[i]) + } + } + if out.CountByNamespaceID[nsIDs[0]] != 3 { + t.Errorf("namespace ID %s has %d count, expected %d", nsIDs[0], out.CountByNamespaceID[nsIDs[0]], 3) + } + if out.CountByNamespaceID[nsIDs[1]] != 1 { + t.Errorf("namespace ID %s has %d count, expected %d", nsIDs[1], out.CountByNamespaceID[nsIDs[1]], 1) + } + + a.AddTokenToFragment(nsIDs[0]) + a.AddTokenToFragment(nsIDs[2]) + err = a.saveCurrentSegmentToStorage(context.Background(), false) + if err != nil { + t.Fatalf("got error writing tokens to storage: %v", err) + } + if a.fragment != nil { + t.Errorf("fragment was not reset after write to storage") + } + + protoSegment = readSegmentFromStorage(t, core, path) + out = &activity.TokenCount{} + err = proto.Unmarshal(protoSegment.Value, out) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + + if len(out.CountByNamespaceID) != 3 { + t.Fatalf("unexpected token length. Expected %d, got %d", 3, len(out.CountByNamespaceID)) + } + for i := 0; i < 3; i++ { + if _, ok := out.CountByNamespaceID[nsIDs[i]]; !ok { + t.Fatalf("namespace ID %s missing from token counts", nsIDs[i]) + } + } + if out.CountByNamespaceID[nsIDs[0]] != 4 { + t.Errorf("namespace ID %s has %d count, expected %d", nsIDs[0], out.CountByNamespaceID[nsIDs[0]], 4) + } + if out.CountByNamespaceID[nsIDs[1]] != 1 { + t.Errorf("namespace ID %s has %d count, expected %d", nsIDs[1], out.CountByNamespaceID[nsIDs[1]], 1) + } + if out.CountByNamespaceID[nsIDs[2]] != 1 { + t.Errorf("namespace ID %s has %d count, expected %d", nsIDs[2], out.CountByNamespaceID[nsIDs[2]], 1) + } +} + +func TestActivityLog_SaveEntitiesToStorage(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + a.enabled = true + // set a nonzero segment + a.currentSegment.startTimestamp = time.Now().Unix() + + now := time.Now() + ids := []string{"11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222", "33333333-2222-2222-2222-222222222222"} + times := [...]int64{ + now.Unix(), + now.Add(1 * time.Second).Unix(), + now.Add(2 * time.Second).Unix(), + } + path := fmt.Sprintf("%sentity/%d/0", logPrefix, a.currentSegment.startTimestamp) + + a.AddEntityToFragment(ids[0], "root", times[0]) + a.AddEntityToFragment(ids[1], "root2", times[1]) + err := a.saveCurrentSegmentToStorage(context.Background(), false) + if err != nil { + t.Fatalf("got error writing entities to storage: %v", err) + } + if a.fragment != nil { + t.Errorf("fragment was not reset after write to storage") + } + + protoSegment := readSegmentFromStorage(t, core, path) + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(protoSegment.Value, out) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + expectedEntityIDs(t, out, ids[:2]) + + a.AddEntityToFragment(ids[0], "root", times[2]) + a.AddEntityToFragment(ids[2], "root", times[2]) + err = a.saveCurrentSegmentToStorage(context.Background(), false) + if err != nil { + t.Fatalf("got error writing segments to storage: %v", err) + } + + protoSegment = readSegmentFromStorage(t, core, path) + out = &activity.EntityActivityLog{} + err = proto.Unmarshal(protoSegment.Value, out) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + expectedEntityIDs(t, out, ids) +} + +func TestActivityLog_ReceivedFragment(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + a.enabled = true + + ids := []string{ + "11111111-1111-1111-1111-111111111111", + "22222222-2222-2222-2222-222222222222", + } + + entityRecords := []*activity.EntityRecord{ + &activity.EntityRecord{ + EntityID: ids[0], + NamespaceID: "root", + Timestamp: time.Now().Unix(), + }, + &activity.EntityRecord{ + EntityID: ids[1], + NamespaceID: "root", + Timestamp: time.Now().Unix(), + }, + } + + fragment := &activity.LogFragment{ + OriginatingNode: "test-123", + Entities: entityRecords, + NonEntityTokens: make(map[string]uint64), + } + + if len(a.standbyFragmentsReceived) != 0 { + t.Fatalf("fragment already received") + } + + a.receivedFragment(fragment) + + checkExpectedEntitiesInMap(t, a, ids) + + if len(a.standbyFragmentsReceived) != 1 { + t.Fatalf("fragment count is %v, expected 1", len(a.standbyFragmentsReceived)) + } + + // Send a duplicate, should be stored but not change entity map + a.receivedFragment(fragment) + + checkExpectedEntitiesInMap(t, a, ids) + + if len(a.standbyFragmentsReceived) != 2 { + t.Fatalf("fragment count is %v, expected 2", len(a.standbyFragmentsReceived)) + } +} + +func TestActivityLog_availableLogsEmptyDirectory(t *testing.T) { + // verify that directory is empty, and nothing goes wrong + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + times, err := a.availableLogs(context.Background()) + + if err != nil { + t.Fatalf("error getting start_time(s) for empty activity log") + } + if len(times) != 0 { + t.Fatalf("invalid number of start_times returned. expected 0, got %d", len(times)) + } +} + +func TestActivityLog_availableLogs(t *testing.T) { + // set up a few files in storage + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + paths := [...]string{"entity/1111/1", "directtokens/1111/1", "directtokens/1000000/1", "entity/992/3", "directtokens/992/1"} + expectedTimes := [...]time.Time{time.Unix(1000000, 0), time.Unix(1111, 0), time.Unix(992, 0)} + + for _, path := range paths { + writeToStorage(t, core, logPrefix+path, []byte("test")) + } + + // verify above files are there, and dates in correct order + times, err := a.availableLogs(context.Background()) + if err != nil { + t.Fatalf("error getting start_time(s) for activity log") + } + + if len(times) != len(expectedTimes) { + t.Fatalf("invalid number of start_times returned. expected %d, got %d", len(expectedTimes), len(times)) + } + for i := range times { + if !times[i].Equal(expectedTimes[i]) { + t.Errorf("invalid time. expected %v, got %v", expectedTimes[i], times[i]) + } + } +} + +func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + + // enabled check is now inside AddEntityToFragment + a.enabled = true + // set a nonzero segment + a.currentSegment.startTimestamp = time.Now().Unix() + + // Stop timers for test purposes + close(a.doneCh) + + path0 := fmt.Sprintf("sys/counters/activity/log/entity/%d/0", a.currentSegment.startTimestamp) + path1 := fmt.Sprintf("sys/counters/activity/log/entity/%d/1", a.currentSegment.startTimestamp) + tokenPath := fmt.Sprintf("sys/counters/activity/log/directtokens/%d/0", a.currentSegment.startTimestamp) + + genID := func(i int) string { + return fmt.Sprintf("11111111-1111-1111-1111-%012d", i) + } + ts := time.Now().Unix() + + // First 7000 should fit in one segment + for i := 0; i < 7000; i++ { + a.AddEntityToFragment(genID(i), "root", ts) + } + + // Consume new fragment notification. + // The worker may have gotten it first, before processing + // the close! + select { + case <-a.newFragmentCh: + default: + } + + // Save incomplete segment + err := a.saveCurrentSegmentToStorage(context.Background(), false) + if err != nil { + t.Fatalf("got error writing entities to storage: %v", err) + } + + protoSegment0 := readSegmentFromStorage(t, core, path0) + entityLog0 := activity.EntityActivityLog{} + err = proto.Unmarshal(protoSegment0.Value, &entityLog0) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + if len(entityLog0.Entities) != 7000 { + t.Fatalf("unexpected entity length. Expected %d, got %d", 7000, len(entityLog0.Entities)) + } + + // 7000 more local entities + for i := 7000; i < 14000; i++ { + a.AddEntityToFragment(genID(i), "root", ts) + } + + // Simulated remote fragment with 100 duplicate entities + tokens1 := map[string]uint64{ + "root": 3, + "aaaaa": 4, + "bbbbb": 5, + } + fragment1 := &activity.LogFragment{ + OriginatingNode: "test-123", + Entities: make([]*activity.EntityRecord, 0, 100), + NonEntityTokens: tokens1, + } + for i := 7000; i < 7100; i++ { + fragment1.Entities = append(fragment1.Entities, &activity.EntityRecord{ + EntityID: genID(i), + NamespaceID: "root", + Timestamp: ts, + }) + } + + // Simulated remote fragment with 100 new entities + tokens2 := map[string]uint64{ + "root": 6, + "aaaaa": 7, + "bbbbb": 8, + } + fragment2 := &activity.LogFragment{ + OriginatingNode: "test-123", + Entities: make([]*activity.EntityRecord, 0, 100), + NonEntityTokens: tokens2, + } + for i := 14000; i < 14100; i++ { + fragment2.Entities = append(fragment2.Entities, &activity.EntityRecord{ + EntityID: genID(i), + NamespaceID: "root", + Timestamp: ts, + }) + } + a.receivedFragment(fragment1) + a.receivedFragment(fragment2) + + <-a.newFragmentCh + + err = a.saveCurrentSegmentToStorage(context.Background(), false) + if err != nil { + t.Fatalf("got error writing entities to storage: %v", err) + } + + if a.currentSegment.entitySequenceNumber != 1 { + t.Fatalf("expected sequence number 1, got %v", a.currentSegment.entitySequenceNumber) + } + + protoSegment0 = readSegmentFromStorage(t, core, path0) + err = proto.Unmarshal(protoSegment0.Value, &entityLog0) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + if len(entityLog0.Entities) != activitySegmentEntityCapacity { + t.Fatalf("unexpected entity length. Expected %d, got %d", activitySegmentEntityCapacity, + len(entityLog0.Entities)) + } + + protoSegment1 := readSegmentFromStorage(t, core, path1) + entityLog1 := activity.EntityActivityLog{} + err = proto.Unmarshal(protoSegment1.Value, &entityLog1) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + expectedCount := 14100 - activitySegmentEntityCapacity + if len(entityLog1.Entities) != expectedCount { + t.Fatalf("unexpected entity length. Expected %d, got %d", expectedCount, + len(entityLog1.Entities)) + } + + entityPresent := make(map[string]struct{}) + for _, e := range entityLog0.Entities { + entityPresent[e.EntityID] = struct{}{} + } + for _, e := range entityLog1.Entities { + entityPresent[e.EntityID] = struct{}{} + } + for i := 0; i < 14100; i++ { + expectedID := genID(i) + if _, present := entityPresent[expectedID]; !present { + t.Fatalf("entity ID %v = %v not present", i, expectedID) + } + } + + expectedNSCounts := map[string]uint64{ + "root": 9, + "aaaaa": 11, + "bbbbb": 13, + } + tokenSegment := readSegmentFromStorage(t, core, tokenPath) + tokenCount := activity.TokenCount{} + err = proto.Unmarshal(tokenSegment.Value, &tokenCount) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + + if !reflect.DeepEqual(expectedNSCounts, tokenCount.CountByNamespaceID) { + t.Fatalf("token counts are not equal, expected %v got %v", expectedNSCounts, tokenCount.CountByNamespaceID) + } +} + +func TestActivityLog_API_ConfigCRUD(t *testing.T) { + core, b, _ := testCoreSystemBackend(t) + view := core.systemBarrierView + + // Test reading the defaults + { + req := logical.TestRequest(t, logical.ReadOperation, "internal/counters/config") + req.Storage = view + resp, err := b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + defaults := map[string]interface{}{ + "default_report_months": 12, + "retention_months": 24, + "enabled": activityLogEnabledDefaultValue, + "queries_available": false, + } + + if diff := deep.Equal(resp.Data, defaults); len(diff) > 0 { + t.Fatalf("diff: %v", diff) + } + } + + // Check Error Cases + { + req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config") + req.Storage = view + req.Data["default_report_months"] = 0 + _, err := b.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatal("expected error") + } + + req = logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config") + req.Storage = view + req.Data["enabled"] = "bad-value" + _, err = b.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatal("expected error") + } + + req = logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config") + req.Storage = view + req.Data["retention_months"] = 0 + req.Data["enabled"] = "enable" + _, err = b.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatal("expected error") + } + } + + // Test single key updates + { + req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config") + req.Storage = view + req.Data["default_report_months"] = 1 + resp, err := b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %#v", resp) + } + + req = logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config") + req.Storage = view + req.Data["retention_months"] = 2 + resp, err = b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %#v", resp) + } + + req = logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config") + req.Storage = view + req.Data["enabled"] = "enable" + resp, err = b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %#v", resp) + } + + req = logical.TestRequest(t, logical.ReadOperation, "internal/counters/config") + req.Storage = view + resp, err = b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + expected := map[string]interface{}{ + "default_report_months": 1, + "retention_months": 2, + "enabled": "enable", + "queries_available": false, + } + + if diff := deep.Equal(resp.Data, expected); len(diff) > 0 { + t.Fatalf("diff: %v", diff) + } + } + + // Test updating all keys + { + req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config") + req.Storage = view + req.Data["enabled"] = "default" + req.Data["retention_months"] = 24 + req.Data["default_report_months"] = 12 + resp, err := b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %#v", resp) + } + + req = logical.TestRequest(t, logical.ReadOperation, "internal/counters/config") + req.Storage = view + resp, err = b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + + defaults := map[string]interface{}{ + "default_report_months": 12, + "retention_months": 24, + "enabled": activityLogEnabledDefaultValue, + "queries_available": false, + } + + if diff := deep.Equal(resp.Data, defaults); len(diff) > 0 { + t.Fatalf("diff: %v", diff) + } + } +} + +func TestActivityLog_parseSegmentNumberFromPath(t *testing.T) { + testCases := []struct { + input string + expected int + expectExists bool + }{ + { + input: "path/to/log/5", + expected: 5, + expectExists: true, + }, + { + input: "/path/to/log/5", + expected: 5, + expectExists: true, + }, + { + input: "path/to/log/", + expected: 0, + expectExists: false, + }, + { + input: "path/to/log/foo", + expected: 0, + expectExists: false, + }, + { + input: "", + expected: 0, + expectExists: false, + }, + { + input: "5", + expected: 5, + expectExists: true, + }, + } + + for _, tc := range testCases { + result, ok := parseSegmentNumberFromPath(tc.input) + if result != tc.expected { + t.Errorf("expected: %d, got: %d for input %q", tc.expected, result, tc.input) + } + if ok != tc.expectExists { + t.Errorf("unexpected value presence. expected exists: %t, got: %t for input %q", tc.expectExists, ok, tc.input) + } + } +} + +func TestActivityLog_getLastEntitySegmentNumber(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + paths := [...]string{"entity/992/0", "entity/1000/-1", "entity/1001/foo", "entity/1111/0", "entity/1111/1"} + for _, path := range paths { + writeToStorage(t, core, logPrefix+path, []byte("test")) + } + + testCases := []struct { + input int64 + expectedVal uint64 + expectExists bool + }{ + { + input: 992, + expectedVal: 0, + expectExists: true, + }, + { + input: 1000, + expectedVal: 0, + expectExists: false, + }, + { + input: 1001, + expectedVal: 0, + expectExists: false, + }, + { + input: 1111, + expectedVal: 1, + expectExists: true, + }, + { + input: 2222, + expectedVal: 0, + expectExists: false, + }, + } + + ctx := context.Background() + for _, tc := range testCases { + result, exists, err := a.getLastEntitySegmentNumber(ctx, time.Unix(tc.input, 0)) + if err != nil { + t.Fatalf("unexpected error for input %d: %v", tc.input, err) + } + if exists != tc.expectExists { + t.Errorf("expected result exists: %t, got: %t for input: %d", tc.expectExists, exists, tc.input) + } + if result != tc.expectedVal { + t.Errorf("expected: %d got: %d for input: %d", tc.expectedVal, result, tc.input) + } + } +} + +func TestActivityLog_tokenCountExists(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + paths := [...]string{"directtokens/992/0", "directtokens/1001/foo", "directtokens/1111/0", "directtokens/2222/1"} + for _, path := range paths { + writeToStorage(t, core, logPrefix+path, []byte("test")) + } + + testCases := []struct { + input int64 + expectExists bool + }{ + { + input: 992, + expectExists: true, + }, + { + input: 1001, + expectExists: false, + }, + { + input: 1111, + expectExists: true, + }, + { + input: 2222, + expectExists: false, + }, + } + + ctx := context.Background() + for _, tc := range testCases { + exists, err := a.tokenCountExists(ctx, time.Unix(tc.input, 0)) + if err != nil { + t.Fatalf("unexpected error for input %d: %v", tc.input, err) + } + if exists != tc.expectExists { + t.Errorf("expected segment to exist: %t but got: %t for input: %d", tc.expectExists, exists, tc.input) + } + } +} + +// entityRecordsEqual compares the parts we care about from two activity entity record slices +// note: this makes a copy of the []*activity.EntityRecord so that misordered slices won't fail the comparison, +// but the function won't modify the order of the slices to compare +func entityRecordsEqual(t *testing.T, record1, record2 []*activity.EntityRecord) bool { + t.Helper() + + if record1 == nil { + return record2 == nil + } + if record2 == nil { + return record1 == nil + } + + if len(record1) != len(record2) { + return false + } + + // sort first on namespace, then on ID, then on timestamp + entityLessFn := func(e []*activity.EntityRecord, i, j int) bool { + ei := e[i] + ej := e[j] + + nsComp := strings.Compare(ei.NamespaceID, ej.NamespaceID) + if nsComp == -1 { + return true + } + if nsComp == 1 { + return false + } + + idComp := strings.Compare(ei.EntityID, ej.EntityID) + if idComp == -1 { + return true + } + if idComp == 1 { + return false + } + + return ei.Timestamp < ej.Timestamp + } + + entitiesCopy1 := make([]*activity.EntityRecord, len(record1)) + entitiesCopy2 := make([]*activity.EntityRecord, len(record2)) + copy(entitiesCopy1, record1) + copy(entitiesCopy2, record2) + + sort.Slice(entitiesCopy1, func(i, j int) bool { + return entityLessFn(entitiesCopy1, i, j) + }) + sort.Slice(entitiesCopy2, func(i, j int) bool { + return entityLessFn(entitiesCopy2, i, j) + }) + + for i, a := range entitiesCopy1 { + b := entitiesCopy2[i] + if a.EntityID != b.EntityID || a.NamespaceID != b.NamespaceID || a.Timestamp != b.Timestamp { + return false + } + } + + return true +} + +func activeEntitiesEqual(t *testing.T, active map[string]struct{}, test []*activity.EntityRecord) bool { + t.Helper() + + if len(active) != len(test) { + return false + } + + for _, ent := range test { + if _, ok := active[ent.EntityID]; !ok { + return false + } + } + + return true +} + +func (a *ActivityLog) resetEntitiesInMemory(t *testing.T) { + t.Helper() + + a.currentSegment = segmentInfo{ + startTimestamp: time.Time{}.Unix(), + currentEntities: &activity.EntityActivityLog{ + Entities: make([]*activity.EntityRecord, 0), + }, + tokenCount: a.currentSegment.tokenCount, + entitySequenceNumber: 0, + } + + a.activeEntities = make(map[string]struct{}) +} + +func TestActivityLog_loadCurrentEntitySegment(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + + // we must verify that loadCurrentEntitySegment doesn't overwrite the in-memory token count + tokenRecords := make(map[string]uint64) + tokenRecords["test"] = 1 + tokenCount := &activity.TokenCount{ + CountByNamespaceID: tokenRecords, + } + a.currentSegment.tokenCount = tokenCount + + // setup in-storage data to load for testing + entityRecords := []*activity.EntityRecord{ + &activity.EntityRecord{ + EntityID: "11111111-1111-1111-1111-111111111111", + NamespaceID: "root", + Timestamp: time.Now().Unix(), + }, + &activity.EntityRecord{ + EntityID: "22222222-2222-2222-2222-222222222222", + NamespaceID: "root", + Timestamp: time.Now().Unix(), + }, + } + testEntities1 := &activity.EntityActivityLog{ + Entities: entityRecords[:1], + } + testEntities2 := &activity.EntityActivityLog{ + Entities: entityRecords[1:2], + } + testEntities3 := &activity.EntityActivityLog{ + Entities: entityRecords[:2], + } + + time1 := time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC).Unix() + time2 := time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC).Unix() + testCases := []struct { + time int64 + seqNum uint64 + path string + entities *activity.EntityActivityLog + }{ + { + time: time1, + seqNum: 0, + path: "entity/" + fmt.Sprint(time1) + "/0", + entities: testEntities1, + }, + { + // we want to verify that data from segment 0 hasn't been loaded + time: time1, + seqNum: 1, + path: "entity/" + fmt.Sprint(time1) + "/1", + entities: testEntities2, + }, + { + time: time2, + seqNum: 0, + path: "entity/" + fmt.Sprint(time2) + "/0", + entities: testEntities3, + }, + } + + for _, tc := range testCases { + data, err := proto.Marshal(tc.entities) + if err != nil { + t.Fatalf(err.Error()) + } + writeToStorage(t, core, logPrefix+tc.path, data) + } + + ctx := context.Background() + for _, tc := range testCases { + err := a.loadCurrentEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum) + if err != nil { + t.Fatalf("got error loading data for %q: %v", tc.path, err) + } + if !reflect.DeepEqual(a.currentSegment.tokenCount.CountByNamespaceID, tokenCount.CountByNamespaceID) { + t.Errorf("this function should not wipe out the in-memory token count") + } + + // verify accurate data in in-memory current segment + if a.currentSegment.startTimestamp != tc.time { + t.Errorf("bad timestamp loaded. expected: %v, got: %v for path %q", tc.time, a.currentSegment.startTimestamp, tc.path) + } + if a.currentSegment.entitySequenceNumber != tc.seqNum { + t.Errorf("bad sequence number loaded. expected: %v, got: %v for path %q", tc.seqNum, a.currentSegment.entitySequenceNumber, tc.path) + } + if !entityRecordsEqual(t, a.currentSegment.currentEntities.Entities, tc.entities.Entities) { + t.Errorf("bad data loaded. expected: %v, got: %v for path %q", tc.entities.Entities, a.currentSegment.currentEntities, tc.path) + } + + if !activeEntitiesEqual(t, a.activeEntities, tc.entities.Entities) { + t.Errorf("bad data loaded into active entites. expected only set of EntityID from %v in %v for path %q", tc.entities.Entities, a.activeEntities, tc.path) + } + + a.resetEntitiesInMemory(t) + } +} + +func TestActivityLog_loadPriorEntitySegment(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + a.enabled = true + + // setup in-storage data to load for testing + entityRecords := []*activity.EntityRecord{ + &activity.EntityRecord{ + EntityID: "11111111-1111-1111-1111-111111111111", + NamespaceID: "root", + Timestamp: time.Now().Unix(), + }, + &activity.EntityRecord{ + EntityID: "22222222-2222-2222-2222-222222222222", + NamespaceID: "root", + Timestamp: time.Now().Unix(), + }, + } + testEntities1 := &activity.EntityActivityLog{ + Entities: entityRecords[:1], + } + testEntities2 := &activity.EntityActivityLog{ + Entities: entityRecords[:2], + } + + time1 := time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC).Unix() + time2 := time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC).Unix() + testCases := []struct { + time int64 + seqNum uint64 + path string + entities *activity.EntityActivityLog + // set true if the in-memory active entities should be wiped because the next test case is a new month + // this also means that currentSegment.startTimestamp must be updated with :time: + refresh bool + }{ + { + time: time1, + seqNum: 0, + path: "entity/" + fmt.Sprint(time1) + "/0", + entities: testEntities1, + refresh: true, + }, + { + // verify that we don't have a duplicate (shouldn't be possible with the current implementation) + time: time1, + seqNum: 1, + path: "entity/" + fmt.Sprint(time1) + "/1", + entities: testEntities2, + refresh: true, + }, + { + time: time2, + seqNum: 0, + path: "entity/" + fmt.Sprint(time2) + "/0", + entities: testEntities2, + refresh: true, + }, + } + + for _, tc := range testCases { + data, err := proto.Marshal(tc.entities) + if err != nil { + t.Fatalf(err.Error()) + } + writeToStorage(t, core, logPrefix+tc.path, data) + } + + ctx := context.Background() + for _, tc := range testCases { + if tc.refresh { + a.activeEntities = make(map[string]struct{}) + a.currentSegment.startTimestamp = tc.time + } + + err := a.loadPriorEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum) + if err != nil { + t.Fatalf("got error loading data for %q: %v", tc.path, err) + } + + if !activeEntitiesEqual(t, a.activeEntities, tc.entities.Entities) { + t.Errorf("bad data loaded into active entites. expected only set of EntityID from %v in %v for path %q", tc.entities.Entities, a.activeEntities, tc.path) + } + } +} + +func TestActivityLog_loadTokenCount(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + + // setup in-storage data to load for testing + tokenRecords := make(map[string]uint64) + for i := 1; i < 4; i++ { + nsID := "ns" + strconv.Itoa(i) + tokenRecords[nsID] = uint64(i) + } + tokenCount := &activity.TokenCount{ + CountByNamespaceID: tokenRecords, + } + + data, err := proto.Marshal(tokenCount) + if err != nil { + t.Fatalf(err.Error()) + } + + testCases := []struct { + time int64 + path string + }{ + { + time: 1111, + path: "directtokens/1111/0", + }, + { + time: 2222, + path: "directtokens/2222/0", + }, + } + + for _, tc := range testCases { + writeToStorage(t, core, logPrefix+tc.path, data) + } + + ctx := context.Background() + for _, tc := range testCases { + // a.currentSegment.tokenCount doesn't need to be wiped each iter since it happens in loadTokenSegment() + err := a.loadTokenCount(ctx, time.Unix(tc.time, 0)) + if err != nil { + t.Fatalf("got error loading data for %q: %v", tc.path, err) + } + if !reflect.DeepEqual(a.currentSegment.tokenCount.CountByNamespaceID, tokenRecords) { + t.Errorf("bad token count loaded. expected: %v got: %v for path %q", tokenRecords, a.currentSegment.tokenCount.CountByNamespaceID, tc.path) + } + } +} + +func TestActivityLog_StopAndRestart(t *testing.T) { + core, b, _ := testCoreSystemBackend(t) + sysView := core.systemBarrierView + + a := core.activityLog + ctx := namespace.RootContext(nil) + + // Disable, then enable, to exercise newly-enabled code + a.SetConfig(ctx, activityConfig{ + Enabled: "disable", + RetentionMonths: 12, + DefaultReportMonths: 12, + }) + + // Go through request to ensure config is persisted + req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config") + req.Storage = sysView + req.Data["enabled"] = "enable" + resp, err := b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %#v", resp) + } + + // Simulate seal/unseal cycle + core.stopActivityLog() + core.setupActivityLog(ctx) + + a = core.activityLog + if a.currentSegment.tokenCount.CountByNamespaceID == nil { + t.Fatalf("nil token count map") + } + + a.AddEntityToFragment("1111-1111", "root", time.Now().Unix()) + a.AddTokenToFragment("root") + + err = a.saveCurrentSegmentToStorage(ctx, false) + if err != nil { + t.Fatal(err) + } + +} + +func setupActivityRecordsInStorage(t *testing.T, includeEntities, includeTokens bool) (*ActivityLog, []*activity.EntityRecord, map[string]uint64) { + t.Helper() + + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + now := time.Now().UTC() + monthsAgo := now.AddDate(0, -3, 0) + + var entityRecords []*activity.EntityRecord + if includeEntities { + entityRecords = []*activity.EntityRecord{ + &activity.EntityRecord{ + EntityID: "11111111-1111-1111-1111-111111111111", + NamespaceID: "root", + Timestamp: time.Now().Unix(), + }, + &activity.EntityRecord{ + EntityID: "22222222-2222-2222-2222-222222222222", + NamespaceID: "root", + Timestamp: time.Now().Unix(), + }, + &activity.EntityRecord{ + EntityID: "33333333-2222-2222-2222-222222222222", + NamespaceID: "root", + Timestamp: time.Now().Unix(), + }, + } + + testEntities1 := &activity.EntityActivityLog{ + Entities: entityRecords[:1], + } + entityData1, err := proto.Marshal(testEntities1) + if err != nil { + t.Fatalf(err.Error()) + } + testEntities2 := &activity.EntityActivityLog{ + Entities: entityRecords[1:2], + } + entityData2, err := proto.Marshal(testEntities2) + if err != nil { + t.Fatalf(err.Error()) + } + testEntities3 := &activity.EntityActivityLog{ + Entities: entityRecords[2:], + } + entityData3, err := proto.Marshal(testEntities3) + if err != nil { + t.Fatalf(err.Error()) + } + + writeToStorage(t, core, logPrefix+"entity/"+fmt.Sprint(monthsAgo.Unix())+"/0", entityData1) + writeToStorage(t, core, logPrefix+"entity/"+fmt.Sprint(now.Unix())+"/0", entityData2) + writeToStorage(t, core, logPrefix+"entity/"+fmt.Sprint(now.Unix())+"/1", entityData3) + } + + var tokenRecords map[string]uint64 + if includeTokens { + tokenRecords = make(map[string]uint64) + for i := 1; i < 4; i++ { + nsID := "ns" + strconv.Itoa(i) + tokenRecords[nsID] = uint64(i) + } + tokenCount := &activity.TokenCount{ + CountByNamespaceID: tokenRecords, + } + + tokenData, err := proto.Marshal(tokenCount) + if err != nil { + t.Fatalf(err.Error()) + } + + writeToStorage(t, core, logPrefix+"directtokens/"+fmt.Sprint(now.Unix())+"/0", tokenData) + } + + return a, entityRecords, tokenRecords +} + +func TestActivityLog_refreshFromStoredLog(t *testing.T) { + a, expectedEntityRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, true, true) + a.enabled = true + + var wg sync.WaitGroup + err := a.refreshFromStoredLog(context.Background(), &wg) + if err != nil { + t.Fatalf("got error loading stored activity logs: %v", err) + } + wg.Wait() + + expectedActive := &activity.EntityActivityLog{ + Entities: expectedEntityRecords[1:], + } + expectedCurrent := &activity.EntityActivityLog{ + Entities: expectedEntityRecords[2:], + } + if !entityRecordsEqual(t, a.currentSegment.currentEntities.Entities, expectedCurrent.Entities) { + // we only expect the newest entity segment to be loaded (for the current month) + t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, a.currentSegment.currentEntities) + } + if !reflect.DeepEqual(a.currentSegment.tokenCount.CountByNamespaceID, expectedTokenCounts) { + // we expect all token counts to be loaded + t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, a.currentSegment.tokenCount.CountByNamespaceID) + } + + if !activeEntitiesEqual(t, a.activeEntities, expectedActive.Entities) { + // we expect activeEntities to be loaded for the entire month + t.Errorf("bad data loaded into active entites. expected only set of EntityID from %v in %v", expectedActive.Entities, a.activeEntities) + } +} + +func TestActivityLog_refreshFromStoredLogOnStandby(t *testing.T) { + a, expectedEntityRecords, _ := setupActivityRecordsInStorage(t, true, true) + a.enabled = true + a.core.perfStandby = true + + var wg sync.WaitGroup + err := a.refreshFromStoredLog(context.Background(), &wg) + if err != nil { + t.Fatalf("got error loading stored activity logs: %v", err) + } + wg.Wait() + + expectedActive := &activity.EntityActivityLog{ + Entities: expectedEntityRecords[1:], + } + if !activeEntitiesEqual(t, a.activeEntities, expectedActive.Entities) { + // we expect activeEntities to be loaded for the entire month + t.Errorf("bad data loaded into active entites. expected only set of EntityID from %v in %v", expectedActive.Entities, a.activeEntities) + } + + // we expect nothing to be loaded to a.currentSegment (other than startTimestamp for end of month checking) + if len(a.currentSegment.currentEntities.Entities) > 0 { + t.Errorf("currentSegment entities should not be populated. got: %v", a.currentSegment.currentEntities) + } + if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 { + t.Errorf("currentSegment token counts should not be populated. got: %v", a.currentSegment.tokenCount.CountByNamespaceID) + } + if a.currentSegment.entitySequenceNumber != 0 { + t.Errorf("currentSegment sequence number should be 0. got: %v", a.currentSegment.entitySequenceNumber) + } +} + +func TestActivityLog_refreshFromStoredLogWithBackgroundLoadingCancelled(t *testing.T) { + a, expectedEntityRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, true, true) + a.enabled = true + + var wg sync.WaitGroup + close(a.doneCh) + + err := a.refreshFromStoredLog(context.Background(), &wg) + if err != nil { + t.Fatalf("got error loading stored activity logs: %v", err) + } + wg.Wait() + + expected := &activity.EntityActivityLog{ + Entities: expectedEntityRecords[2:], + } + if !entityRecordsEqual(t, a.currentSegment.currentEntities.Entities, expected.Entities) { + // we only expect the newest entity segment to be loaded (for the current month) + t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expected, a.currentSegment.currentEntities) + } + if !reflect.DeepEqual(a.currentSegment.tokenCount.CountByNamespaceID, expectedTokenCounts) { + // we expect all token counts to be loaded + t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, a.currentSegment.tokenCount.CountByNamespaceID) + } + + if !activeEntitiesEqual(t, a.activeEntities, expected.Entities) { + // we only expect activeEntities to be loaded for the newest segment (for the current month) + t.Errorf("bad data loaded into active entites. expected only set of EntityID from %v in %v", expected.Entities, a.activeEntities) + } +} + +func TestActivityLog_refreshFromStoredLogContextCancelled(t *testing.T) { + a, _, _ := setupActivityRecordsInStorage(t, true, true) + + var wg sync.WaitGroup + ctx, cancelFn := context.WithCancel(context.Background()) + cancelFn() + + err := a.refreshFromStoredLog(ctx, &wg) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancelled error, got: %v", err) + } +} + +func TestActivityLog_refreshFromStoredLogNoTokens(t *testing.T) { + a, expectedEntityRecords, _ := setupActivityRecordsInStorage(t, true, false) + a.enabled = true + + var wg sync.WaitGroup + err := a.refreshFromStoredLog(context.Background(), &wg) + if err != nil { + t.Fatalf("got error loading stored activity logs: %v", err) + } + wg.Wait() + + expectedActive := &activity.EntityActivityLog{ + Entities: expectedEntityRecords[1:], + } + expectedCurrent := &activity.EntityActivityLog{ + Entities: expectedEntityRecords[2:], + } + if !entityRecordsEqual(t, a.currentSegment.currentEntities.Entities, expectedCurrent.Entities) { + // we expect all segments for the current month to be loaded + t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, a.currentSegment.currentEntities) + } + if !activeEntitiesEqual(t, a.activeEntities, expectedActive.Entities) { + t.Errorf("bad data loaded into active entites. expected only set of EntityID from %v in %v", expectedActive.Entities, a.activeEntities) + } + + // we expect no tokens + if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 { + t.Errorf("expected no token counts to be loaded. got: %v", a.currentSegment.tokenCount.CountByNamespaceID) + } +} + +func TestActivityLog_refreshFromStoredLogNoEntities(t *testing.T) { + a, _, expectedTokenCounts := setupActivityRecordsInStorage(t, false, true) + a.enabled = true + + var wg sync.WaitGroup + err := a.refreshFromStoredLog(context.Background(), &wg) + if err != nil { + t.Fatalf("got error loading stored activity logs: %v", err) + } + wg.Wait() + + if !reflect.DeepEqual(a.currentSegment.tokenCount.CountByNamespaceID, expectedTokenCounts) { + // we expect all token counts to be loaded + t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, a.currentSegment.tokenCount.CountByNamespaceID) + } + + if len(a.currentSegment.currentEntities.Entities) > 0 { + t.Errorf("expected no current entity segment to be loaded. got: %v", a.currentSegment.currentEntities) + } + if len(a.activeEntities) > 0 { + t.Errorf("expected no active entity segment to be loaded. got: %v", a.activeEntities) + } +} + +func TestActivityLog_refreshFromStoredLogNoData(t *testing.T) { + a, _, _ := setupActivityRecordsInStorage(t, false, false) + a.enabled = true + + var wg sync.WaitGroup + err := a.refreshFromStoredLog(context.Background(), &wg) + if err != nil { + t.Fatalf("got error loading stored activity logs: %v", err) + } + wg.Wait() + + if len(a.currentSegment.currentEntities.Entities) > 0 { + t.Errorf("expected no current entity segment to be loaded. got: %v", a.currentSegment.currentEntities) + } + if len(a.activeEntities) > 0 { + t.Errorf("expected no active entity segment to be loaded. got: %v", a.activeEntities) + } + if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 { + t.Errorf("expected no token counts to be loaded. got: %v", a.currentSegment.tokenCount.CountByNamespaceID) + } +} + +func TestActivityLog_IncludeNamespace(t *testing.T) { + root := namespace.RootNamespace + a := &ActivityLog{} + + nsA := &namespace.Namespace{ + ID: "aaaaa", + Path: "a/", + } + nsC := &namespace.Namespace{ + ID: "ccccc", + Path: "c/", + } + nsAB := &namespace.Namespace{ + ID: "bbbbb", + Path: "a/b/", + } + + testCases := []struct { + QueryNS *namespace.Namespace + RecordNS *namespace.Namespace + Expected bool + }{ + {root, nil, true}, + {root, root, true}, + {root, nsA, true}, + {root, nsAB, true}, + {nsA, nsA, true}, + {nsA, nsAB, true}, + {nsAB, nsAB, true}, + + {nsA, root, false}, + {nsA, nil, false}, + {nsAB, root, false}, + {nsAB, nil, false}, + {nsAB, nsA, false}, + {nsC, nsA, false}, + {nsC, nsAB, false}, + } + + for _, tc := range testCases { + if a.includeInResponse(tc.QueryNS, tc.RecordNS) != tc.Expected { + t.Errorf("bad response for query %v record %v, expected %v", + tc.QueryNS, tc.RecordNS, tc.Expected) + } + } +} + +func TestActivityLog_DeleteWorker(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + + paths := []string{ + "entity/1111/1", + "entity/1111/2", + "entity/1111/3", + "entity/1112/1", + "directtokens/1111/1", + "directtokens/1112/1", + } + for _, path := range paths { + writeToStorage(t, core, logPrefix+path, []byte("test")) + } + + doneCh := make(chan struct{}) + timeout := time.After(20 * time.Second) + + go a.deleteLogWorker(1111, doneCh) + select { + case <-doneCh: + break + case <-timeout: + t.Fatalf("timed out") + } + + // Check segments still present + readSegmentFromStorage(t, core, logPrefix+"entity/1112/1") + readSegmentFromStorage(t, core, logPrefix+"directtokens/1112/1") + + // Check other segments not present + expectMissingSegment(t, core, logPrefix+"entity/1111/1") + expectMissingSegment(t, core, logPrefix+"entity/1111/2") + expectMissingSegment(t, core, logPrefix+"entity/1111/3") + expectMissingSegment(t, core, logPrefix+"directtokens/1111/1") +} + +// Skip this test if too close to the end of a month! +// TODO: move testhelper? +func SkipAtEndOfMonth(t *testing.T) { + thisMonth := timeutil.StartOfMonth(time.Now().UTC()) + endOfMonth := timeutil.EndOfMonth(thisMonth) + if endOfMonth.Sub(time.Now()) < 10*time.Minute { + t.Skip("too close to end of month") + } +} + +func TestActivityLog_EnableDisable(t *testing.T) { + SkipAtEndOfMonth(t) + + core, b, _ := testCoreSystemBackend(t) + a := core.activityLog + view := core.systemBarrierView + ctx := namespace.RootContext(nil) + + enableRequest := func() { + t.Helper() + req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config") + req.Storage = view + req.Data["enabled"] = "enable" + resp, err := b.HandleRequest(ctx, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %#v", resp) + } + } + disableRequest := func() { + t.Helper() + req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config") + req.Storage = view + req.Data["enabled"] = "disable" + resp, err := b.HandleRequest(ctx, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %#v", resp) + } + } + + // enable (if not already) and write a segment + enableRequest() + + id1 := "11111111-1111-1111-1111-111111111111" + id2 := "22222222-2222-2222-2222-222222222222" + id3 := "33333333-3333-3333-3333-333333333333" + a.AddEntityToFragment(id1, "root", time.Now().Unix()) + a.AddEntityToFragment(id2, "root", time.Now().Unix()) + + a.currentSegment.startTimestamp -= 10 + seg1 := a.currentSegment.startTimestamp + err := a.saveCurrentSegmentToStorage(ctx, false) + if err != nil { + t.Fatal(err) + } + + // verify segment exists + path := fmt.Sprintf("%ventity/%v/0", logPrefix, seg1) + readSegmentFromStorage(t, core, path) + + // Add in-memory fragment + a.AddEntityToFragment(id3, "root", time.Now().Unix()) + + // disable and verify segment no longer exists + disableRequest() + + timeout := time.After(20 * time.Second) + select { + case <-a.deleteDone: + break + case <-timeout: + t.Fatalf("timed out") + } + + expectMissingSegment(t, core, path) + if a.currentSegment.startTimestamp != 0 { + t.Errorf("bad startTimestamp, expected 0 got %v", a.currentSegment.startTimestamp) + } + if len(a.currentSegment.currentEntities.Entities) != 0 { + t.Errorf("expected empty currentEntities, got %v", a.currentSegment.currentEntities.Entities) + } + if len(a.currentSegment.tokenCount.CountByNamespaceID) != 0 { + t.Errorf("expected empty tokens, got %v", a.currentSegment.tokenCount.CountByNamespaceID) + } + if len(a.activeEntities) != 0 { + t.Errorf("expected empty activeEntities, got %v", a.activeEntities) + } + if a.fragment != nil { + t.Errorf("expected nil fragment") + } + + // enable (if not already) which force-writes an empty segment + enableRequest() + + seg2 := a.currentSegment.startTimestamp + if seg1 >= seg2 { + t.Errorf("bad second segment timestamp, %v >= %v", seg1, seg2) + } + + // Verify empty segments are present + path = fmt.Sprintf("%ventity/%v/0", logPrefix, seg2) + readSegmentFromStorage(t, core, path) + + path = fmt.Sprintf("%vdirecttokens/%v/0", logPrefix, seg2) + readSegmentFromStorage(t, core, path) +} + +func TestActivityLog_EndOfMonth(t *testing.T) { + // We only want *fake* end of months, *real* ones are too scary. + SkipAtEndOfMonth(t) + + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + ctx := namespace.RootContext(nil) + + // Make sure we're enabled. + a.SetConfig(ctx, activityConfig{ + Enabled: "enable", + RetentionMonths: 12, + DefaultReportMonths: 12, + }) + + id1 := "11111111-1111-1111-1111-111111111111" + id2 := "22222222-2222-2222-2222-222222222222" + id3 := "33333333-3333-3333-3333-333333333333" + a.AddEntityToFragment(id1, "root", time.Now().Unix()) + + month0 := time.Now().UTC() + segment0 := a.currentSegment.startTimestamp + month1 := month0.AddDate(0, 1, 0) + month2 := month0.AddDate(0, 2, 0) + + // Trigger end-of-month + a.HandleEndOfMonth(month1) + + // Check segment is present, with 1 entity + path := fmt.Sprintf("%ventity/%v/0", logPrefix, segment0) + protoSegment := readSegmentFromStorage(t, core, path) + out := &activity.EntityActivityLog{} + err := proto.Unmarshal(protoSegment.Value, out) + if err != nil { + t.Fatal(err) + } + + segment1 := a.currentSegment.startTimestamp + expectedTimestamp := timeutil.StartOfMonth(month1).Unix() + if segment1 != expectedTimestamp { + t.Errorf("expected segment timestamp %v got %v", expectedTimestamp, segment1) + } + + // Check intent log is present + intentRaw, err := core.barrier.Get(ctx, "sys/counters/activity/endofmonth") + if err != nil { + t.Fatal(err) + } + var intent ActivityIntentLog + err = intentRaw.DecodeJSON(&intent) + if err != nil { + t.Fatal(err) + } + + if intent.PreviousMonth != segment0 { + t.Errorf("expected previous month %v got %v", segment0, intent.PreviousMonth) + } + + if intent.NextMonth != segment1 { + t.Errorf("expected previous month %v got %v", segment1, intent.NextMonth) + } + + a.AddEntityToFragment(id2, "root", time.Now().Unix()) + + a.HandleEndOfMonth(month2) + segment2 := a.currentSegment.startTimestamp + + a.AddEntityToFragment(id3, "root", time.Now().Unix()) + + err = a.saveCurrentSegmentToStorage(ctx, false) + if err != nil { + t.Fatal(err) + } + + // Check all three segments still present, with correct entities + testCases := []struct { + SegmentTimestamp int64 + ExpectedEntityIDs []string + }{ + {segment0, []string{id1}}, + {segment1, []string{id2}}, + {segment2, []string{id3}}, + } + + for i, tc := range testCases { + t.Logf("checking segment %v timestamp %v", i, tc.SegmentTimestamp) + path := fmt.Sprintf("%ventity/%v/0", logPrefix, tc.SegmentTimestamp) + protoSegment := readSegmentFromStorage(t, core, path) + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(protoSegment.Value, out) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + expectedEntityIDs(t, out, tc.ExpectedEntityIDs) + } +} + +func TestActivityLog_SaveAfterDisable(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + ctx := namespace.RootContext(nil) + a := core.activityLog + a.SetConfig(ctx, activityConfig{ + Enabled: "enable", + RetentionMonths: 12, + DefaultReportMonths: 12, + }) + + a.AddEntityToFragment("1111-1111-11111111", "root", time.Now().Unix()) + startTimestamp := a.currentSegment.startTimestamp + + // This kicks off an asynchronous delete + a.SetConfig(ctx, activityConfig{ + Enabled: "disable", + RetentionMonths: 12, + DefaultReportMonths: 12, + }) + + timer := time.After(10 * time.Second) + select { + case <-timer: + t.Fatal("timeout waiting for delete to finish") + case <-a.deleteDone: + break + } + + // Segment should not be written even with force + err := a.saveCurrentSegmentToStorage(context.Background(), true) + if err != nil { + t.Fatal(err) + } + + path := logPrefix + "entity/0/0" + expectMissingSegment(t, core, path) + + path = fmt.Sprintf("%ventity/%v/0", logPrefix, startTimestamp) + expectMissingSegment(t, core, path) +} + +func TestActivityLog_Precompute(t *testing.T) { + SkipAtEndOfMonth(t) + + january := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + august := time.Date(2020, 8, 15, 12, 0, 0, 0, time.UTC) + september := timeutil.StartOfMonth(time.Date(2020, 9, 1, 0, 0, 0, 0, time.UTC)) + october := timeutil.StartOfMonth(time.Date(2020, 10, 1, 0, 0, 0, 0, time.UTC)) + november := timeutil.StartOfMonth(time.Date(2020, 11, 1, 0, 0, 0, 0, time.UTC)) + + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + ctx := namespace.RootContext(nil) + + // Generate overlapping sets of entity IDs from this list. + // january: 40-44 RRRRR + // first month: 0-19 RRRRRAAAAABBBBBRRRRR + // second month: 10-29 BBBBBRRRRRRRRRRCCCCC + // third month: 15-39 RRRRRRRRRRCCCCCRRRRRBBBBB + + entityRecords := make([]*activity.EntityRecord, 45) + entityNamespaces := []string{"root", "aaaaa", "bbbbb", "root", "root", "ccccc", "root", "bbbbb", "rrrrr"} + + for i := range entityRecords { + entityRecords[i] = &activity.EntityRecord{ + EntityID: fmt.Sprintf("111122222-3333-4444-5555-%012v", i), + NamespaceID: entityNamespaces[i/5], + Timestamp: time.Now().Unix(), + } + } + + toInsert := []struct { + StartTime int64 + Segment uint64 + Entities []*activity.EntityRecord + }{ + // January, should not be included + { + january.Unix(), + 0, + entityRecords[40:45], + }, + // Artifically split August and October + { // 1 + august.Unix(), + 0, + entityRecords[:13], + }, + { // 2 + august.Unix(), + 1, + entityRecords[13:20], + }, + { // 3 + september.Unix(), + 0, + entityRecords[10:30], + }, + { // 4 + october.Unix(), + 0, + entityRecords[15:40], + }, + { + october.Unix(), + 1, + entityRecords[15:40], + }, + { + october.Unix(), + 2, + entityRecords[17:23], + }, + } + + // Note that precomputedQuery worker doesn't filter + // for times <= the one it was asked to do. Is that a problem? + // Here, it means that we can't insert everything *first* and do multiple + // test cases, we have to write logs incrementally. + doInsert := func(i int) { + segment := toInsert[i] + eal := &activity.EntityActivityLog{ + Entities: segment.Entities, + } + data, err := proto.Marshal(eal) + if err != nil { + t.Fatal(err) + } + path := fmt.Sprintf("%ventity/%v/%v", logPrefix, segment.StartTime, segment.Segment) + writeToStorage(t, core, path, data) + } + + expectedCounts := []struct { + StartTime time.Time + EndTime time.Time + ByNamespace map[string]int + }{ + // First test case + { + august, + timeutil.EndOfMonth(august), + map[string]int{ + "root": 10, + "aaaaa": 5, + "bbbbb": 5, + }, + }, + // Second test case + { + august, + timeutil.EndOfMonth(september), + map[string]int{ + "root": 15, + "aaaaa": 5, + "bbbbb": 5, + "ccccc": 5, + }, + }, + { + september, + timeutil.EndOfMonth(september), + map[string]int{ + "root": 10, + "bbbbb": 5, + "ccccc": 5, + }, + }, + // Third test case + { + august, + timeutil.EndOfMonth(october), + map[string]int{ + "root": 20, + "aaaaa": 5, + "bbbbb": 10, + "ccccc": 5, + }, + }, + { + september, + timeutil.EndOfMonth(october), + map[string]int{ + "root": 15, + "bbbbb": 10, + "ccccc": 5, + }, + }, + { + october, + timeutil.EndOfMonth(october), + map[string]int{ + "root": 15, + "bbbbb": 5, + "ccccc": 5, + }, + }, + } + + checkPrecomputedQuery := func(i int) { + t.Helper() + pq, err := a.queryStore.Get(ctx, expectedCounts[i].StartTime, expectedCounts[i].EndTime) + if err != nil { + t.Fatal(err) + } + if pq == nil { + t.Errorf("empty result for %v -- %v", expectedCounts[i].StartTime, expectedCounts[i].EndTime) + } + if len(pq.Namespaces) != len(expectedCounts[i].ByNamespace) { + t.Errorf("mismatched number of namespaces, expected %v got %v", + len(expectedCounts[i].ByNamespace), len(pq.Namespaces)) + } + for _, nsRecord := range pq.Namespaces { + val, ok := expectedCounts[i].ByNamespace[nsRecord.NamespaceID] + if !ok { + t.Errorf("unexpected namespace %v", nsRecord.NamespaceID) + continue + } + if uint64(val) != nsRecord.Entities { + t.Errorf("wrong number of entities in %v: expected %v, got %v", + nsRecord.NamespaceID, val, nsRecord.Entities) + } + } + if !pq.StartTime.Equal(expectedCounts[i].StartTime) { + t.Errorf("mismatched start time: expected %v got %v", + expectedCounts[i].StartTime, pq.StartTime) + } + if !pq.EndTime.Equal(expectedCounts[i].EndTime) { + t.Errorf("mismatched end time: expected %v got %v", + expectedCounts[i].EndTime, pq.EndTime) + } + } + + testCases := []struct { + InsertUpTo int // index in the toInsert array + PrevMonth int64 + NextMonth int64 + ExpectedUpTo int // index in the expectedCounts array + }{ + { + 2, // jan-august + august.Unix(), + september.Unix(), + 0, // august-august + }, + { + 3, // jan-sept + september.Unix(), + october.Unix(), + 2, // august-september + }, + { + 6, // jan-oct + october.Unix(), + november.Unix(), + 5, // august-september + }, + } + + inserted := -1 + for _, tc := range testCases { + t.Logf("tc %+v", tc) + + // Persists across loops + for inserted < tc.InsertUpTo { + inserted += 1 + t.Logf("inserting segment %v", inserted) + doInsert(inserted) + } + + intent := &ActivityIntentLog{ + PreviousMonth: tc.PrevMonth, + NextMonth: tc.NextMonth, + } + data, err := json.Marshal(intent) + if err != nil { + t.Fatal(err) + } + writeToStorage(t, core, "sys/counters/activity/endofmonth", data) + + // Pretend we've successfully rolled over to the following month + a.l.Lock() + a.currentSegment.startTimestamp = tc.NextMonth + a.l.Unlock() + + err = a.precomputedQueryWorker() + if err != nil { + t.Fatal(err) + } + + expectMissingSegment(t, core, "sys/counters/activity/endofmonth") + + for i := 0; i <= tc.ExpectedUpTo; i++ { + checkPrecomputedQuery(i) + } + + } +} + +type BlockingInmemStorage struct { +} + +func (b *BlockingInmemStorage) List(ctx context.Context, prefix string) ([]string, error) { + <-ctx.Done() + return nil, errors.New("fake implementation") +} + +func (b *BlockingInmemStorage) Get(ctx context.Context, key string) (*logical.StorageEntry, error) { + <-ctx.Done() + return nil, errors.New("fake implementation") +} + +func (b *BlockingInmemStorage) Put(ctx context.Context, entry *logical.StorageEntry) error { + <-ctx.Done() + return errors.New("fake implementation") +} + +func (b *BlockingInmemStorage) Delete(ctx context.Context, key string) error { + <-ctx.Done() + return errors.New("fake implementation") +} + +func TestActivityLog_PrecomputeCancel(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + + // Substitute in a new view + a.view = NewBarrierView(&BlockingInmemStorage{}, "test") + + core.stopActivityLog() + + done := make(chan struct{}) + + // This will block if the shutdown didn't work. + go func() { + a.precomputedQueryWorker() + close(done) + }() + + timeout := time.After(5 * time.Second) + + select { + case <-done: + break + case <-timeout: + t.Fatalf("timeout waiting for worker to finish") + } + +} + +func TestActivityLog_NextMonthStart(t *testing.T) { + SkipAtEndOfMonth(t) + + now := time.Now().UTC() + year, month, _ := now.Date() + computedStart := time.Date(year, month, 1, 0, 0, 0, 0, time.UTC).AddDate(0, 1, 0) + + testCases := []struct { + SegmentStart int64 + ExpectedTime time.Time + }{ + { + 0, + computedStart, + }, + { + time.Date(2021, 2, 12, 13, 14, 15, 0, time.UTC).Unix(), + time.Date(2021, 3, 1, 0, 0, 0, 0, time.UTC), + }, + { + time.Date(2021, 3, 1, 0, 0, 0, 0, time.UTC).Unix(), + time.Date(2021, 4, 1, 0, 0, 0, 0, time.UTC), + }, + } + + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + + for _, tc := range testCases { + t.Logf("segmentStart=%v", tc.SegmentStart) + a.l.Lock() + a.currentSegment.startTimestamp = tc.SegmentStart + a.l.Unlock() + + actual := a.StartOfNextMonth() + if !actual.Equal(tc.ExpectedTime) { + t.Errorf("expected %v, got %v", tc.ExpectedTime, actual) + } + } +} + +func TestActivityLog_Deletion(t *testing.T) { + SkipAtEndOfMonth(t) + + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog + + times := []time.Time{ + time.Date(2019, 1, 15, 1, 2, 3, 0, time.UTC), // 0 + time.Date(2019, 3, 15, 1, 2, 3, 0, time.UTC), + time.Date(2019, 4, 1, 0, 0, 0, 0, time.UTC), + time.Date(2019, 5, 1, 0, 0, 0, 0, time.UTC), + time.Date(2019, 6, 1, 0, 0, 0, 0, time.UTC), + time.Date(2019, 7, 1, 0, 0, 0, 0, time.UTC), // 5 + time.Date(2019, 8, 1, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 1, 0, 0, 0, 0, time.UTC), + time.Date(2019, 10, 1, 0, 0, 0, 0, time.UTC), + time.Date(2019, 11, 1, 0, 0, 0, 0, time.UTC), // <-- 12 months starts here + time.Date(2019, 12, 1, 0, 0, 0, 0, time.UTC), // 10 + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC), // 15 + time.Date(2020, 6, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 7, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 8, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 9, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 10, 1, 0, 0, 0, 0, time.UTC), // 20 + time.Date(2020, 11, 1, 0, 0, 0, 0, time.UTC), + } + + novIndex := len(times) - 1 + paths := make([][]string, len(times)) + + for i, start := range times { + // no entities in some months, just for fun + for j := 0; j < (i+3)%5; j++ { + entityPath := fmt.Sprintf("%ventity/%v/%v", logPrefix, start.Unix(), j) + paths[i] = append(paths[i], entityPath) + writeToStorage(t, core, entityPath, []byte("test")) + } + tokenPath := fmt.Sprintf("%vdirecttokens/%v/0", logPrefix, start.Unix()) + paths[i] = append(paths[i], tokenPath) + writeToStorage(t, core, tokenPath, []byte("test")) + + // No queries for November yet + if i < novIndex { + for _, endTime := range times[i+1 : novIndex] { + queryPath := fmt.Sprintf("sys/counters/activity/queries/%v/%v", start.Unix(), endTime.Unix()) + paths[i] = append(paths[i], queryPath) + writeToStorage(t, core, queryPath, []byte("test")) + } + } + } + + checkPresent := func(i int) { + t.Helper() + for _, p := range paths[i] { + readSegmentFromStorage(t, core, p) + } + } + + checkAbsent := func(i int) { + t.Helper() + for _, p := range paths[i] { + expectMissingSegment(t, core, p) + } + } + + t.Log("24 months") + now := times[len(times)-1] + err := a.retentionWorker(now, 24) + if err != nil { + t.Fatal(err) + } + for i := range times { + checkPresent(i) + } + + t.Log("12 months") + err = a.retentionWorker(now, 12) + if err != nil { + t.Fatal(err) + } + for i := 0; i <= 8; i++ { + checkAbsent(i) + } + for i := 9; i <= 21; i++ { + checkPresent(i) + } + + t.Log("1 month") + err = a.retentionWorker(now, 1) + if err != nil { + t.Fatal(err) + } + for i := 0; i <= 19; i++ { + checkAbsent(i) + } + checkPresent(20) + checkPresent(21) + + t.Log("0 months") + err = a.retentionWorker(now, 0) + if err != nil { + t.Fatal(err) + } + for i := 0; i <= 20; i++ { + checkAbsent(i) + } + checkPresent(21) } diff --git a/vault/activity_log_util.go b/vault/activity_log_util.go new file mode 100644 index 0000000000..4a187e8727 --- /dev/null +++ b/vault/activity_log_util.go @@ -0,0 +1,10 @@ +// +build !enterprise + +package vault + +import "context" + +// sendCurrentFragment is a no-op on OSS +func (a *ActivityLog) sendCurrentFragment(ctx context.Context) error { + return nil +} diff --git a/vault/auth.go b/vault/auth.go index 11efc11829..5b53947166 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -425,7 +425,7 @@ func (c *Core) taintCredEntry(ctx context.Context, path string, updateStorage bo // Taint the entry from the auth table // We do this on the original since setting the taint operates // on the entries which a shallow clone shares anyways - entry, err := c.auth.setTaint(ctx, strings.TrimPrefix(path, credentialRoutePrefix), true) + entry, err := c.auth.setTaint(ctx, strings.TrimPrefix(path, credentialRoutePrefix), true, mountStateUnmounting) if err != nil { return err } diff --git a/vault/core.go b/vault/core.go index d8a1abbc77..30084f48e3 100644 --- a/vault/core.go +++ b/vault/core.go @@ -532,6 +532,8 @@ type Core struct { quotaManager *quotas.Manager clusterHeartbeatInterval time.Duration + + activityLogConfig ActivityLogCoreConfig } // CoreConfig is used to parameterize a core @@ -631,6 +633,9 @@ type CoreConfig struct { ClusterNetworkLayer cluster.NetworkLayer ClusterHeartbeatInterval time.Duration + + // Activity log controls + ActivityLogConfig ActivityLogCoreConfig } // GetServiceRegistration returns the config's ServiceRegistration, or nil if it does @@ -770,9 +775,9 @@ func NewCore(conf *CoreConfig) (*Core, error) { postUnsealStarted: new(uint32), raftJoinDoneCh: make(chan struct{}), clusterHeartbeatInterval: clusterHeartbeatInterval, + activityLogConfig: conf.ActivityLogConfig, } c.standbyStopCh.Store(make(chan struct{})) - atomic.StoreUint32(c.sealed, 1) c.metricSink.SetGaugeWithLabels([]string{"core", "unsealed"}, 0, nil) @@ -2061,6 +2066,9 @@ func (c *Core) preSeal() error { if err := c.stopExpiration(); err != nil { result = multierror.Append(result, errwrap.Wrapf("error stopping expiration: {{err}}", err)) } + if err := c.stopActivityLog(); err != nil { + result = multierror.Append(result, errwrap.Wrapf("error stopping activity log: {{err}}", err)) + } if err := c.teardownCredentials(context.Background()); err != nil { result = multierror.Append(result, errwrap.Wrapf("error tearing down credentials: {{err}}", err)) } diff --git a/vault/core_metrics.go b/vault/core_metrics.go index ecd041e4be..10e42ddffc 100644 --- a/vault/core_metrics.go +++ b/vault/core_metrics.go @@ -68,8 +68,8 @@ func (c *Core) metricsLoop(stopCh chan struct{}) { } c.stateLock.RUnlock() case <-identityCountTimer: - // Only emit on active node - if c.PerfStandby() { + // Only emit on active node of cluster that is not a DR cecondary. + if standby, _ := c.Standby(); standby || c.IsDRSecondary() { break } @@ -196,10 +196,11 @@ func (c *Core) emitMetrics(stopCh chan struct{}) { }, } - // Disable collection if configured, or if we're a performance standby. + // Disable collection if configured, or if we're a performance standby + // node or DR secondary cluster. if c.MetricSink().GaugeInterval == time.Duration(0) { c.logger.Info("usage gauge collection is disabled") - } else if !c.PerfStandby() { + } else if standby, _ := c.Standby(); !standby && !c.IsDRSecondary() { for _, init := range metricsInit { if init.DisableEnvVar != "" { if os.Getenv(init.DisableEnvVar) != "" { diff --git a/vault/core_util.go b/vault/core_util.go index 2e95dc0f3c..4296c462ba 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -13,6 +13,11 @@ import ( "github.com/hashicorp/vault/vault/replication" ) +const ( + activityLogEnabledDefault = false + activityLogEnabledDefaultValue = "default-disabled" +) + type entCore struct{} type entCoreConfig struct{} diff --git a/vault/external_tests/quotas/quotas_test.go b/vault/external_tests/quotas/quotas_test.go index 624c41e53a..696889045b 100644 --- a/vault/external_tests/quotas/quotas_test.go +++ b/vault/external_tests/quotas/quotas_test.go @@ -127,7 +127,7 @@ func waitForRemovalOrTimeout(c *api.Client, path string, tick, to time.Duration) } } -func TestQuotas_RateLimitQuota_DupName(t *testing.T) { +func TestQuotas_RateLimit_DupName(t *testing.T) { conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil) cluster := vault.NewTestCluster(t, conf, opts) cluster.Start() diff --git a/vault/logical_system.go b/vault/logical_system.go index 647f9886db..dc37f2995d 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -167,6 +167,7 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend { b.Backend.Paths = append(b.Backend.Paths, b.monitorPath()) b.Backend.Paths = append(b.Backend.Paths, b.hostInfoPath()) b.Backend.Paths = append(b.Backend.Paths, b.quotasPaths()...) + b.Backend.Paths = append(b.Backend.Paths, b.rootActivityPaths()...) if core.rawEnabled { b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...) @@ -4391,4 +4392,12 @@ This path responds to the following HTTP methods. The information that gets collected includes host hardware information, and CPU, disk, and memory utilization`, }, + "activity-query": { + "Query the historical count of clients.", + "Query the historical count of clients.", + }, + "activity-config": { + "Control the collection and reporting of client counts.", + "Control the collection and reporting of client counts.", + }, } diff --git a/vault/logical_system_activity.go b/vault/logical_system_activity.go new file mode 100644 index 0000000000..9fac23e733 --- /dev/null +++ b/vault/logical_system_activity.go @@ -0,0 +1,220 @@ +package vault + +import ( + "context" + "net/http" + "path" + "strings" + "time" + + "github.com/hashicorp/vault/helper/timeutil" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/logical" +) + +// activityQueryPath is available in every namespace +func (b *SystemBackend) activityQueryPath() *framework.Path { + return &framework.Path{ + Pattern: "internal/counters/activity$", + Fields: map[string]*framework.FieldSchema{ + "start_time": &framework.FieldSchema{ + Type: framework.TypeTime, + Description: "Start of query interval", + }, + "end_time": &framework.FieldSchema{ + Type: framework.TypeTime, + Description: "End of query interval", + }, + }, + HelpSynopsis: strings.TrimSpace(sysHelp["activity-query"][0]), + HelpDescription: strings.TrimSpace(sysHelp["activity-query"][1]), + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handleClientMetricQuery, + Summary: "Report the client count metrics, for this namespace and all child namespaces.", + }, + }, + } +} + +// rootActivityPaths are available only in the root namespace +func (b *SystemBackend) rootActivityPaths() []*framework.Path { + return []*framework.Path{ + b.activityQueryPath(), + { + Pattern: "internal/counters/config$", + Fields: map[string]*framework.FieldSchema{ + "default_report_months": &framework.FieldSchema{ + Type: framework.TypeInt, + Default: 12, + Description: "Number of months to report if no start date specified.", + }, + "retention_months": &framework.FieldSchema{ + Type: framework.TypeInt, + Default: 24, + Description: "Number of months of client data to retain. Setting to 0 will clear all existing data.", + }, + "enabled": &framework.FieldSchema{ + Type: framework.TypeString, + Default: "default", + Description: "Enable or disable collection of client count: enable, disable, or default.", + }, + }, + HelpSynopsis: strings.TrimSpace(sysHelp["activity-config"][0]), + HelpDescription: strings.TrimSpace(sysHelp["activity-config"][1]), + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handleActivityConfigRead, + Summary: "Read the client count tracking configuration.", + }, + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.handleActivityConfigUpdate, + Summary: "Enable or disable collection of client count, set retention period, or set default reporting period.", + }, + }, + }, + } +} + +func (b *SystemBackend) handleClientMetricQuery(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + a := b.Core.activityLog + if a == nil { + return logical.ErrorResponse("no activity log present"), nil + } + + startTime := d.Get("start_time").(time.Time) + endTime := d.Get("end_time").(time.Time) + + // If a specific endTime is used, then respect that + // otherwise we want to give the latest N months, so go back to the start + // of the previous month + // + // Also convert any user inputs to UTC to avoid + // problems later. + if endTime.IsZero() { + endTime = timeutil.EndOfMonth(time.Now().UTC().AddDate(0, -1, 0)) + } else { + endTime = endTime.UTC() + } + if startTime.IsZero() { + startTime = a.DefaultStartTime(endTime) + } else { + startTime = startTime.UTC() + } + if startTime.After(endTime) { + return logical.ErrorResponse("start_time is later than end_time"), nil + } + + results, err := a.handleQuery(ctx, startTime, endTime) + if err != nil { + return nil, err + } + if results == nil { + resp204, err := logical.RespondWithStatusCode(nil, req, http.StatusNoContent) + return resp204, err + } + + return &logical.Response{ + Data: results, + }, nil +} + +func (b *SystemBackend) handleActivityConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + a := b.Core.activityLog + if a == nil { + return logical.ErrorResponse("no activity log present"), nil + } + + config, err := a.loadConfigOrDefault(ctx) + if err != nil { + return nil, err + } + + qa, err := a.queriesAvailable(ctx) + if err != nil { + return nil, err + } + + if config.Enabled == "default" { + config.Enabled = activityLogEnabledDefaultValue + } + + return &logical.Response{ + Data: map[string]interface{}{ + "default_report_months": config.DefaultReportMonths, + "retention_months": config.RetentionMonths, + "enabled": config.Enabled, + "queries_available": qa, + }, + }, nil +} + +func (b *SystemBackend) handleActivityConfigUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + a := b.Core.activityLog + if a == nil { + return logical.ErrorResponse("no activity log present"), nil + } + + config, err := a.loadConfigOrDefault(ctx) + if err != nil { + return nil, err + } + + { + // Parse the default report months + if defaultReportMonthsRaw, ok := d.GetOk("default_report_months"); ok { + config.DefaultReportMonths = defaultReportMonthsRaw.(int) + } + + if config.DefaultReportMonths <= 0 { + return logical.ErrorResponse("default_report_months must be greater than 0"), logical.ErrInvalidRequest + } + } + + { + // Parse the retention months + if retentionMonthsRaw, ok := d.GetOk("retention_months"); ok { + config.RetentionMonths = retentionMonthsRaw.(int) + } + + if config.RetentionMonths < 0 { + return logical.ErrorResponse("retention_months must be greater than or equal to 0"), logical.ErrInvalidRequest + } + } + + { + // Parse the enabled setting + if enabledRaw, ok := d.GetOk("enabled"); ok { + config.Enabled = enabledRaw.(string) + } + switch config.Enabled { + case "default", "enable", "disable": + default: + return logical.ErrorResponse("enabled must be one of \"default\", \"enable\", \"disable\""), logical.ErrInvalidRequest + } + } + + enabled := config.Enabled == "enable" + if !enabled && config.Enabled == "default" { + enabled = activityLogEnabledDefault + } + + if enabled && config.RetentionMonths == 0 { + return logical.ErrorResponse("retention_months cannot be 0 while enabled"), logical.ErrInvalidRequest + } + + // Store the config + entry, err := logical.StorageEntryJSON(path.Join(activitySubPath, activityConfigKey), config) + if err != nil { + return nil, err + } + if err := req.Storage.Put(ctx, entry); err != nil { + return nil, err + } + + // Set the new config on the activity log + a.SetConfig(ctx, config) + + return nil, nil +} diff --git a/vault/mount.go b/vault/mount.go index 1ed994e7f6..0a1d5cc5ca 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -186,7 +186,7 @@ func (t *MountTable) shallowClone() *MountTable { // setTaint is used to set the taint on given entry Accepts either the mount // entry's path or namespace + path, i.e. /secret/ or /token/ -func (t *MountTable) setTaint(ctx context.Context, path string, value bool) (*MountEntry, error) { +func (t *MountTable) setTaint(ctx context.Context, path string, tainted bool, mountState string) (*MountEntry, error) { n := len(t.Entries) ns, err := namespace.FromContext(ctx) if err != nil { @@ -194,7 +194,8 @@ func (t *MountTable) setTaint(ctx context.Context, path string, value bool) (*Mo } for i := 0; i < n; i++ { if entry := t.Entries[i]; entry.Path == path && entry.Namespace().ID == ns.ID { - t.Entries[i].Tainted = value + t.Entries[i].Tainted = tainted + t.Entries[i].MountState = mountState return t.Entries[i], nil } } @@ -253,6 +254,8 @@ func (t *MountTable) sortEntriesByPathDepth() *MountTable { return t } +const mountStateUnmounting = "unmounting" + // MountEntry is used to represent a mount table entry type MountEntry struct { Table string `json:"table"` // The table it belongs to @@ -268,6 +271,7 @@ type MountEntry struct { SealWrap bool `json:"seal_wrap"` // Whether to wrap CSPs ExternalEntropyAccess bool `json:"external_entropy_access,omitempty"` // Whether to allow external entropy source access Tainted bool `json:"tainted,omitempty"` // Set as a Write-Ahead flag for unmount/remount + MountState string `json:"mount_state,omitempty"` // The current mount state. The only non-empty mount state right now is "unmounting" NamespaceID string `json:"namespace_id"` // namespace contains the populated namespace @@ -641,7 +645,7 @@ func (c *Core) unmountInternal(ctx context.Context, path string, updateStorage b entry := c.router.MatchingMountEntry(ctx, path) // Mark the entry as tainted - if err := c.taintMountEntry(ctx, path, updateStorage); err != nil { + if err := c.taintMountEntry(ctx, path, updateStorage, true); err != nil { c.logger.Error("failed to taint mount entry for path being unmounted", "error", err, "path", path) return err } @@ -759,13 +763,18 @@ func (c *Core) removeMountEntry(ctx context.Context, path string, updateStorage } // taintMountEntry is used to mark an entry in the mount table as tainted -func (c *Core) taintMountEntry(ctx context.Context, path string, updateStorage bool) error { +func (c *Core) taintMountEntry(ctx context.Context, path string, updateStorage, unmounting bool) error { c.mountsLock.Lock() defer c.mountsLock.Unlock() + mountState := "" + if unmounting { + mountState = mountStateUnmounting + } + // As modifying the taint of an entry affects shallow clones, // we simply use the original - entry, err := c.mounts.setTaint(ctx, path, true) + entry, err := c.mounts.setTaint(ctx, path, true, mountState) if err != nil { return err } @@ -860,7 +869,7 @@ func (c *Core) remount(ctx context.Context, src, dst string, updateStorage bool) } // Mark the entry as tainted - if err := c.taintMountEntry(ctx, src, updateStorage); err != nil { + if err := c.taintMountEntry(ctx, src, updateStorage, false); err != nil { return err } @@ -988,9 +997,38 @@ func (c *Core) loadMounts(ctx context.Context) error { } } - // Note that this is only designed to work with singletons, as it checks by - // type only. + // If this node is a performance standby we do not want to attempt to + // upgrade the mount table, this will be the active node's responsibility. + if !c.perfStandby { + err := c.runMountUpdates(ctx, needPersist) + if err != nil { + c.logger.Error("failed to run mount table upgrades", "error", err) + return err + } + } + for _, entry := range c.mounts.Entries { + if entry.NamespaceID == "" { + entry.NamespaceID = namespace.RootNamespaceID + } + ns, err := NamespaceByID(ctx, entry.NamespaceID, c) + if err != nil { + return err + } + if ns == nil { + return namespace.ErrNoNamespace + } + entry.namespace = ns + + // Sync values to the cache + entry.SyncCache() + } + return nil +} + +// Note that this is only designed to work with singletons, as it checks by +// type only. +func (c *Core) runMountUpdates(ctx context.Context, needPersist bool) error { // Upgrade to typed mount table if c.mounts.Type == "" { c.mounts.Type = mountTableType @@ -1022,6 +1060,10 @@ func (c *Core) loadMounts(ctx context.Context) error { // Upgrade to table-scoped entries for _, entry := range c.mounts.Entries { + if !c.PR1103disabled && entry.Type == mountTypeNSCubbyhole && !entry.Local && !c.ReplicationState().HasState(consts.ReplicationPerformanceSecondary|consts.ReplicationDRSecondary) { + entry.Local = true + needPersist = true + } if entry.Type == cubbyholeMountType && !entry.Local { entry.Local = true needPersist = true @@ -1051,17 +1093,6 @@ func (c *Core) loadMounts(ctx context.Context) error { entry.NamespaceID = namespace.RootNamespaceID needPersist = true } - ns, err := NamespaceByID(ctx, entry.NamespaceID, c) - if err != nil { - return err - } - if ns == nil { - return namespace.ErrNoNamespace - } - entry.namespace = ns - - // Sync values to the cache - entry.SyncCache() } // Done if we have restored the mount table and we don't need diff --git a/vault/namespaces.go b/vault/namespaces.go index 5b9f31b94c..335c90896b 100644 --- a/vault/namespaces.go +++ b/vault/namespaces.go @@ -10,6 +10,10 @@ var ( NamespaceByID func(context.Context, string, *Core) (*namespace.Namespace, error) = namespaceByID ) +const ( + mountTypeNSCubbyhole = "ns_cubbyhole" +) + func namespaceByID(ctx context.Context, nsID string, c *Core) (*namespace.Namespace, error) { if nsID == namespace.RootNamespaceID { return namespace.RootNamespace, nil diff --git a/vault/testing.go b/vault/testing.go index c675be7446..9473bfb27b 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -1468,12 +1468,18 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te coreConfig.CounterSyncInterval = base.CounterSyncInterval coreConfig.RecoveryMode = base.RecoveryMode + coreConfig.ActivityLogConfig = base.ActivityLogConfig + testApplyEntBaseConfig(coreConfig, base) } if coreConfig.ClusterName == "" { coreConfig.ClusterName = t.Name() } + if coreConfig.ClusterName == "" { + coreConfig.ClusterName = t.Name() + } + if coreConfig.ClusterHeartbeatInterval == 0 { // Set this lower so that state populates quickly to standby nodes coreConfig.ClusterHeartbeatInterval = 2 * time.Second @@ -1768,6 +1774,10 @@ func (testCluster *TestCluster) newCore(t testing.T, idx int, coreConfig *CoreCo localConfig.MetricSink, localConfig.MetricsHelper = opts.CoreMetricSinkProvider(localConfig.ClusterName) } + if opts != nil && opts.CoreMetricSinkProvider != nil { + localConfig.MetricSink, localConfig.MetricsHelper = opts.CoreMetricSinkProvider(localConfig.ClusterName) + } + c, err := NewCore(&localConfig) if err != nil { t.Fatalf("err: %v", err) diff --git a/vault/token_store.go b/vault/token_store.go index a12f19bb9c..38964c7c18 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -485,7 +485,8 @@ type TokenStore struct { parentBarrierView *BarrierView rolesBarrierView *BarrierView - expiration *ExpirationManager + expiration *ExpirationManager + activityLog *ActivityLog cubbyholeBackend *CubbyholeBackend @@ -657,6 +658,12 @@ func (ts *TokenStore) SetExpirationManager(exp *ExpirationManager) { ts.expiration = exp } +// SetActivityLog injects the activity log to which all new +// token creation events are reported. +func (ts *TokenStore) SetActivityLog(a *ActivityLog) { + ts.activityLog = a +} + // SaltID is used to apply a salt and hash to an ID to make sure its not reversible func (ts *TokenStore) SaltID(ctx context.Context, id string) (string, error) { ns, err := namespace.FromContext(ctx) @@ -862,6 +869,11 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err return err } + // Update the activity log + if ts.activityLog != nil { + ts.activityLog.HandleTokenCreation(entry) + } + return ts.storeCommon(ctx, entry, true) case logical.TokenTypeBatch: @@ -905,6 +917,11 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err entry.ID = fmt.Sprintf("%s.%s", entry.ID, tokenNS.ID) } + // Update the activity log + if ts.activityLog != nil { + ts.activityLog.HandleTokenCreation(entry) + } + return nil default: diff --git a/vendor/github.com/hashicorp/vault/sdk/helper/awsutil/region.go b/vendor/github.com/hashicorp/vault/sdk/helper/awsutil/region.go index 727c3b9104..93456cdddf 100644 --- a/vendor/github.com/hashicorp/vault/sdk/helper/awsutil/region.go +++ b/vendor/github.com/hashicorp/vault/sdk/helper/awsutil/region.go @@ -69,5 +69,6 @@ func GetRegion(configuredRegion string) (string, error) { if err != nil { return "", errwrap.Wrapf("unable to retrieve region from instance metadata: {{err}}", err) } + return region, nil } diff --git a/website/pages/api-docs/secret/transform/index.mdx b/website/pages/api-docs/secret/transform/index.mdx index 61b60f2d6d..5da864203d 100644 --- a/website/pages/api-docs/secret/transform/index.mdx +++ b/website/pages/api-docs/secret/transform/index.mdx @@ -90,6 +90,12 @@ This endpoint lists all existing roles in the secrets engine. | :----- | :---------------- | | `LIST` | `/transform/role` | +### Parameters + +- `filter` `(string: "*")` – + If provided, only returns role names that match the given glob. + + ### Sample Request ```shell-session