grafana/pkg/clientauth/roundtripper_test.go
Gabriel MABILLE 93566ce4ef
Chore: Unify token exchange round trippers (#115609)
* Chore: Unify token exchange rount trippers

* Remove the conditional provider for now

* Remove unecessary strategy

* test cleanup

* Lint
2026-01-05 11:23:35 +01:00

259 lines
8.7 KiB
Go

package clientauth
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/grafana/authlib/authn"
"github.com/stretchr/testify/require"
)
type fakeExchanger struct {
resp *authn.TokenExchangeResponse
err error
gotReq *authn.TokenExchangeRequest
}
func (f *fakeExchanger) Exchange(_ context.Context, req authn.TokenExchangeRequest) (*authn.TokenExchangeResponse, error) {
f.gotReq = &req
return f.resp, f.err
}
// roundTripperFunc allows building a stub transport inline
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
func TestTokenExchangeRoundTripper_SetsAccessTokenHeader(t *testing.T) {
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "test-token-123"}}
var capturedHeader string
transport := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
capturedHeader = r.Header.Get("X-Access-Token")
rr := httptest.NewRecorder()
rr.WriteHeader(http.StatusOK)
return rr.Result(), nil
})
rt := newTokenExchangeRoundTripperWithStrategies(exchanger, transport, NewStaticNamespaceProvider("test-namespace"), NewStaticAudienceProvider("test-audience"))
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
resp, err := rt.RoundTrip(req)
require.NoError(t, err)
if resp != nil {
_ = resp.Body.Close()
}
// Clean up response
_ = resp.Body.Close()
require.Equal(t, "Bearer test-token-123", capturedHeader)
}
func TestTokenExchangeRoundTripper_PropagatesExchangeError(t *testing.T) {
expectedErr := errors.New("token exchange failed")
exchanger := &fakeExchanger{err: expectedErr}
transport := roundTripperFunc(func(_ *http.Request) (*http.Response, error) {
t.Fatal("transport should not be called on exchange error")
return nil, nil
})
rt := newTokenExchangeRoundTripperWithStrategies(exchanger, transport, NewStaticNamespaceProvider("test-namespace"), NewStaticAudienceProvider("test-audience"))
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
resp, err := rt.RoundTrip(req)
require.Error(t, err)
if resp != nil {
_ = resp.Body.Close()
}
require.ErrorContains(t, err, "failed to exchange token")
require.ErrorIs(t, err, expectedErr)
}
func TestTokenExchangeRoundTripper_SendsCorrectAudienceAndNamespace(t *testing.T) {
tests := []struct {
name string
audience string
namespace string
expectedAudiences []string
expectedNamespace string
}{
{
name: "single audience with wildcard namespace",
audience: "folder.grafana.app",
namespace: "*",
expectedAudiences: []string{"folder.grafana.app"},
expectedNamespace: "*",
},
{
name: "different audience with wildcard namespace",
audience: "dashboard.grafana.app",
namespace: "*",
expectedAudiences: []string{"dashboard.grafana.app"},
expectedNamespace: "*",
},
{
name: "audience with specific namespace",
audience: "test-audience",
namespace: "test-namespace",
expectedAudiences: []string{"test-audience"},
expectedNamespace: "test-namespace",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "token"}}
transport := roundTripperFunc(func(_ *http.Request) (*http.Response, error) {
rr := httptest.NewRecorder()
rr.WriteHeader(http.StatusOK)
return rr.Result(), nil
})
rt := newTokenExchangeRoundTripperWithStrategies(exchanger, transport, NewStaticNamespaceProvider(tt.namespace), NewStaticAudienceProvider(tt.audience))
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
resp, err := rt.RoundTrip(req)
require.NoError(t, err)
if resp != nil {
_ = resp.Body.Close()
}
require.NotNil(t, exchanger.gotReq)
require.Equal(t, tt.expectedAudiences, exchanger.gotReq.Audiences)
require.Equal(t, tt.expectedNamespace, exchanger.gotReq.Namespace)
})
}
}
func TestTokenExchangeRoundTripper_DoesNotMutateOriginalRequest(t *testing.T) {
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "token"}}
transport := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
rr := httptest.NewRecorder()
rr.WriteHeader(http.StatusOK)
return rr.Result(), nil
})
rt := newTokenExchangeRoundTripperWithStrategies(exchanger, transport, NewStaticNamespaceProvider("namespace"), NewStaticAudienceProvider("audience"))
originalReq, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
// Ensure original request has no X-Access-Token header
originalReq.Header.Set("X-Custom-Header", "original-value")
require.Empty(t, originalReq.Header.Get("X-Access-Token"))
resp, err := rt.RoundTrip(originalReq)
require.NoError(t, err)
_ = resp.Body.Close()
// Original request should not have been mutated
require.Empty(t, originalReq.Header.Get("X-Access-Token"))
require.Equal(t, "original-value", originalReq.Header.Get("X-Custom-Header"))
}
func TestTokenExchangeRoundTripper_PropagatesTransportError(t *testing.T) {
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "token"}}
expectedErr := errors.New("transport error")
transport := roundTripperFunc(func(_ *http.Request) (*http.Response, error) {
return nil, expectedErr
})
rt := newTokenExchangeRoundTripperWithStrategies(exchanger, transport, NewStaticNamespaceProvider("namespace"), NewStaticAudienceProvider("audience"))
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
resp, err := rt.RoundTrip(req)
require.Error(t, err)
if resp != nil {
_ = resp.Body.Close()
}
require.ErrorIs(t, err, expectedErr)
}
func TestNewTokenExchangeTransportWrapper(t *testing.T) {
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "wrapped-token"}}
var capturedHeader string
baseTransport := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
capturedHeader = r.Header.Get("X-Access-Token")
rr := httptest.NewRecorder()
rr.WriteHeader(http.StatusOK)
return rr.Result(), nil
})
wrapper := NewStaticTokenExchangeTransportWrapper(exchanger, "test-audience", "test-namespace")
wrappedTransport := wrapper(baseTransport)
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
resp, err := wrappedTransport.RoundTrip(req)
require.NoError(t, err)
_ = resp.Body.Close()
require.Equal(t, "Bearer wrapped-token", capturedHeader)
require.NotNil(t, exchanger.gotReq)
require.Equal(t, []string{"test-audience"}, exchanger.gotReq.Audiences)
require.Equal(t, "test-namespace", exchanger.gotReq.Namespace)
}
func TestTokenExchangeRoundTripperWithStrategies(t *testing.T) {
tests := []struct {
name string
namespaceProvider NamespaceProvider
audienceProvider AudienceProvider
expectedNamespace string
expectedAudiences []string
expectedHeader string
}{
{
name: "static providers with bearer prefix",
namespaceProvider: NewStaticNamespaceProvider("*"),
audienceProvider: NewStaticAudienceProvider("folder.grafana.app"),
expectedNamespace: "*",
expectedAudiences: []string{"folder.grafana.app"},
expectedHeader: "Bearer test-token",
},
{
name: "multiple audiences",
namespaceProvider: NewStaticNamespaceProvider("*"),
audienceProvider: NewStaticAudienceProvider("audience1", "audience2"),
expectedNamespace: "*",
expectedAudiences: []string{"audience1", "audience2"},
expectedHeader: "Bearer test-token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
exchanger := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "test-token"}}
var capturedHeader string
transport := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
capturedHeader = r.Header.Get("X-Access-Token")
rr := httptest.NewRecorder()
rr.WriteHeader(http.StatusOK)
return rr.Result(), nil
})
rt := newTokenExchangeRoundTripperWithStrategies(
exchanger,
transport,
tt.namespaceProvider,
tt.audienceProvider,
)
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.org", nil)
resp, err := rt.RoundTrip(req)
require.NoError(t, err)
if resp != nil {
_ = resp.Body.Close()
}
require.Equal(t, tt.expectedHeader, capturedHeader)
require.NotNil(t, exchanger.gotReq)
require.Equal(t, tt.expectedAudiences, exchanger.gotReq.Audiences)
require.Equal(t, tt.expectedNamespace, exchanger.gotReq.Namespace)
})
}
}