mirror of
https://github.com/hashicorp/vault.git
synced 2026-02-03 20:40:45 -05:00
events: Remove subscriptions on timeout and cancel (#19185)
When subscriptions are too slow to accept messages on their channels, then we should remove them from the fanout so that we don't have dead subscriptions using up resources. In addition, when a subscription is explicitly canceled, we should also clean up after it remove the corresponding pipeline. We also add a new metrics, `events.subscriptions`, to keep track of the number of active subscriptions. This also helps us test. Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
parent
100ec9a700
commit
db822c05d4
2 changed files with 175 additions and 6 deletions
|
|
@ -5,9 +5,11 @@ import (
|
|||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/eventlogger"
|
||||
"github.com/hashicorp/eventlogger/formatter_filters/cloudevents"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
|
|
@ -17,9 +19,13 @@ import (
|
|||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
var ErrNotStarted = errors.New("event broker has not been started")
|
||||
const defaultTimeout = 60 * time.Second
|
||||
|
||||
var cloudEventsFormatterFilter *cloudevents.FormatterFilter
|
||||
var (
|
||||
ErrNotStarted = errors.New("event broker has not been started")
|
||||
cloudEventsFormatterFilter *cloudevents.FormatterFilter
|
||||
subscriptions atomic.Int64 // keeps track of event subscription count in all event buses
|
||||
)
|
||||
|
||||
// EventBus contains the main logic of running an event broker for Vault.
|
||||
// Start() must be called before the EventBus will accept events for sending.
|
||||
|
|
@ -28,6 +34,7 @@ type EventBus struct {
|
|||
broker *eventlogger.Broker
|
||||
started atomic.Bool
|
||||
formatterNodeID eventlogger.NodeID
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
type pluginEventBus struct {
|
||||
|
|
@ -42,6 +49,13 @@ type asyncChanNode struct {
|
|||
ch chan *logical.EventReceived
|
||||
namespace *namespace.Namespace
|
||||
logger hclog.Logger
|
||||
|
||||
// used to close the connection
|
||||
closeOnce sync.Once
|
||||
cancelFunc context.CancelFunc
|
||||
pipelineID eventlogger.PipelineID
|
||||
eventType eventlogger.EventType
|
||||
broker *eventlogger.Broker
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
@ -79,6 +93,10 @@ func (bus *EventBus) SendInternal(ctx context.Context, ns *namespace.Namespace,
|
|||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
bus.logger.Info("Sending event", "event", eventReceived)
|
||||
|
||||
// We can't easily know when the Send is complete, so we can't call the cancel function.
|
||||
// But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long.
|
||||
ctx, _ = context.WithTimeout(ctx, bus.timeout)
|
||||
_, err := bus.broker.Send(ctx, eventlogger.EventType(eventType), eventReceived)
|
||||
if err != nil {
|
||||
// if no listeners for this event type are registered, that's okay, the event
|
||||
|
|
@ -142,6 +160,7 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) {
|
|||
logger: logger,
|
||||
broker: broker,
|
||||
formatterNodeID: formatterNodeID,
|
||||
timeout: defaultTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -178,7 +197,18 @@ func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eve
|
|||
defer cancel()
|
||||
return nil, nil, err
|
||||
}
|
||||
return asyncNode.ch, cancel, nil
|
||||
addSubscriptions(1)
|
||||
// add info needed to cancel the subscription
|
||||
asyncNode.pipelineID = eventlogger.PipelineID(pipelineID)
|
||||
asyncNode.eventType = eventlogger.EventType(eventType)
|
||||
asyncNode.cancelFunc = cancel
|
||||
return asyncNode.ch, asyncNode.Close, nil
|
||||
}
|
||||
|
||||
// SetSendTimeout sets the timeout of sending events. If the events are not accepted by the
|
||||
// underlying channel before this timeout, then the channel closed.
|
||||
func (bus *EventBus) SetSendTimeout(timeout time.Duration) {
|
||||
bus.timeout = timeout
|
||||
}
|
||||
|
||||
func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hclog.Logger) *asyncChanNode {
|
||||
|
|
@ -190,8 +220,21 @@ func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hc
|
|||
}
|
||||
}
|
||||
|
||||
// Close tells the bus to stop sending us events.
|
||||
func (node *asyncChanNode) Close() {
|
||||
node.closeOnce.Do(func() {
|
||||
defer node.cancelFunc()
|
||||
if node.broker != nil {
|
||||
err := node.broker.RemovePipeline(node.eventType, node.pipelineID)
|
||||
if err != nil {
|
||||
node.logger.Warn("Error removing pipeline for closing node", "error", err)
|
||||
}
|
||||
}
|
||||
addSubscriptions(-1)
|
||||
})
|
||||
}
|
||||
|
||||
func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (*eventlogger.Event, error) {
|
||||
// TODO: add timeout on sending to node.ch
|
||||
// sends to the channel async in another goroutine
|
||||
go func() {
|
||||
eventRecv := e.Payload.(*logical.EventReceived)
|
||||
|
|
@ -200,12 +243,17 @@ func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (*
|
|||
if eventRecv.Namespace != node.namespace.Path {
|
||||
return
|
||||
}
|
||||
var timeout bool
|
||||
select {
|
||||
case node.ch <- eventRecv:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
timeout = errors.Is(ctx.Err(), context.DeadlineExceeded)
|
||||
case <-node.ctx.Done():
|
||||
return
|
||||
timeout = errors.Is(node.ctx.Err(), context.DeadlineExceeded)
|
||||
}
|
||||
if timeout {
|
||||
node.logger.Info("Subscriber took too long to process event, closing", "ID", eventRecv.Event.ID())
|
||||
node.Close()
|
||||
}
|
||||
}()
|
||||
return e, nil
|
||||
|
|
@ -218,3 +266,7 @@ func (node *asyncChanNode) Reopen() error {
|
|||
func (node *asyncChanNode) Type() eventlogger.NodeType {
|
||||
return eventlogger.NodeTypeSink
|
||||
}
|
||||
|
||||
func addSubscriptions(delta int64) {
|
||||
metrics.SetGauge([]string{"events", "subscriptions"}, float32(subscriptions.Add(delta)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ package eventbus
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -9,6 +11,7 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
// TestBusBasics tests that basic event sending and subscribing function.
|
||||
func TestBusBasics(t *testing.T) {
|
||||
bus, err := NewEventBus(nil)
|
||||
if err != nil {
|
||||
|
|
@ -62,6 +65,7 @@ func TestBusBasics(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestNamespaceFiltering verifies that events for other namespaces are filtered out by the bus.
|
||||
func TestNamespaceFiltering(t *testing.T) {
|
||||
bus, err := NewEventBus(nil)
|
||||
if err != nil {
|
||||
|
|
@ -121,6 +125,7 @@ func TestNamespaceFiltering(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestBus2Subscriptions verifies that events of different types are successfully routed to the correct subscribers.
|
||||
func TestBus2Subscriptions(t *testing.T) {
|
||||
bus, err := NewEventBus(nil)
|
||||
if err != nil {
|
||||
|
|
@ -180,3 +185,115 @@ func TestBus2Subscriptions(t *testing.T) {
|
|||
t.Error("Timeout waiting for event2")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBusSubscriptionsCancel verifies that canceled subscriptions are cleaned up.
|
||||
func TestBusSubscriptionsCancel(t *testing.T) {
|
||||
testCases := []struct {
|
||||
cancel bool
|
||||
}{
|
||||
{cancel: true},
|
||||
{cancel: false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("cancel=%v", tc.cancel), func(t *testing.T) {
|
||||
subscriptions.Store(0)
|
||||
bus, err := NewEventBus(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
if !tc.cancel {
|
||||
// set the timeout very short to make the test faster if we aren't canceling explicitly
|
||||
bus.SetSendTimeout(100 * time.Millisecond)
|
||||
}
|
||||
bus.Start()
|
||||
|
||||
// create and stop a bunch of subscriptions
|
||||
const create = 100
|
||||
const stop = 50
|
||||
|
||||
eventType := logical.EventType("someType")
|
||||
|
||||
var channels []<-chan *logical.EventReceived
|
||||
var cancels []context.CancelFunc
|
||||
stopped := atomic.Int32{}
|
||||
|
||||
received := atomic.Int32{}
|
||||
|
||||
for i := 0; i < create; i++ {
|
||||
ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cancelFunc)
|
||||
channels = append(channels, ch)
|
||||
cancels = append(cancels, cancelFunc)
|
||||
|
||||
go func(i int32) {
|
||||
<-ch // always receive one message
|
||||
received.Add(1)
|
||||
// continue receiving messages as long as are not stopped
|
||||
for i < int32(stop) {
|
||||
<-ch
|
||||
received.Add(1)
|
||||
}
|
||||
if tc.cancel {
|
||||
cancelFunc() // stop explicitly to unsubscribe
|
||||
}
|
||||
stopped.Add(1)
|
||||
}(int32(i))
|
||||
}
|
||||
|
||||
// check that all channels receive a message
|
||||
event, err := logical.NewEvent()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
waitFor(t, 1*time.Second, func() bool { return received.Load() == int32(create) })
|
||||
waitFor(t, 1*time.Second, func() bool { return stopped.Load() == int32(stop) })
|
||||
|
||||
// send another message, but half should stop receiving
|
||||
event, err = logical.NewEvent()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
waitFor(t, 1*time.Second, func() bool { return received.Load() == int32(create*2-stop) })
|
||||
// the sends should time out and the subscriptions should drop when cancelFunc is called or the context cancels
|
||||
waitFor(t, 1*time.Second, func() bool { return subscriptions.Load() == int64(create-stop) })
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// waitFor waits for a condition to be true, up to the maximum timeout.
|
||||
// It waits with a capped exponential backoff starting at 1ms.
|
||||
// It is guaranteed to try f() at least once.
|
||||
func waitFor(t *testing.T, maxWait time.Duration, f func() bool) {
|
||||
t.Helper()
|
||||
start := time.Now()
|
||||
|
||||
if f() {
|
||||
return
|
||||
}
|
||||
sleepAmount := 1 * time.Millisecond
|
||||
for time.Now().Sub(start) <= maxWait {
|
||||
left := time.Now().Sub(start)
|
||||
sleepAmount = sleepAmount * 2
|
||||
if sleepAmount > left {
|
||||
sleepAmount = left
|
||||
}
|
||||
time.Sleep(sleepAmount)
|
||||
if f() {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Error("Timeout waiting for condition")
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue