mirror of
https://github.com/grafana/grafana.git
synced 2026-02-03 20:49:50 -05:00
* Chore: Unify token exchange rount trippers * Remove the conditional provider for now * Remove unecessary strategy * test cleanup * Lint
259 lines
8.7 KiB
Go
259 lines
8.7 KiB
Go
package clientauth
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/grafana/authlib/authn"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type fakeExchanger struct {
|
|
resp *authn.TokenExchangeResponse
|
|
err error
|
|
gotReq *authn.TokenExchangeRequest
|
|
}
|
|
|
|
func (f *fakeExchanger) Exchange(_ context.Context, req authn.TokenExchangeRequest) (*authn.TokenExchangeResponse, error) {
|
|
f.gotReq = &req
|
|
return f.resp, f.err
|
|
}
|
|
|
|
// roundTripperFunc allows building a stub transport inline
|
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
|
|
|
func TestTokenExchangeRoundTripper_SetsAccessTokenHeader(t *testing.T) {
|
|
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "test-token-123"}}
|
|
|
|
var capturedHeader string
|
|
transport := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
capturedHeader = r.Header.Get("X-Access-Token")
|
|
rr := httptest.NewRecorder()
|
|
rr.WriteHeader(http.StatusOK)
|
|
return rr.Result(), nil
|
|
})
|
|
|
|
rt := newTokenExchangeRoundTripperWithStrategies(exchanger, transport, NewStaticNamespaceProvider("test-namespace"), NewStaticAudienceProvider("test-audience"))
|
|
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
if resp != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
|
|
// Clean up response
|
|
_ = resp.Body.Close()
|
|
|
|
require.Equal(t, "Bearer test-token-123", capturedHeader)
|
|
}
|
|
|
|
func TestTokenExchangeRoundTripper_PropagatesExchangeError(t *testing.T) {
|
|
expectedErr := errors.New("token exchange failed")
|
|
exchanger := &fakeExchanger{err: expectedErr}
|
|
|
|
transport := roundTripperFunc(func(_ *http.Request) (*http.Response, error) {
|
|
t.Fatal("transport should not be called on exchange error")
|
|
return nil, nil
|
|
})
|
|
|
|
rt := newTokenExchangeRoundTripperWithStrategies(exchanger, transport, NewStaticNamespaceProvider("test-namespace"), NewStaticAudienceProvider("test-audience"))
|
|
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.Error(t, err)
|
|
if resp != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
require.ErrorContains(t, err, "failed to exchange token")
|
|
require.ErrorIs(t, err, expectedErr)
|
|
}
|
|
|
|
func TestTokenExchangeRoundTripper_SendsCorrectAudienceAndNamespace(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
audience string
|
|
namespace string
|
|
expectedAudiences []string
|
|
expectedNamespace string
|
|
}{
|
|
{
|
|
name: "single audience with wildcard namespace",
|
|
audience: "folder.grafana.app",
|
|
namespace: "*",
|
|
expectedAudiences: []string{"folder.grafana.app"},
|
|
expectedNamespace: "*",
|
|
},
|
|
{
|
|
name: "different audience with wildcard namespace",
|
|
audience: "dashboard.grafana.app",
|
|
namespace: "*",
|
|
expectedAudiences: []string{"dashboard.grafana.app"},
|
|
expectedNamespace: "*",
|
|
},
|
|
{
|
|
name: "audience with specific namespace",
|
|
audience: "test-audience",
|
|
namespace: "test-namespace",
|
|
expectedAudiences: []string{"test-audience"},
|
|
expectedNamespace: "test-namespace",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "token"}}
|
|
transport := roundTripperFunc(func(_ *http.Request) (*http.Response, error) {
|
|
rr := httptest.NewRecorder()
|
|
rr.WriteHeader(http.StatusOK)
|
|
return rr.Result(), nil
|
|
})
|
|
|
|
rt := newTokenExchangeRoundTripperWithStrategies(exchanger, transport, NewStaticNamespaceProvider(tt.namespace), NewStaticAudienceProvider(tt.audience))
|
|
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
if resp != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
|
|
require.NotNil(t, exchanger.gotReq)
|
|
require.Equal(t, tt.expectedAudiences, exchanger.gotReq.Audiences)
|
|
require.Equal(t, tt.expectedNamespace, exchanger.gotReq.Namespace)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTokenExchangeRoundTripper_DoesNotMutateOriginalRequest(t *testing.T) {
|
|
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "token"}}
|
|
transport := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
rr := httptest.NewRecorder()
|
|
rr.WriteHeader(http.StatusOK)
|
|
return rr.Result(), nil
|
|
})
|
|
|
|
rt := newTokenExchangeRoundTripperWithStrategies(exchanger, transport, NewStaticNamespaceProvider("namespace"), NewStaticAudienceProvider("audience"))
|
|
originalReq, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
|
|
|
|
// Ensure original request has no X-Access-Token header
|
|
originalReq.Header.Set("X-Custom-Header", "original-value")
|
|
require.Empty(t, originalReq.Header.Get("X-Access-Token"))
|
|
|
|
resp, err := rt.RoundTrip(originalReq)
|
|
require.NoError(t, err)
|
|
_ = resp.Body.Close()
|
|
|
|
// Original request should not have been mutated
|
|
require.Empty(t, originalReq.Header.Get("X-Access-Token"))
|
|
require.Equal(t, "original-value", originalReq.Header.Get("X-Custom-Header"))
|
|
}
|
|
|
|
func TestTokenExchangeRoundTripper_PropagatesTransportError(t *testing.T) {
|
|
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "token"}}
|
|
expectedErr := errors.New("transport error")
|
|
transport := roundTripperFunc(func(_ *http.Request) (*http.Response, error) {
|
|
return nil, expectedErr
|
|
})
|
|
|
|
rt := newTokenExchangeRoundTripperWithStrategies(exchanger, transport, NewStaticNamespaceProvider("namespace"), NewStaticAudienceProvider("audience"))
|
|
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.Error(t, err)
|
|
if resp != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
require.ErrorIs(t, err, expectedErr)
|
|
}
|
|
|
|
func TestNewTokenExchangeTransportWrapper(t *testing.T) {
|
|
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "wrapped-token"}}
|
|
|
|
var capturedHeader string
|
|
baseTransport := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
capturedHeader = r.Header.Get("X-Access-Token")
|
|
rr := httptest.NewRecorder()
|
|
rr.WriteHeader(http.StatusOK)
|
|
return rr.Result(), nil
|
|
})
|
|
|
|
wrapper := NewStaticTokenExchangeTransportWrapper(exchanger, "test-audience", "test-namespace")
|
|
wrappedTransport := wrapper(baseTransport)
|
|
|
|
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
|
|
resp, err := wrappedTransport.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
_ = resp.Body.Close()
|
|
|
|
require.Equal(t, "Bearer wrapped-token", capturedHeader)
|
|
require.NotNil(t, exchanger.gotReq)
|
|
require.Equal(t, []string{"test-audience"}, exchanger.gotReq.Audiences)
|
|
require.Equal(t, "test-namespace", exchanger.gotReq.Namespace)
|
|
}
|
|
|
|
func TestTokenExchangeRoundTripperWithStrategies(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
namespaceProvider NamespaceProvider
|
|
audienceProvider AudienceProvider
|
|
expectedNamespace string
|
|
expectedAudiences []string
|
|
expectedHeader string
|
|
}{
|
|
{
|
|
name: "static providers with bearer prefix",
|
|
namespaceProvider: NewStaticNamespaceProvider("*"),
|
|
audienceProvider: NewStaticAudienceProvider("folder.grafana.app"),
|
|
expectedNamespace: "*",
|
|
expectedAudiences: []string{"folder.grafana.app"},
|
|
expectedHeader: "Bearer test-token",
|
|
},
|
|
{
|
|
name: "multiple audiences",
|
|
namespaceProvider: NewStaticNamespaceProvider("*"),
|
|
audienceProvider: NewStaticAudienceProvider("audience1", "audience2"),
|
|
expectedNamespace: "*",
|
|
expectedAudiences: []string{"audience1", "audience2"},
|
|
expectedHeader: "Bearer test-token",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "test-token"}}
|
|
|
|
var capturedHeader string
|
|
transport := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
|
capturedHeader = r.Header.Get("X-Access-Token")
|
|
rr := httptest.NewRecorder()
|
|
rr.WriteHeader(http.StatusOK)
|
|
return rr.Result(), nil
|
|
})
|
|
|
|
rt := newTokenExchangeRoundTripperWithStrategies(
|
|
exchanger,
|
|
transport,
|
|
tt.namespaceProvider,
|
|
tt.audienceProvider,
|
|
)
|
|
|
|
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
if resp != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
|
|
require.Equal(t, tt.expectedHeader, capturedHeader)
|
|
require.NotNil(t, exchanger.gotReq)
|
|
require.Equal(t, tt.expectedAudiences, exchanger.gotReq.Audiences)
|
|
require.Equal(t, tt.expectedNamespace, exchanger.gotReq.Namespace)
|
|
})
|
|
}
|
|
}
|