From 14c6d99b8e9abf6801838d1ca2ee2e91396dabc0 Mon Sep 17 00:00:00 2001 From: Morten Torkildsen Date: Mon, 5 Jan 2026 23:32:08 +0000 Subject: [PATCH] DRA: Add integration tests for Partitionable Devices --- pkg/scheduler/testing/wrappers.go | 16 ++ test/integration/dra/dra_test.go | 8 +- test/integration/dra/helpers_test.go | 8 +- .../dra/partitionable_devices_test.go | 172 ++++++++++++++++++ 4 files changed, 201 insertions(+), 3 deletions(-) create mode 100644 test/integration/dra/partitionable_devices_test.go diff --git a/pkg/scheduler/testing/wrappers.go b/pkg/scheduler/testing/wrappers.go index 224e6c1fe18..422c8a9251e 100644 --- a/pkg/scheduler/testing/wrappers.go +++ b/pkg/scheduler/testing/wrappers.go @@ -1336,6 +1336,16 @@ func MakeResourceSlice(nodeName, driverName string) *ResourceSliceWrapper { return wrapper } +func MakeResourceSliceWithPerDeviceNodeSelection(namePrefix, driverName string) *ResourceSliceWrapper { + wrapper := new(ResourceSliceWrapper) + wrapper.Name = namePrefix + "-" + driverName + wrapper.Spec.PerDeviceNodeSelection = ptr.To(true) + wrapper.Spec.Pool.Name = namePrefix + wrapper.Spec.Pool.ResourceSliceCount = 1 + wrapper.Spec.Driver = driverName + return wrapper +} + // FromResourceSlice creates a ResourceSlice wrapper from some existing object. func FromResourceSlice(other *resourceapi.ResourceSlice) *ResourceSliceWrapper { return &ResourceSliceWrapper{*other.DeepCopy()} @@ -1353,6 +1363,8 @@ func (wrapper *ResourceSliceWrapper) Devices(names ...string) *ResourceSliceWrap return wrapper } +type NodeName string + // Device extends the devices field of the inner object. // The device must have a name and may have arbitrary additional fields. func (wrapper *ResourceSliceWrapper) Device(name string, otherFields ...any) *ResourceSliceWrapper { @@ -1365,6 +1377,10 @@ func (wrapper *ResourceSliceWrapper) Device(name string, otherFields ...any) *Re device.Capacity = typedField case resourceapi.DeviceTaint: device.Taints = append(device.Taints, typedField) + case NodeName: + device.NodeName = (*string)(&typedField) + case *v1.NodeSelector: + device.NodeSelector = typedField default: panic(fmt.Sprintf("expected a type which matches a field in BasicDevice, got %T", field)) } diff --git a/test/integration/dra/dra_test.go b/test/integration/dra/dra_test.go index 347006d70df..db430029d1b 100644 --- a/test/integration/dra/dra_test.go +++ b/test/integration/dra/dra_test.go @@ -163,6 +163,7 @@ func TestDRA(t *testing.T) { }, f: func(tCtx ktesting.TContext) { tCtx.Run("AdminAccess", func(tCtx ktesting.TContext) { testAdminAccess(tCtx, false) }) + tCtx.Run("PartitionableDevices", func(tCtx ktesting.TContext) { testPartitionableDevices(tCtx, false) }) tCtx.Run("PrioritizedList", func(tCtx ktesting.TContext) { testPrioritizedList(tCtx, false) }) tCtx.Run("Pod", func(tCtx ktesting.TContext) { testPod(tCtx, true) }) tCtx.Run("PublishResourceSlices", func(tCtx ktesting.TContext) { @@ -235,6 +236,7 @@ func TestDRA(t *testing.T) { tCtx.Run("Convert", testConvert) tCtx.Run("ControllerManagerMetrics", testControllerManagerMetrics) tCtx.Run("DeviceBindingConditions", func(tCtx ktesting.TContext) { testDeviceBindingConditions(tCtx, true) }) + tCtx.Run("PartitionableDevices", func(tCtx ktesting.TContext) { testPartitionableDevices(tCtx, true) }) tCtx.Run("PrioritizedList", func(tCtx ktesting.TContext) { testPrioritizedList(tCtx, true) }) tCtx.Run("PrioritizedListScoring", func(tCtx ktesting.TContext) { testPrioritizedListScoring(tCtx) }) tCtx.Run("PublishResourceSlices", func(tCtx ktesting.TContext) { testPublishResourceSlices(tCtx, true) }) @@ -295,10 +297,14 @@ func TestDRA(t *testing.T) { func createNodes(tCtx ktesting.TContext) { for i := 0; i < numNodes; i++ { + nodeName := fmt.Sprintf("worker-%d", i) // Create node. node := &v1.Node{ ObjectMeta: metav1.ObjectMeta{ - Name: fmt.Sprintf("worker-%d", i), + Name: nodeName, + Labels: map[string]string{ + "kubernetes.io/hostname": nodeName, + }, }, } node, err := tCtx.Client().CoreV1().Nodes().Create(tCtx, node, metav1.CreateOptions{FieldValidation: "Strict"}) diff --git a/test/integration/dra/helpers_test.go b/test/integration/dra/helpers_test.go index d83689851e5..da1f5adefa0 100644 --- a/test/integration/dra/helpers_test.go +++ b/test/integration/dra/helpers_test.go @@ -162,11 +162,14 @@ func createPod(tCtx ktesting.TContext, namespace string, suffix string, pod *v1. return pod } -func waitForPodScheduled(tCtx ktesting.TContext, namespace, podName string) { +func waitForPodScheduled(tCtx ktesting.TContext, namespace, podName string) *v1.Pod { tCtx.Helper() + var pod *v1.Pod tCtx.Eventually(func(tCtx ktesting.TContext) (*v1.Pod, error) { - return tCtx.Client().CoreV1().Pods(namespace).Get(tCtx, podName, metav1.GetOptions{}) + p, err := tCtx.Client().CoreV1().Pods(namespace).Get(tCtx, podName, metav1.GetOptions{}) + pod = p + return p, err }).WithTimeout(60*time.Second).Should( gomega.HaveField("Status.Conditions", gomega.ContainElement( gomega.And( @@ -176,6 +179,7 @@ func waitForPodScheduled(tCtx ktesting.TContext, namespace, podName string) { )), "Pod %s should have been scheduled.", podName, ) + return pod } func deleteAndWait[T any](tCtx ktesting.TContext, del func(context.Context, string, metav1.DeleteOptions) error, get func(context.Context, string, metav1.GetOptions) (T, error), name string) { diff --git a/test/integration/dra/partitionable_devices_test.go b/test/integration/dra/partitionable_devices_test.go new file mode 100644 index 00000000000..87c2323f32f --- /dev/null +++ b/test/integration/dra/partitionable_devices_test.go @@ -0,0 +1,172 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dra + +import ( + "fmt" + + "github.com/onsi/gomega" + "github.com/stretchr/testify/require" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + st "k8s.io/kubernetes/pkg/scheduler/testing" + "k8s.io/kubernetes/test/utils/ktesting" +) + +func testPartitionableDevices(tCtx ktesting.TContext, enabled bool) { + if enabled { + tCtx.Run("PerDeviceNodeSelection", testPerDeviceNodeSelection) + tCtx.Run("MultiHostDevice", testPartitionableDevicesWithMultiHostDevice) + } else { + testDisabled(tCtx) + } +} + +// testDisabled verifies that creating ResourceSlices with node selection +// perDeviceNodeSelection fails when the Partitionable Devices feature is +// disabled. +func testDisabled(tCtx ktesting.TContext) { + namespace := createTestNamespace(tCtx, nil) + _, driverName := createTestClass(tCtx, namespace) + + slice := st.MakeResourceSliceWithPerDeviceNodeSelection("slice", driverName) + _, err := tCtx.Client().ResourceV1().ResourceSlices().Create(tCtx, slice.Obj(), metav1.CreateOptions{}) + require.Error(tCtx, err, "slice should have become invalid after dropping PartitionableDevices") +} + +// testPerDeviceNodeSelection verifies that pods are scheduled +// on the correct nodes when they are allocated devices that +// speficy node selection using the perDeviceNodeSelection field +// that was introduced as part of the Partitionable Devices +// feature. +func testPerDeviceNodeSelection(tCtx ktesting.TContext) { + namespace := createTestNamespace(tCtx, nil) + class, driverName := createTestClass(tCtx, namespace) + + nodes, err := tCtx.Client().CoreV1().Nodes().List(tCtx, metav1.ListOptions{}) + tCtx.ExpectNoError(err, "list nodes") + + slice := st.MakeResourceSliceWithPerDeviceNodeSelection("slice", driverName) + for _, node := range nodes.Items { + slice.Device(fmt.Sprintf("device-for-%s", node.Name), st.NodeName(node.Name)) + } + createSlice(tCtx, slice.Obj()) + + startScheduler(tCtx) + + for i := range nodes.Items { + claim := st.MakeResourceClaim(). + Name(fmt.Sprintf("claim-%d", i)). + Namespace(namespace). + Request(class.Name). + Obj() + createdClaim, err := tCtx.Client().ResourceV1().ResourceClaims(namespace).Create(tCtx, claim, metav1.CreateOptions{}) + tCtx.ExpectNoError(err, fmt.Sprintf("claim name %q", createdClaim.Name)) + + pod := st.MakePod().Name(podName).Namespace(namespace). + Container("my-container"). + Obj() + createdPod := createPod(tCtx, namespace, fmt.Sprintf("-%d", i), pod, claim) + + scheduledPod := waitForPodScheduled(tCtx, namespace, createdPod.Name) + + allocatedClaim, err := tCtx.Client().ResourceV1().ResourceClaims(namespace).Get(tCtx, createdClaim.Name, metav1.GetOptions{}) + tCtx.ExpectNoError(err, fmt.Sprintf("get claim %q", createdClaim.Name)) + tCtx.Expect(allocatedClaim).To(gomega.HaveField("Status.Allocation", gomega.Not(gomega.BeNil())), "Claim should have been allocated.") + + nodeName := scheduledPod.Spec.NodeName + expectedAllocatedDevice := fmt.Sprintf("device-for-%s", nodeName) + tCtx.Expect(allocatedClaim.Status.Allocation.Devices.Results[0].Device).To(gomega.Equal(expectedAllocatedDevice)) + } +} + +// testPartitionableDevicesWithMultiHostDevice verifies that multiple pods sharing +// a ResourceClaim that is assigned a multi-host devices gets scheduled correctly +// on the nodes selected by the node selector on the device. +func testPartitionableDevicesWithMultiHostDevice(tCtx ktesting.TContext) { + namespace := createTestNamespace(tCtx, nil) + class, driverName := createTestClass(tCtx, namespace) + + nodes, err := tCtx.Client().CoreV1().Nodes().List(tCtx, metav1.ListOptions{}) + tCtx.ExpectNoError(err, "list nodes") + + minNodeCount := 4 + if nodeCount := len(nodes.Items); nodeCount < minNodeCount { + tCtx.Errorf("found only %d nodes, need at least %d", nodeCount, minNodeCount) + } + + deviceNodes := []string{ + nodes.Items[0].Name, + nodes.Items[1].Name, + nodes.Items[2].Name, + nodes.Items[3].Name, + } + + slice := st.MakeResourceSliceWithPerDeviceNodeSelection("slice", driverName) + slice.Device("multi-host-device", &v1.NodeSelector{ + NodeSelectorTerms: []v1.NodeSelectorTerm{{ + MatchExpressions: []v1.NodeSelectorRequirement{{ + Key: "kubernetes.io/hostname", + Operator: v1.NodeSelectorOpIn, + Values: deviceNodes, + }}, + }}, + }) + createSlice(tCtx, slice.Obj()) + + startScheduler(tCtx) + + claim := st.MakeResourceClaim(). + Name("multi-host-claim"). + Namespace(namespace). + Request(class.Name). + Obj() + createdClaim, err := tCtx.Client().ResourceV1().ResourceClaims(namespace).Create(tCtx, claim, metav1.CreateOptions{}) + tCtx.ExpectNoError(err, fmt.Sprintf("claim name %q", createdClaim.Name)) + + labelKey := "app" + labelValue := "multiHost" + labelSelector := &metav1.LabelSelector{ + MatchExpressions: []metav1.LabelSelectorRequirement{ + { + Key: labelKey, + Operator: metav1.LabelSelectorOpIn, + Values: []string{labelValue}, + }, + }, + } + var podNames []string + for i := range 4 { + pod := st.MakePod(). + Name(podName). + Namespace(namespace). + Labels(map[string]string{labelKey: labelValue}). + PodAntiAffinity("kubernetes.io/hostname", labelSelector, st.PodAntiAffinityWithRequiredReq). + Container("my-container"). + Obj() + createdPod := createPod(tCtx, namespace, fmt.Sprintf("-%d", i), pod, claim) + podNames = append(podNames, createdPod.Name) + } + + var scheduledOnNodes []string + for _, podName := range podNames { + scheduledPod := waitForPodScheduled(tCtx, namespace, podName) + scheduledOnNodes = append(scheduledOnNodes, scheduledPod.Spec.NodeName) + } + + tCtx.Expect(scheduledOnNodes).To(gomega.ConsistOf(deviceNodes)) +}