diff --git a/api/sys_auth.go b/api/sys_auth.go index e814412191..8337ae8c02 100644 --- a/api/sys_auth.go +++ b/api/sys_auth.go @@ -12,6 +12,39 @@ import ( "github.com/mitchellh/mapstructure" ) +func (c *Sys) GetAuth(path string) (*AuthMount, error) { + return c.GetAuthWithContext(context.Background(), path) +} + +func (c *Sys) GetAuthWithContext(ctx context.Context, path string) (*AuthMount, error) { + ctx, cancelFunc := c.c.withConfiguredTimeout(ctx) + defer cancelFunc() + + r := c.c.NewRequest(http.MethodGet, fmt.Sprintf("/v1/sys/auth/%s", path)) + + resp, err := c.c.rawRequestWithContext(ctx, r) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + secret, err := ParseSecret(resp.Body) + if err != nil { + return nil, err + } + if secret == nil || secret.Data == nil { + return nil, errors.New("data from server response is empty") + } + + mount := AuthMount{} + err = mapstructure.Decode(secret.Data, &mount) + if err != nil { + return nil, err + } + + return &mount, nil +} + func (c *Sys) ListAuth() (map[string]*AuthMount, error) { return c.ListAuthWithContext(context.Background()) } diff --git a/api/sys_mounts.go b/api/sys_mounts.go index b9f4f8f6f8..64529986af 100644 --- a/api/sys_mounts.go +++ b/api/sys_mounts.go @@ -13,6 +13,39 @@ import ( "github.com/mitchellh/mapstructure" ) +func (c *Sys) GetMount(path string) (*MountOutput, error) { + return c.GetMountWithContext(context.Background(), path) +} + +func (c *Sys) GetMountWithContext(ctx context.Context, path string) (*MountOutput, error) { + ctx, cancelFunc := c.c.withConfiguredTimeout(ctx) + defer cancelFunc() + + r := c.c.NewRequest(http.MethodGet, fmt.Sprintf("/v1/sys/mounts/%s", path)) + + resp, err := c.c.rawRequestWithContext(ctx, r) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + secret, err := ParseSecret(resp.Body) + if err != nil { + return nil, err + } + if secret == nil || secret.Data == nil { + return nil, errors.New("data from server response is empty") + } + + mount := MountOutput{} + err = mapstructure.Decode(secret.Data, &mount) + if err != nil { + return nil, err + } + + return &mount, nil +} + func (c *Sys) ListMounts() (map[string]*MountOutput, error) { return c.ListMountsWithContext(context.Background()) } diff --git a/changelog/25499.txt b/changelog/25499.txt new file mode 100644 index 0000000000..f2ef3e54aa --- /dev/null +++ b/changelog/25499.txt @@ -0,0 +1,3 @@ +```release-note:improvement +api: Add wrapper functions for GET /sys/mounts/:path and GET /sys/auth/:path +``` diff --git a/vault/external_tests/api/sys_auth_test.go b/vault/external_tests/api/sys_auth_test.go new file mode 100644 index 0000000000..46e715de87 --- /dev/null +++ b/vault/external_tests/api/sys_auth_test.go @@ -0,0 +1,90 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package api + +import ( + "testing" + + "github.com/hashicorp/vault/api" +) + +// TestGetAuth tests that we can get a single auth mount +func TestGetAuth(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + mountName string + authInput *api.EnableAuthOptions + expected *api.AuthMount + shouldMount bool + expectErr bool + }{ + { + name: "get-default-auth-mount-success", + mountName: "token", + authInput: nil, + expected: &api.AuthMount{ + Type: "token", + }, + shouldMount: false, + expectErr: false, + }, + { + name: "get-manual-auth-mount-success", + mountName: "userpass", + authInput: &api.EnableAuthOptions{ + Type: "userpass", + }, + expected: &api.AuthMount{ + Type: "userpass", + }, + shouldMount: true, + expectErr: false, + }, + { + name: "error-not-found", + mountName: "not-found", + authInput: nil, + expected: &api.AuthMount{ + Type: "not-found", + }, + shouldMount: false, + expectErr: true, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if tc.shouldMount { + err := client.Sys().EnableAuthWithOptions(tc.mountName, tc.authInput) + if err != nil { + t.Fatal(err) + } + } + + mount, err := client.Sys().GetAuth(tc.mountName) + if !tc.expectErr && err != nil { + t.Fatal(err) + } + + if !tc.expectErr { + if tc.expected.Type != mount.Type || tc.expected.PluginVersion != mount.PluginVersion { + t.Errorf("mount did not match: expected %+v but got %+v", tc.expected, mount) + } + } else { + if err == nil { + t.Errorf("expected error but got nil") + } + } + }) + } +} diff --git a/vault/external_tests/api/sys_mounts_test.go b/vault/external_tests/api/sys_mounts_test.go new file mode 100644 index 0000000000..bdf8a0b16b --- /dev/null +++ b/vault/external_tests/api/sys_mounts_test.go @@ -0,0 +1,90 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package api + +import ( + "testing" + + "github.com/hashicorp/vault/api" +) + +// TestGetMount tests that we can get a single secret mount +func TestGetMount(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + mountName string + mountInput *api.MountInput + expected *api.MountOutput + shouldMount bool + expectErr bool + }{ + { + name: "get-default-mount-success", + mountName: "secret", + mountInput: nil, + expected: &api.MountOutput{ + Type: "kv", + }, + shouldMount: false, + expectErr: false, + }, + { + name: "get-manual-mount-success", + mountName: "pki", + mountInput: &api.MountInput{ + Type: "pki", + }, + expected: &api.MountOutput{ + Type: "pki", + }, + shouldMount: true, + expectErr: false, + }, + { + name: "error-not-found", + mountName: "not-found", + mountInput: nil, + expected: &api.MountOutput{ + Type: "not-found", + }, + shouldMount: false, + expectErr: true, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if tc.shouldMount { + err := client.Sys().Mount(tc.mountName, tc.mountInput) + if err != nil { + t.Fatal(err) + } + } + + mount, err := client.Sys().GetMount(tc.mountName) + if !tc.expectErr && err != nil { + t.Fatal(err) + } + + if !tc.expectErr { + if tc.expected.Type != mount.Type || tc.expected.PluginVersion != mount.PluginVersion { + t.Errorf("mount did not match: expected %+v but got %+v", tc.expected, mount) + } + } else { + if err == nil { + t.Errorf("expected error but got nil") + } + } + }) + } +}