From 899ebd4affe4b2857ae738023ff7629a6ab44541 Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Thu, 1 Aug 2024 11:43:54 -0500 Subject: [PATCH] db/postgres: add feature flag protected sslinline configuration (#27871) * adds sslinline option to postgres conn string * for database secrets type postgres, inspects the connection string for sslinline and generates a tlsconfig from the connection string. * support fallback hosts * remove broken multihost test * bootstrap container with cert material * overwrite pg config and set key file perms * add feature flag check * add tests * add license and comments * test all ssl modes * add test cases for dsn (key/value) connection strings * add fallback test cases * fix error formatting * add test for multi-host when using pgx native conn url parsing --------- Co-authored-by: Branden Horiuchi --- builtin/logical/database/backend_test.go | 2 +- .../postgresql/postgresqlhelper.go | 231 ++++++--- .../database/postgresql/postgresql_test.go | 403 ++++++++++----- sdk/database/helper/connutil/postgres.go | 466 ++++++++++++++++++ sdk/database/helper/connutil/sql.go | 10 +- sdk/helper/pluginutil/env.go | 5 + 6 files changed, 923 insertions(+), 194 deletions(-) create mode 100644 sdk/database/helper/connutil/postgres.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 8cfa8535cb..a861470002 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -1451,7 +1451,7 @@ func TestBackend_ConnectionURL_redacted(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cleanup, u := postgreshelper.PrepareTestContainerWithPassword(t, "13.4-buster", tt.password) + cleanup, u := postgreshelper.PrepareTestContainerWithPassword(t, tt.password) t.Cleanup(cleanup) p, err := url.Parse(u) diff --git a/helper/testhelpers/postgresql/postgresqlhelper.go b/helper/testhelpers/postgresql/postgresqlhelper.go index f0aa1203bd..cf144b192a 100644 --- a/helper/testhelpers/postgresql/postgresqlhelper.go +++ b/helper/testhelpers/postgresql/postgresqlhelper.go @@ -11,22 +11,29 @@ import ( "os" "testing" + "github.com/hashicorp/vault/helper/testhelpers/certhelpers" + "github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/helper/docker" ) -const postgresVersion = "13.4-buster" +const ( + defaultPGImage = "docker.mirror.hashicorp.services/postgres" + defaultPGVersion = "13.4-buster" + defaultPGPass = "secret" +) func defaultRunOpts(t *testing.T) docker.RunOptions { return docker.RunOptions{ ContainerName: "postgres", - ImageRepo: "docker.mirror.hashicorp.services/postgres", - ImageTag: postgresVersion, + ImageRepo: defaultPGImage, + ImageTag: defaultPGVersion, Env: []string{ - "POSTGRES_PASSWORD=secret", + "POSTGRES_PASSWORD=" + defaultPGPass, "POSTGRES_DB=database", }, - Ports: []string{"5432/tcp"}, - DoNotAutoRemove: false, + Ports: []string{"5432/tcp"}, + DoNotAutoRemove: false, + OmitLogTimestamps: true, LogConsumer: func(s string) { if t.Failed() { t.Logf("container logs: %s", s) @@ -36,7 +43,13 @@ func defaultRunOpts(t *testing.T) docker.RunOptions { } func PrepareTestContainer(t *testing.T) (func(), string) { - _, cleanup, url, _ := prepareTestContainer(t, defaultRunOpts(t), "secret", true, false) + _, cleanup, url, _ := prepareTestContainer(t, defaultRunOpts(t), defaultPGPass, true, false, false) + + return cleanup, url +} + +func PrepareTestContainerMultiHost(t *testing.T) (func(), string) { + _, cleanup, url, _ := prepareTestContainer(t, defaultRunOpts(t), defaultPGPass, true, false, true) return cleanup, url } @@ -45,90 +58,138 @@ func PrepareTestContainer(t *testing.T) (func(), string) { // admin user configured so that we can safely call rotate-root without // rotating the root DB credentials func PrepareTestContainerWithVaultUser(t *testing.T, ctx context.Context) (func(), string) { - runner, cleanup, url, id := prepareTestContainer(t, defaultRunOpts(t), "secret", true, false) + runner, cleanup, url, id := prepareTestContainer(t, defaultRunOpts(t), defaultPGPass, true, false, false) cmd := []string{"psql", "-U", "postgres", "-c", "CREATE USER vaultadmin WITH LOGIN PASSWORD 'vaultpass' SUPERUSER"} - _, err := runner.RunCmdInBackground(ctx, id, cmd) - if err != nil { - t.Fatalf("Could not run command (%v) in container: %v", cmd, err) - } + mustRunCommand(t, ctx, runner, id, cmd) return cleanup, url } -func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, version string) (func(), string) { +// PrepareTestContainerWithSSL will setup a test container with SSL enabled so +// that we can test client certificate authentication. +func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode string, useFallback bool) (func(), string) { runOpts := defaultRunOpts(t) - runOpts.Cmd = []string{"-c", "log_statement=all"} - runner, cleanup, url, id := prepareTestContainer(t, runOpts, "secret", true, false) - - content := "echo 'hostssl all all all cert clientcert=verify-ca' > /var/lib/postgresql/data/pg_hba.conf" - // Copy the ssl init script into the newly running container. - buildCtx := docker.NewBuildContext() - buildCtx["ssl-conf.sh"] = docker.PathContentsFromBytes([]byte(content)) - if err := runner.CopyTo(id, "/usr/local/bin", buildCtx); err != nil { - t.Fatalf("Could not copy ssl init script into container: %v", err) - } - - // run the ssl init script to overwrite the pg_hba.conf file and set it to - // require SSL for each connection - cmd := []string{"bash", "/usr/local/bin/ssl-conf.sh"} - _, err := runner.RunCmdInBackground(ctx, id, cmd) + runner, err := docker.NewServiceRunner(runOpts) if err != nil { - t.Fatalf("Could not run command (%v) in container: %v", cmd, err) + t.Fatalf("Could not provision docker service runner: %s", err) } - // reload so the config changes take effect - cmd = []string{"psql", "-U", "postgres", "-c", "SELECT pg_reload_conf()"} - _, err = runner.RunCmdInBackground(ctx, id, cmd) + // first we connect with username/password because ssl is not enabled yet + svc, id, err := runner.StartNewService(context.Background(), true, false, connectPostgres(defaultPGPass, runOpts.ImageRepo, false)) if err != nil { - t.Fatalf("Could not run command (%v) in container: %v", cmd, err) + t.Fatalf("Could not start docker Postgres: %s", err) } - return cleanup, url + // Create certificates for postgres authentication + caCert := certhelpers.NewCert(t, + certhelpers.CommonName("ca"), + certhelpers.IsCA(true), + certhelpers.SelfSign(), + ) + serverCert := certhelpers.NewCert(t, + certhelpers.CommonName("server"), + certhelpers.DNS("localhost"), + certhelpers.Parent(caCert), + ) + clientCert := certhelpers.NewCert(t, + certhelpers.CommonName("postgres"), + certhelpers.DNS("localhost"), + certhelpers.Parent(caCert), + ) + + bCtx := docker.NewBuildContext() + bCtx["ca.crt"] = docker.PathContentsFromBytes(caCert.CombinedPEM()) + bCtx["server.crt"] = docker.PathContentsFromBytes(serverCert.CombinedPEM()) + bCtx["server.key"] = &docker.FileContents{ + Data: serverCert.PrivateKeyPEM(), + Mode: 0o600, + // postgres uid + UID: 999, + } + + // https://www.postgresql.org/docs/current/auth-pg-hba-conf.html + clientAuthConfig := "echo 'hostssl all all all cert clientcert=verify-ca' > /var/lib/postgresql/data/pg_hba.conf" + bCtx["ssl-conf.sh"] = docker.PathContentsFromString(clientAuthConfig) + pgConfig := ` +cat << EOF > /var/lib/postgresql/data/postgresql.conf +# PostgreSQL configuration file +listen_addresses = '*' +max_connections = 100 +shared_buffers = 128MB +dynamic_shared_memory_type = posix +max_wal_size = 1GB +min_wal_size = 80MB +ssl = on +ssl_ca_file = '/var/lib/postgresql/ca.crt' +ssl_cert_file = '/var/lib/postgresql/server.crt' +ssl_key_file= '/var/lib/postgresql/server.key' +EOF +` + bCtx["pg-conf.sh"] = docker.PathContentsFromString(pgConfig) + + err = runner.CopyTo(id, "/var/lib/postgresql/", bCtx) + if err != nil { + t.Fatalf("failed to copy to container: %v", err) + } + + // overwrite the postgresql.conf config file with our ssl settings + mustRunCommand(t, ctx, runner, id, + []string{"bash", "/var/lib/postgresql/pg-conf.sh"}) + + // overwrite the pg_hba.conf file and set it to require SSL for each connection + mustRunCommand(t, ctx, runner, id, + []string{"bash", "/var/lib/postgresql/ssl-conf.sh"}) + + // reload so the config changes take effect and ssl is enabled + mustRunCommand(t, ctx, runner, id, + []string{"psql", "-U", "postgres", "-c", "SELECT pg_reload_conf()"}) + + if sslMode == "disable" { + // return non-tls connection url + return svc.Cleanup, svc.Config.URL().String() + } + + sslConfig, err := connectPostgresSSL( + t, + svc.Config.URL().Host, + sslMode, + string(caCert.CombinedPEM()), + string(clientCert.CombinedPEM()), + string(clientCert.PrivateKeyPEM()), + useFallback, + ) + if err != nil { + svc.Cleanup() + t.Fatalf("failed to connect to postgres container via SSL: %v", err) + } + return svc.Cleanup, sslConfig.URL().String() } -func PrepareTestContainerWithPassword(t *testing.T, version, password string) (func(), string) { +func PrepareTestContainerWithPassword(t *testing.T, password string) (func(), string) { runOpts := defaultRunOpts(t) runOpts.Env = []string{ "POSTGRES_PASSWORD=" + password, "POSTGRES_DB=database", } - _, cleanup, url, _ := prepareTestContainer(t, runOpts, password, true, false) + _, cleanup, url, _ := prepareTestContainer(t, runOpts, password, true, false, false) return cleanup, url } -func PrepareTestContainerRepmgr(t *testing.T, name, version string, envVars []string) (*docker.Runner, func(), string, string) { - runOpts := defaultRunOpts(t) - runOpts.ImageRepo = "docker.mirror.hashicorp.services/bitnami/postgresql-repmgr" - runOpts.ImageTag = version - runOpts.Env = append(envVars, - "REPMGR_PARTNER_NODES=psql-repl-node-0,psql-repl-node-1", - "REPMGR_PRIMARY_HOST=psql-repl-node-0", - "REPMGR_PASSWORD=repmgrpass", - "POSTGRESQL_PASSWORD=secret") - runOpts.DoNotAutoRemove = true - - return prepareTestContainer(t, runOpts, "secret", false, true) -} - -func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password string, addSuffix, forceLocalAddr bool, +func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password string, addSuffix, forceLocalAddr, useFallback bool, ) (*docker.Runner, func(), string, string) { if os.Getenv("PG_URL") != "" { return nil, func() {}, "", os.Getenv("PG_URL") } - if runOpts.ImageRepo == "bitnami/postgresql-repmgr" { - runOpts.NetworkID = os.Getenv("POSTGRES_MULTIHOST_NET") - } - runner, err := docker.NewServiceRunner(runOpts) if err != nil { t.Fatalf("Could not start docker Postgres: %s", err) } - svc, containerID, err := runner.StartNewService(context.Background(), addSuffix, forceLocalAddr, connectPostgres(password, runOpts.ImageRepo)) + svc, containerID, err := runner.StartNewService(context.Background(), addSuffix, forceLocalAddr, connectPostgres(password, runOpts.ImageRepo, useFallback)) if err != nil { t.Fatalf("Could not start docker Postgres: %s", err) } @@ -136,12 +197,55 @@ func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password stri return runner, svc.Cleanup, svc.Config.URL().String(), containerID } -func connectPostgres(password, repo string) docker.ServiceAdapter { +// connectPostgresSSL is used to verify the connection of our test container +// and construct the connection string that is used in tests. +// +// NOTE: The RawQuery component of the url sets the custom sslinline field and +// inlines the certificate material in the sslrootcert, sslcert, and sslkey +// fields. This feature will be removed in a future version of the SDK. +func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) (docker.ServiceConfig, error) { + if useFallback { + // set the first host to a bad address so we can test the fallback logic + host = "localhost:55," + host + } + u := url.URL{ + Scheme: "postgres", + User: url.User("postgres"), + Host: host, + Path: "postgres", + RawQuery: url.Values{ + "sslmode": {sslMode}, + "sslinline": {"true"}, + "sslrootcert": {caCert}, + "sslcert": {clientCert}, + "sslkey": {clientKey}, + }.Encode(), + } + + // TODO: remove this deprecated function call in a future SDK version + db, err := connutil.OpenPostgres("pgx", u.String()) + if err != nil { + return nil, err + } + defer db.Close() + + if err = db.Ping(); err != nil { + return nil, err + } + return docker.NewServiceURL(u), nil +} + +func connectPostgres(password, repo string, useFallback bool) docker.ServiceAdapter { return func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) { + hostAddr := fmt.Sprintf("%s:%d", host, port) + if useFallback { + // set the first host to a bad address so we can test the fallback logic + hostAddr = "localhost:55," + hostAddr + } u := url.URL{ Scheme: "postgres", User: url.UserPassword("postgres", password), - Host: fmt.Sprintf("%s:%d", host, port), + Host: hostAddr, Path: "postgres", RawQuery: "sslmode=disable", } @@ -170,3 +274,14 @@ func RestartContainer(t *testing.T, ctx context.Context, runner *docker.Runner, t.Fatalf("Could not restart docker Postgres: %s", err) } } + +func mustRunCommand(t *testing.T, ctx context.Context, runner *docker.Runner, containerID string, cmd []string) { + t.Helper() + _, stderr, retcode, err := runner.RunCmdWithOutput(ctx, containerID, cmd) + if err != nil { + t.Fatalf("Could not run command (%v) in container: %v", cmd, err) + } + if retcode != 0 || len(stderr) != 0 { + t.Fatalf("exit code: %v, stderr: %v", retcode, string(stderr)) + } +} diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index 23b04788bb..0ca347dab4 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -17,7 +17,7 @@ import ( dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing" "github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil" - "github.com/hashicorp/vault/sdk/helper/docker" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/helper/template" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -58,6 +58,274 @@ func TestPostgreSQL_Initialize(t *testing.T) { } } +// TestPostgreSQL_InitializeMultiHost tests the functionality of Postgres's +// multi-host connection strings. +func TestPostgreSQL_InitializeMultiHost(t *testing.T) { + cleanup, connURL := postgresql.PrepareTestContainerMultiHost(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + "max_open_connections": 5, + } + + req := dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + } + + db := new() + dbtesting.AssertInitialize(t, db, req) + + if !db.Initialized { + t.Fatal("Database should be initialized") + } + + if err := db.Close(); err != nil { + t.Fatalf("err: %s", err) + } +} + +// TestPostgreSQL_InitializeSSLFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE +// flag guards against unwanted usage of the deprecated SSL client authentication path. +// TODO: remove this when we remove the underlying feature in a future SDK version +func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) { + // set the flag to true so we can call PrepareTestContainerWithSSL + // which does a validation check on the connection + t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true") + + cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), "verify-ca", false) + t.Cleanup(cleanup) + + type testCase struct { + env string + wantErr bool + expectedError string + } + + tests := map[string]testCase{ + "feature flag is true": { + env: "true", + wantErr: false, + expectedError: "", + }, + "feature flag is unset or empty": { + env: "", + wantErr: true, + // this error is expected because the env var unset means we are + // using pgx's native connection string parsing which does not + // support inlining of the certificate material in the sslrootcert, + // sslcert, and sslkey fields + expectedError: "error verifying connection", + }, + "feature flag is false": { + env: "false", + wantErr: true, + expectedError: "failed to open postgres connection with deprecated funtion", + }, + "feature flag is invalid": { + env: "foo", + wantErr: true, + expectedError: "failed to open postgres connection with deprecated funtion", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + // update the env var with the value we are testing + t.Setenv(pluginutil.PluginUsePostgresSSLInline, test.env) + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + "max_open_connections": 5, + } + + req := dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + } + + db := new() + _, err := dbtesting.VerifyInitialize(t, db, req) + if test.wantErr && err == nil { + t.Fatal("expected error, got nil") + } else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) { + t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError) + } + + if !test.wantErr && !db.Initialized { + t.Fatal("Database should be initialized") + } + + if err := db.Close(); err != nil { + t.Fatalf("err: %s", err) + } + // unset for the next test case + os.Unsetenv(pluginutil.PluginUsePostgresSSLInline) + }) + } +} + +// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate +// with a postgres server via ssl with a URL connection string or DSN (key/value) +// for each ssl mode. +// TODO: remove this when we remove the underlying feature in a future SDK version +func TestPostgreSQL_InitializeSSL(t *testing.T) { + // required to enable the sslinline custom parsing + t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true") + + type testCase struct { + sslMode string + useDSN bool + useFallback bool + wantErr bool + expectedError string + } + + tests := map[string]testCase{ + "disable sslmode": { + sslMode: "disable", + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode": { + sslMode: "allow", + wantErr: false, + }, + "prefer sslmode": { + sslMode: "prefer", + wantErr: false, + }, + "require sslmode": { + sslMode: "require", + wantErr: false, + }, + "verify-ca sslmode": { + sslMode: "verify-ca", + wantErr: false, + }, + "disable sslmode with DSN": { + sslMode: "disable", + useDSN: true, + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode with DSN": { + sslMode: "allow", + useDSN: true, + wantErr: false, + }, + "prefer sslmode with DSN": { + sslMode: "prefer", + useDSN: true, + wantErr: false, + }, + "require sslmode with DSN": { + sslMode: "require", + useDSN: true, + wantErr: false, + }, + "verify-ca sslmode with DSN": { + sslMode: "verify-ca", + useDSN: true, + wantErr: false, + }, + "disable sslmode with fallback": { + sslMode: "disable", + useFallback: true, + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode with fallback": { + sslMode: "allow", + useFallback: true, + }, + "prefer sslmode with fallback": { + sslMode: "prefer", + useFallback: true, + }, + "require sslmode with fallback": { + sslMode: "require", + useFallback: true, + }, + "verify-ca sslmode with fallback": { + sslMode: "verify-ca", + useFallback: true, + }, + "disable sslmode with DSN with fallback": { + sslMode: "disable", + useDSN: true, + useFallback: true, + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode with DSN with fallback": { + sslMode: "allow", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "prefer sslmode with DSN with fallback": { + sslMode: "prefer", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "require sslmode with DSN with fallback": { + sslMode: "require", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "verify-ca sslmode with DSN with fallback": { + sslMode: "verify-ca", + useDSN: true, + useFallback: true, + wantErr: false, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), test.sslMode, test.useFallback) + t.Cleanup(cleanup) + + if test.useDSN { + var err error + connURL, err = dbutil.ParseURL(connURL) + if err != nil { + t.Fatal(err) + } + } + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + "max_open_connections": 5, + } + + req := dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + } + + db := new() + _, err := dbtesting.VerifyInitialize(t, db, req) + if test.wantErr && err == nil { + t.Fatal("expected error, got nil") + } else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) { + t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError) + } + + if !test.wantErr && !db.Initialized { + t.Fatal("Database should be initialized") + } + + if err := db.Close(); err != nil { + t.Fatalf("err: %s", err) + } + }) + } +} + func TestPostgreSQL_InitializeWithStringVals(t *testing.T) { db, cleanup := getPostgreSQL(t, map[string]interface{}{ "max_open_connections": "5", @@ -1268,139 +1536,6 @@ func TestNewUser_CloudGCP(t *testing.T) { } } -// This is a long-running integration test which tests the functionality of Postgres's multi-host -// connection strings. It uses two Postgres containers preconfigured with Replication Manager -// provided by Bitnami. This test currently does not run in CI and must be run manually. This is -// due to the test length, as it requires multiple sleep calls to ensure cluster setup and -// primary node failover occurs before the test steps continue. -// -// To run the test, set the environment variable POSTGRES_MULTIHOST_NET to the value of -// a docker network you've preconfigured, e.g. -// 'docker network create -d bridge postgres-repmgr' -// 'export POSTGRES_MULTIHOST_NET=postgres-repmgr' -func TestPostgreSQL_Repmgr(t *testing.T) { - _, exists := os.LookupEnv("POSTGRES_MULTIHOST_NET") - if !exists { - t.Skipf("POSTGRES_MULTIHOST_NET not set, skipping test") - } - - // Run two postgres-repmgr containers in a replication cluster - db0, runner0, url0, container0 := testPostgreSQL_Repmgr_Container(t, "psql-repl-node-0") - _, _, url1, _ := testPostgreSQL_Repmgr_Container(t, "psql-repl-node-1") - - ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) - defer cancel() - - time.Sleep(10 * time.Second) - - // Write a read role to the cluster - _, err := db0.NewUser(ctx, dbplugin.NewUserRequest{ - Statements: dbplugin.Statements{ - Commands: []string{ - `CREATE ROLE "ro" NOINHERIT; - GRANT SELECT ON ALL TABLES IN SCHEMA public TO "ro";`, - }, - }, - }) - if err != nil { - t.Fatalf("no error expected, got: %s", err) - } - - // Open a connection to both databases using the multihost connection string - connectionDetails := map[string]interface{}{ - "connection_url": fmt.Sprintf("postgresql://{{username}}:{{password}}@%s,%s/postgres?target_session_attrs=read-write", getHost(url0), getHost(url1)), - "username": "postgres", - "password": "secret", - } - req := dbplugin.InitializeRequest{ - Config: connectionDetails, - VerifyConnection: true, - } - - db := new() - dbtesting.AssertInitialize(t, db, req) - if !db.Initialized { - t.Fatal("Database should be initialized") - } - defer db.Close() - - // Add a user to the cluster, then stop the primary container - if err = testPostgreSQL_Repmgr_AddUser(ctx, db); err != nil { - t.Fatalf("no error expected, got: %s", err) - } - postgresql.StopContainer(t, ctx, runner0, container0) - - // Try adding a new user immediately - expect failure as the database - // cluster is still switching primaries - err = testPostgreSQL_Repmgr_AddUser(ctx, db) - if !strings.HasSuffix(err.Error(), "ValidateConnect failed (read only connection)") { - t.Fatalf("expected error was not received, got: %s", err) - } - - time.Sleep(20 * time.Second) - - // Try adding a new user again which should succeed after the sleep - // as the primary failover should have finished. Then, restart - // the first container which should become a secondary DB. - if err = testPostgreSQL_Repmgr_AddUser(ctx, db); err != nil { - t.Fatalf("no error expected, got: %s", err) - } - postgresql.RestartContainer(t, ctx, runner0, container0) - - time.Sleep(10 * time.Second) - - // A final new user to add, which should succeed after the secondary joins. - if err = testPostgreSQL_Repmgr_AddUser(ctx, db); err != nil { - t.Fatalf("no error expected, got: %s", err) - } - - if err := db.Close(); err != nil { - t.Fatalf("err: %s", err) - } -} - -func testPostgreSQL_Repmgr_Container(t *testing.T, name string) (*PostgreSQL, *docker.Runner, string, string) { - envVars := []string{ - "REPMGR_NODE_NAME=" + name, - "REPMGR_NODE_NETWORK_NAME=" + name, - } - - runner, cleanup, connURL, containerID := postgresql.PrepareTestContainerRepmgr(t, name, "13.4.0", envVars) - t.Cleanup(cleanup) - - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - } - req := dbplugin.InitializeRequest{ - Config: connectionDetails, - VerifyConnection: true, - } - db := new() - dbtesting.AssertInitialize(t, db, req) - if !db.Initialized { - t.Fatal("Database should be initialized") - } - - if err := db.Close(); err != nil { - t.Fatalf("err: %s", err) - } - - return db, runner, connURL, containerID -} - -func testPostgreSQL_Repmgr_AddUser(ctx context.Context, db *PostgreSQL) error { - _, err := db.NewUser(ctx, dbplugin.NewUserRequest{ - Statements: dbplugin.Statements{ - Commands: []string{ - `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}' INHERIT; - GRANT ro TO "{{name}}";`, - }, - }, - }) - - return err -} - func getHost(url string) string { splitCreds := strings.Split(url, "@")[1] diff --git a/sdk/database/helper/connutil/postgres.go b/sdk/database/helper/connutil/postgres.go new file mode 100644 index 0000000000..7d96376bd2 --- /dev/null +++ b/sdk/database/helper/connutil/postgres.go @@ -0,0 +1,466 @@ +// Copyright (c) 2019-2021 Jack Christensen + +// MIT License + +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: + +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Copied from https://github.com/jackc/pgconn/blob/1860f4e57204614f40d05a5c76a43e8d80fde9da/config.go + +package connutil + +import ( + "context" + "crypto/tls" + "crypto/x509" + "database/sql" + "encoding/pem" + "errors" + "fmt" + "math" + "net" + "net/url" + "os" + "strconv" + "strings" + + "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" +) + +// OpenPostgres parses the connection string and opens a connection to the database. +// +// If sslinline is set, strips the connection string of all ssl settings and +// creates a TLS config based on the settings provided, then uses the +// RegisterConnConfig function to create a new connection. This is necessary +// because the pgx driver does not support the sslinline parameter and instead +// expects to source ssl material from the file system. +// +// Deprecated: OpenPostgres will be removed in a future version of the Vault SDK. +func OpenPostgres(driverName, connString string) (*sql.DB, error) { + if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); !ok { + return nil, fmt.Errorf("failed to open postgres connection with deprecated funtion, set feature flag to enable") + } + + var options pgconn.ParseConfigOptions + + settings := make(map[string]string) + if connString != "" { + var err error + // connString may be a database URL or a DSN + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + settings, err = parsePostgresURLSettings(connString) + if err != nil { + return nil, fmt.Errorf("failed to parse as URL: %w", err) + } + } else { + settings, err = parsePostgresDSNSettings(connString) + if err != nil { + return nil, fmt.Errorf("failed to parse as DSN: %w", err) + } + } + } + + // get the inline flag + sslInline := settings["sslinline"] == "true" + + // if sslinline is not set, open a regular connection + if !sslInline { + return sql.Open(driverName, connString) + } + + // generate a new DSN without the ssl settings + newConnStr := []string{"sslmode=disable"} + for k, v := range settings { + switch k { + case "sslinline", "sslcert", "sslkey", "sslrootcert", "sslmode": + continue + } + + newConnStr = append(newConnStr, fmt.Sprintf("%s='%s'", k, v)) + } + + // parse the updated config + config, err := pgx.ParseConfig(strings.Join(newConnStr, " ")) + if err != nil { + return nil, err + } + + // create a TLS config + fallbacks := []*pgconn.FallbackConfig{} + + hosts := strings.Split(settings["host"], ",") + ports := strings.Split(settings["port"], ",") + + for i, host := range hosts { + var portStr string + if i < len(ports) { + portStr = ports[i] + } else { + portStr = ports[0] + } + + port, err := parsePort(portStr) + if err != nil { + return nil, fmt.Errorf("invalid port: %w", err) + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := pgconn.NetworkAddress(host, port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configPostgresTLS(settings, host, options) + if err != nil { + return nil, fmt.Errorf("failed to configure TLS: %w", err) + } + } + + for _, tlsConfig := range tlsConfigs { + fallbacks = append(fallbacks, &pgconn.FallbackConfig{ + Host: host, + Port: port, + TLSConfig: tlsConfig, + }) + } + } + + config.Host = fallbacks[0].Host + config.Port = fallbacks[0].Port + config.TLSConfig = fallbacks[0].TLSConfig + config.Fallbacks = fallbacks[1:] + + return sql.Open(driverName, stdlib.RegisterConnConfig(config)) +} + +// configPostgresTLS uses libpq's TLS parameters to construct []*tls.Config. It is +// necessary to allow returning multiple TLS configs as sslmode "allow" and +// "prefer" allow fallback. +// +// Copied from https://github.com/jackc/pgconn/blob/1860f4e57204614f40d05a5c76a43e8d80fde9da/config.go +// and modified to read ssl material by value instead of file location. +func configPostgresTLS(settings map[string]string, thisHost string, parseConfigOptions pgconn.ParseConfigOptions) ([]*tls.Config, error) { + host := thisHost + sslmode := settings["sslmode"] + sslrootcert := settings["sslrootcert"] + sslcert := settings["sslcert"] + sslkey := settings["sslkey"] + sslpassword := settings["sslpassword"] + sslsni := settings["sslsni"] + + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + if sslsni == "" { + sslsni = "1" + } + + tlsConfig := &tls.Config{} + + switch sslmode { + case "disable": + return []*tls.Config{nil}, nil + case "allow", "prefer": + tlsConfig.InsecureSkipVerify = true + case "require": + // According to PostgreSQL documentation, if a root CA file exists, + // the behavior of sslmode=require should be the same as that of verify-ca + // + // See https://www.postgresql.org/docs/12/libpq-ssl.html + if sslrootcert != "" { + goto nextCase + } + tlsConfig.InsecureSkipVerify = true + break + nextCase: + fallthrough + case "verify-ca": + // Don't perform the default certificate verification because it + // will verify the hostname. Instead, verify the server's + // certificate chain ourselves in VerifyPeerCertificate and + // ignore the server name. This emulates libpq's verify-ca + // behavior. + // + // See https://github.com/golang/go/issues/21971#issuecomment-332693931 + // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate + // for more info. + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return errors.New("failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + // Leave DNSName empty to skip hostname verification. + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + Intermediates: x509.NewCertPool(), + } + // Skip the first cert because it's the leaf. All others + // are intermediates. + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := certs[0].Verify(opts) + return err + } + case "verify-full": + tlsConfig.ServerName = host + default: + return nil, errors.New("sslmode is invalid") + } + + if sslrootcert != "" { + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM([]byte(sslrootcert)) { + return nil, errors.New("unable to add CA to cert pool") + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return nil, errors.New(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + block, _ := pem.Decode([]byte(sslkey)) + var pemKey []byte + var decryptedKey []byte + var decryptedError error + // If PEM is encrypted, attempt to decrypt using pass phrase + if x509.IsEncryptedPEMBlock(block) { + // Attempt decryption with pass phrase + // NOTE: only supports RSA (PKCS#1) + if sslpassword != "" { + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + } + // if sslpassword not provided or has decryption error when use it + // try to find sslpassword with callback function + if sslpassword == "" || decryptedError != nil { + if parseConfigOptions.GetSSLPassword != nil { + sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) + } + if sslpassword == "" { + return nil, fmt.Errorf("unable to find sslpassword") + } + } + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + // Should we also provide warning for PKCS#1 needed? + if decryptedError != nil { + return nil, fmt.Errorf("unable to decrypt key: %w", decryptedError) + } + + pemBytes := pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: decryptedKey, + } + pemKey = pem.EncodeToMemory(&pemBytes) + } else { + pemKey = pem.EncodeToMemory(block) + } + + cert, err := tls.X509KeyPair([]byte(sslcert), pemKey) + if err != nil { + return nil, fmt.Errorf("unable to load cert: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + // Set Server Name Indication (SNI), if enabled by connection parameters. + // Per RFC 6066, do not set it if the host is a literal IP address (IPv4 + // or IPv6). + if sslsni == "1" && net.ParseIP(host) == nil { + tlsConfig.ServerName = host + } + + switch sslmode { + case "allow": + return []*tls.Config{nil, tlsConfig}, nil + case "prefer": + return []*tls.Config{tlsConfig, nil}, nil + case "require", "verify-ca", "verify-full": + return []*tls.Config{tlsConfig}, nil + default: + panic("BUG: bad sslmode should already have been caught") + } +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} + +func parsePostgresURLSettings(connString string) (map[string]string, error) { + settings := make(map[string]string) + + url, err := url.Parse(connString) + if err != nil { + return nil, err + } + + if url.User != nil { + settings["user"] = url.User.Username() + if password, present := url.User.Password(); present { + settings["password"] = password + } + } + + // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. + var hosts []string + var ports []string + for _, host := range strings.Split(url.Host, ",") { + if host == "" { + continue + } + if isIPOnly(host) { + hosts = append(hosts, strings.Trim(host, "[]")) + continue + } + h, p, err := net.SplitHostPort(host) + if err != nil { + return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) + } + if h != "" { + hosts = append(hosts, h) + } + if p != "" { + ports = append(ports, p) + } + } + if len(hosts) > 0 { + settings["host"] = strings.Join(hosts, ",") + } + if len(ports) > 0 { + settings["port"] = strings.Join(ports, ",") + } + + database := strings.TrimLeft(url.Path, "/") + if database != "" { + settings["database"] = database + } + + nameMap := map[string]string{ + "dbname": "database", + } + + for k, v := range url.Query() { + if k2, present := nameMap[k]; present { + k = k2 + } + + settings[k] = v[0] + } + + return settings, nil +} + +func parsePostgresDSNSettings(s string) (map[string]string, error) { + settings := make(map[string]string) + + nameMap := map[string]string{ + "dbname": "database", + } + + for len(s) > 0 { + var key, val string + eqIdx := strings.IndexRune(s, '=') + if eqIdx < 0 { + return nil, errors.New("invalid dsn") + } + + key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") + s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") + if len(s) == 0 { + } else if s[0] != '\'' { + end := 0 + for ; end < len(s); end++ { + if asciiSpace[s[end]] == 1 { + break + } + if s[end] == '\\' { + end++ + if end == len(s) { + return nil, errors.New("invalid backslash") + } + } + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } else { // quoted string + s = s[1:] + end := 0 + for ; end < len(s); end++ { + if s[end] == '\'' { + break + } + if s[end] == '\\' { + end++ + } + } + if end == len(s) { + return nil, errors.New("unterminated quoted string in connection info string") + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } + + if k, ok := nameMap[key]; ok { + key = k + } + + if key == "" { + return nil, errors.New("invalid dsn") + } + + settings[key] = val + } + + return settings, nil +} + +func isIPOnly(host string) bool { + return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") +} diff --git a/sdk/database/helper/connutil/sql.go b/sdk/database/helper/connutil/sql.go index 7f119bdfaa..548cc83d38 100644 --- a/sdk/database/helper/connutil/sql.go +++ b/sdk/database/helper/connutil/sql.go @@ -8,6 +8,7 @@ import ( "database/sql" "fmt" "net/url" + "os" "strings" "sync" "time" @@ -17,6 +18,7 @@ import ( "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/database/dbplugin" "github.com/hashicorp/vault/sdk/database/helper/dbutil" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/mitchellh/mapstructure" ) @@ -218,7 +220,13 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er } var err error - c.db, err = sql.Open(driverName, conn) + if driverName == "pgx" && os.Getenv(pluginutil.PluginUsePostgresSSLInline) != "" { + // TODO: remove this deprecated function call in a future SDK version + c.db, err = OpenPostgres(driverName, conn) + } else { + c.db, err = sql.Open(driverName, conn) + } + if err != nil { return nil, err } diff --git a/sdk/helper/pluginutil/env.go b/sdk/helper/pluginutil/env.go index 515baa1f76..ea05e8462c 100644 --- a/sdk/helper/pluginutil/env.go +++ b/sdk/helper/pluginutil/env.go @@ -44,6 +44,11 @@ const ( // colliding plugin-specific environment variables. Otherwise, plugin-specific // environment variables take precedence over Vault process environment variables. PluginUseLegacyEnvLayering = "VAULT_PLUGIN_USE_LEGACY_ENV_LAYERING" + + // PluginUsePostgresSSLInline enables the usage of a custom sslinline + // configuration as a shim to the pgx posgtres library. + // Deprecated: VAULT_PLUGIN_USE_POSTGRES_SSLINLINE will be removed in a future version of the Vault SDK. + PluginUsePostgresSSLInline = "VAULT_PLUGIN_USE_POSTGRES_SSLINLINE" ) // OptionallyEnableMlock determines if mlock should be called, and if so enables