diff --git a/changelog/21942.txt b/changelog/21942.txt new file mode 100644 index 0000000000..4e2828efb4 --- /dev/null +++ b/changelog/21942.txt @@ -0,0 +1,3 @@ +```release-note:improvement +openapi: Fix generation of correct fields in some rarer cases +``` diff --git a/sdk/framework/openapi.go b/sdk/framework/openapi.go index be2452f4d6..1bb3b7361a 100644 --- a/sdk/framework/openapi.go +++ b/sdk/framework/openapi.go @@ -229,7 +229,7 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc * // Convert optional parameters into distinct patterns to be processed independently. forceUnpublished := false - paths, err := expandPattern(p.Pattern) + paths, captures, err := expandPattern(p.Pattern) if err != nil { if errors.Is(err, errUnsupportableRegexpOperationForOpenAPI) { // Pattern cannot be transformed into sensible OpenAPI paths. In this case, we override the later @@ -270,26 +270,14 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc * // Process path and header parameters, which are common to all operations. // Body fields will be added to individual operations. - pathFields, bodyFields := splitFields(p.Fields, path) + pathFields, queryFields, bodyFields := splitFields(p.Fields, path, captures) for name, field := range pathFields { - location := "path" - required := true - - if field == nil { - continue - } - - if field.Query { - location = "query" - required = false - } - t := convertType(field.Type) p := OASParameter{ Name: name, Description: cleanString(field.Description), - In: location, + In: "path", Schema: &OASSchema{ Type: t.baseType, Pattern: t.pattern, @@ -297,7 +285,7 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc * Default: field.Default, DisplayAttrs: withoutOperationHints(field.DisplayAttrs), }, - Required: required, + Required: true, Deprecated: field.Deprecated, } pi.Parameters = append(pi.Parameters, p) @@ -342,8 +330,12 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc * op.Deprecated = props.Deprecated op.OperationID = operationID - // Add any fields not present in the path as body parameters for POST. - if opType == logical.CreateOperation || opType == logical.UpdateOperation { + switch opType { + // For the operation types which map to POST/PUT methods, and so allow for request body parameters, + // prepare the request body definition + case logical.CreateOperation: + fallthrough + case logical.UpdateOperation: s := &OASSchema{ Type: "object", Properties: make(map[string]*OASSchema), @@ -357,27 +349,14 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc * continue } - openapiField := convertType(field.Type) - if field.Required { - s.Required = append(s.Required, name) - } + addFieldToOASSchema(s, name, field) + } - p := OASSchema{ - Type: openapiField.baseType, - Description: cleanString(field.Description), - Format: openapiField.format, - Pattern: openapiField.pattern, - Enum: field.AllowedValues, - Default: field.Default, - Deprecated: field.Deprecated, - DisplayAttrs: withoutOperationHints(field.DisplayAttrs), - } - if openapiField.baseType == "array" { - p.Items = &OASSchema{ - Type: openapiField.items, - } - } - s.Properties[name] = &p + // Contrary to what one might guess, fields marked with "Query: true" are only query fields when the + // request method is one which does not allow for a request body - they are still body fields when + // dealing with a POST/PUT request. + for name, field := range queryFields { + addFieldToOASSchema(s, name, field) } // Make the ordering deterministic, so that the generated OpenAPI spec document, observed over several @@ -426,12 +405,12 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc * }, } } - } - // LIST is represented as GET with a `list` query parameter. Code later on in this function will assign - // list operations to a path with an extra trailing slash, ensuring they do not collide with read - // operations. - if opType == logical.ListOperation { + // For the operation types which map to HTTP methods without a request body, populate query parameters + case logical.ListOperation: + // LIST is represented as GET with a `list` query parameter. Code later on in this function will assign + // list operations to a path with an extra trailing slash, ensuring they do not collide with read + // operations. op.Parameters = append(op.Parameters, OASParameter{ Name: "list", Description: "Must be set to `true`", @@ -439,6 +418,27 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc * In: "query", Schema: &OASSchema{Type: "string", Enum: []interface{}{"true"}}, }) + fallthrough + case logical.DeleteOperation: + fallthrough + case logical.ReadOperation: + for name, field := range queryFields { + t := convertType(field.Type) + p := OASParameter{ + Name: name, + Description: cleanString(field.Description), + In: "query", + Schema: &OASSchema{ + Type: t.baseType, + Pattern: t.pattern, + Enum: field.AllowedValues, + Default: field.Default, + DisplayAttrs: withoutOperationHints(field.DisplayAttrs), + }, + Deprecated: field.Deprecated, + } + op.Parameters = append(op.Parameters, p) + } } // Add tags based on backend type @@ -612,6 +612,31 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc * return nil } +func addFieldToOASSchema(s *OASSchema, name string, field *FieldSchema) { + openapiField := convertType(field.Type) + if field.Required { + s.Required = append(s.Required, name) + } + + p := OASSchema{ + Type: openapiField.baseType, + Description: cleanString(field.Description), + Format: openapiField.format, + Pattern: openapiField.pattern, + Enum: field.AllowedValues, + Default: field.Default, + Deprecated: field.Deprecated, + DisplayAttrs: withoutOperationHints(field.DisplayAttrs), + } + if openapiField.baseType == "array" { + p.Items = &OASSchema{ + Type: openapiField.items, + } + } + + s.Properties[name] = &p +} + // specialPathMatch checks whether the given path matches one of the special // paths, taking into account * and + wildcards (e.g. foo/+/bar/*) func specialPathMatch(path string, specialPaths []string) bool { @@ -776,8 +801,9 @@ func constructOperationID( } // expandPattern expands a regex pattern by generating permutations of any optional parameters -// and changing named parameters into their {openapi} equivalents. -func expandPattern(pattern string) ([]string, error) { +// and changing named parameters into their {openapi} equivalents. It also returns the names of all capturing groups +// observed in the pattern. +func expandPattern(pattern string) (paths []string, captures map[string]struct{}, err error) { // Happily, the Go regexp library exposes its underlying "parse to AST" functionality, so we can rely on that to do // the hard work of interpreting the regexp syntax. rx, err := syntax.Parse(pattern, syntax.Perl) @@ -787,12 +813,12 @@ func expandPattern(pattern string) ([]string, error) { panic(err) } - paths, err := collectPathsFromRegexpAST(rx) + paths, captures, err = collectPathsFromRegexpAST(rx) if err != nil { - return nil, err + return nil, nil, err } - return paths, nil + return paths, captures, nil } type pathCollector struct { @@ -813,23 +839,28 @@ type pathCollector struct { // // Each named capture group - i.e. (?Psomething here) - is replaced with an OpenAPI parameter - i.e. {name} - and // the subtree of regexp AST inside the parameter is completely skipped. -func collectPathsFromRegexpAST(rx *syntax.Regexp) ([]string, error) { - pathCollectors, err := collectPathsFromRegexpASTInternal(rx, []*pathCollector{{}}) +func collectPathsFromRegexpAST(rx *syntax.Regexp) (paths []string, captures map[string]struct{}, err error) { + captures = make(map[string]struct{}) + pathCollectors, err := collectPathsFromRegexpASTInternal(rx, []*pathCollector{{}}, captures) if err != nil { - return nil, err + return nil, nil, err } - paths := make([]string, 0, len(pathCollectors)) + paths = make([]string, 0, len(pathCollectors)) for _, collector := range pathCollectors { if collector.conditionalSlashAppendedAtLength != collector.Len() { paths = append(paths, collector.String()) } } - return paths, nil + return paths, captures, nil } var errUnsupportableRegexpOperationForOpenAPI = errors.New("path regexp uses an operation that cannot be translated to an OpenAPI pattern") -func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCollector) ([]*pathCollector, error) { +func collectPathsFromRegexpASTInternal( + rx *syntax.Regexp, + appendingTo []*pathCollector, + captures map[string]struct{}, +) ([]*pathCollector, error) { var err error // Depending on the type of this regexp AST node (its Op, i.e. operation), figure out whether it contributes any @@ -856,7 +887,7 @@ func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCol // those pieces. case syntax.OpConcat: for _, child := range rx.Sub { - appendingTo, err = collectPathsFromRegexpASTInternal(child, appendingTo) + appendingTo, err = collectPathsFromRegexpASTInternal(child, appendingTo, captures) if err != nil { return nil, err } @@ -887,7 +918,7 @@ func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCol childAppendingTo = append(childAppendingTo, newCollector) } } - childAppendingTo, err = collectPathsFromRegexpASTInternal(child, childAppendingTo) + childAppendingTo, err = collectPathsFromRegexpASTInternal(child, childAppendingTo, captures) if err != nil { return nil, err } @@ -905,7 +936,7 @@ func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCol newCollector.conditionalSlashAppendedAtLength = collector.conditionalSlashAppendedAtLength childAppendingTo = append(childAppendingTo, newCollector) } - childAppendingTo, err = collectPathsFromRegexpASTInternal(child, childAppendingTo) + childAppendingTo, err = collectPathsFromRegexpASTInternal(child, childAppendingTo, captures) if err != nil { return nil, err } @@ -927,7 +958,7 @@ func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCol // In Vault, an unnamed capturing group is not actually used for capturing. // We treat it exactly the same as OpConcat. for _, child := range rx.Sub { - appendingTo, err = collectPathsFromRegexpASTInternal(child, appendingTo) + appendingTo, err = collectPathsFromRegexpASTInternal(child, appendingTo, captures) if err != nil { return nil, err } @@ -940,6 +971,7 @@ func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCol builder.WriteString(rx.Name) builder.WriteRune('}') } + captures[rx.Name] = struct{}{} } // Any other kind of operation is a problem, and will trigger an error, resulting in the pattern being left out of @@ -1041,29 +1073,37 @@ func cleanString(s string) string { return s } -// splitFields partitions fields into path and body groups -// The input pattern is expected to have been run through expandPattern, -// with paths parameters denotes in {braces}. -func splitFields(allFields map[string]*FieldSchema, pattern string) (pathFields, bodyFields map[string]*FieldSchema) { +// splitFields partitions fields into path, query and body groups. It uses information on capturing groups previously +// collected by expandPattern, which is necessary to correctly match the treatment in (*Backend).HandleRequest: +// a field counts as a path field if it appears in any capture in the regex, and if that capture was inside an +// alternation or optional part of the regex which does not survive in the OpenAPI path pattern currently being +// processed, that field should NOT be rendered to the OpenAPI spec AT ALL. +func splitFields( + allFields map[string]*FieldSchema, + openAPIPathPattern string, + captures map[string]struct{}, +) (pathFields, queryFields, bodyFields map[string]*FieldSchema) { pathFields = make(map[string]*FieldSchema) + queryFields = make(map[string]*FieldSchema) bodyFields = make(map[string]*FieldSchema) - for _, match := range pathFieldsRe.FindAllStringSubmatch(pattern, -1) { + for _, match := range pathFieldsRe.FindAllStringSubmatch(openAPIPathPattern, -1) { name := match[1] pathFields[name] = allFields[name] } for name, field := range allFields { - if _, ok := pathFields[name]; !ok { + // Any field which relates to a regex capture was already processed above, if it needed to be. + if _, ok := captures[name]; !ok { if field.Query { - pathFields[name] = field + queryFields[name] = field } else { bodyFields[name] = field } } } - return pathFields, bodyFields + return pathFields, queryFields, bodyFields } // withoutOperationHints returns a copy of the given DisplayAttributes without diff --git a/sdk/framework/openapi_test.go b/sdk/framework/openapi_test.go index e37d5f5119..0b6185b4d6 100644 --- a/sdk/framework/openapi_test.go +++ b/sdk/framework/openapi_test.go @@ -160,13 +160,13 @@ func TestOpenAPI_ExpandPattern(t *testing.T) { } for i, test := range tests { - out, err := expandPattern(test.inPattern) + paths, _, err := expandPattern(test.inPattern) if err != nil { t.Fatal(err) } - sort.Strings(out) - if !reflect.DeepEqual(out, test.outPathlets) { - t.Fatalf("Test %d: Expected %v got %v", i, test.outPathlets, out) + sort.Strings(paths) + if !reflect.DeepEqual(paths, test.outPathlets) { + t.Fatalf("Test %d: Expected %v got %v", i, test.outPathlets, paths) } } } @@ -188,7 +188,7 @@ func TestOpenAPI_ExpandPattern_ReturnsError(t *testing.T) { } for i, test := range tests { - _, err := expandPattern(test.inPattern) + _, _, err := expandPattern(test.inPattern) if err != test.outError { t.Fatalf("Test %d: Expected %q got %q", i, test.outError, err) } @@ -196,31 +196,50 @@ func TestOpenAPI_ExpandPattern_ReturnsError(t *testing.T) { } func TestOpenAPI_SplitFields(t *testing.T) { + paths, captures, err := expandPattern("some/" + GenericNameRegex("a") + "/path" + OptionalParamRegex("e")) + if err != nil { + t.Fatal(err) + } + fields := map[string]*FieldSchema{ "a": {Description: "path"}, "b": {Description: "body"}, "c": {Description: "body"}, "d": {Description: "body"}, "e": {Description: "path"}, + "f": {Description: "query", Query: true}, } - pathFields, bodyFields := splitFields(fields, "some/{a}/path/{e}") + for index, path := range paths { + pathFields, queryFields, bodyFields := splitFields(fields, path, captures) - lp := len(pathFields) - lb := len(bodyFields) - l := len(fields) - if lp+lb != l { - t.Fatalf("split length error: %d + %d != %d", lp, lb, l) - } - - for name, field := range pathFields { - if field.Description != "path" { - t.Fatalf("expected field %s to be in 'path', found in %s", name, field.Description) + numPath := len(pathFields) + numQuery := len(queryFields) + numBody := len(bodyFields) + numExpectedDiscarded := 0 + // The first path generated is expected to be the one omitting the optional parameter field "e" + if index == 0 { + numExpectedDiscarded = 1 } - } - for name, field := range bodyFields { - if field.Description != "body" { - t.Fatalf("expected field %s to be in 'body', found in %s", name, field.Description) + l := len(fields) + if numPath+numQuery+numBody+numExpectedDiscarded != l { + t.Fatalf("split length error: %d + %d + %d + %d != %d", numPath, numQuery, numBody, numExpectedDiscarded, l) + } + + for name, field := range pathFields { + if field.Description != "path" { + t.Fatalf("expected field %s to be in 'path', found in %s", name, field.Description) + } + } + for name, field := range queryFields { + if field.Description != "query" { + t.Fatalf("expected field %s to be in 'query', found in %s", name, field.Description) + } + } + for name, field := range bodyFields { + if field.Description != "body" { + t.Fatalf("expected field %s to be in 'body', found in %s", name, field.Description) + } } } }