From 070cb9ec4875f5bdaaa73fd293cfb5a809e8780d Mon Sep 17 00:00:00 2001 From: Travis O'Neal Date: Wed, 8 Apr 2026 14:00:25 -0400 Subject: [PATCH] Add PlacementCycleState to WAS scheduler framework MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added PlacementCycleState as a third state scope for WAS, alongside pod-level CycleState and PodGroupCycleState. This is foundational plumbing only — plugin adoption is a follow-up. Signed-off-by: Travis O'Neal --- pkg/scheduler/framework/cycle_state.go | 13 ++ pkg/scheduler/framework/cycle_state_test.go | 86 ++++++++ pkg/scheduler/schedule_one_podgroup.go | 25 ++- pkg/scheduler/schedule_one_podgroup_test.go | 193 +++++++++++++++++- .../kube-scheduler/framework/cycle_state.go | 33 +++ .../k8s.io/kube-scheduler/framework/types.go | 3 + 6 files changed, 343 insertions(+), 10 deletions(-) diff --git a/pkg/scheduler/framework/cycle_state.go b/pkg/scheduler/framework/cycle_state.go index 31b78725356..32bc015c5a8 100644 --- a/pkg/scheduler/framework/cycle_state.go +++ b/pkg/scheduler/framework/cycle_state.go @@ -46,6 +46,10 @@ type CycleState struct { // or doesn't belong to any pod group. // This field can only be non-nil when GenericWorkload feature flag is enabled. podGroupCycleState fwk.PodGroupCycleState + // placementCycleState contains the CycleState for the current Placement being evaluated. + // If set to nil, it means this pod is not being scheduled within a placement context. + // This field can only be non-nil when GenericWorkload feature flag is enabled. + placementCycleState fwk.PlacementCycleState } // NewCycleState initializes a new CycleState and returns its pointer. @@ -113,6 +117,14 @@ func (c *CycleState) GetPodGroupSchedulingCycle() fwk.PodGroupCycleState { return c.podGroupCycleState } +func (c *CycleState) GetPlacementCycleState() fwk.PlacementCycleState { + return c.placementCycleState +} + +func (c *CycleState) SetPlacementCycleState(placementCycleState fwk.PlacementCycleState) { + c.placementCycleState = placementCycleState +} + func (c *CycleState) SetSkipAllPostFilterPlugins(flag bool) { c.skipAllPostFilterPlugins = flag } @@ -140,6 +152,7 @@ func (c *CycleState) Clone() fwk.CycleState { copy.skipPreBindPlugins = c.skipPreBindPlugins copy.parallelPreBindPlugins = c.parallelPreBindPlugins copy.podGroupCycleState = c.podGroupCycleState + copy.placementCycleState = c.placementCycleState copy.skipAllPostFilterPlugins = c.skipAllPostFilterPlugins return copy diff --git a/pkg/scheduler/framework/cycle_state_test.go b/pkg/scheduler/framework/cycle_state_test.go index 70b70085fdf..e72b220e94d 100644 --- a/pkg/scheduler/framework/cycle_state_test.go +++ b/pkg/scheduler/framework/cycle_state_test.go @@ -193,3 +193,89 @@ func TestCycleStateClone(t *testing.T) { }) } } + +func TestPlacementCycleState(t *testing.T) { + t.Run("nil by default", func(t *testing.T) { + state := NewCycleState() + if state.GetPlacementCycleState() != nil { + t.Errorf("expected nil PlacementCycleState on fresh CycleState") + } + }) + + t.Run("set and get", func(t *testing.T) { + state := NewCycleState() + placementState := NewCycleState() + placementState.Write("testkey", &fakeData{data: "placementdata"}) + + state.SetPlacementCycleState(placementState) + + got := state.GetPlacementCycleState() + if got == nil { + t.Fatal("expected non-nil PlacementCycleState after Set") + } + + data, err := got.Read("testkey") + if err != nil { + t.Fatalf("unexpected error reading from PlacementCycleState: %v", err) + } + if data.(*fakeData).data != "placementdata" { + t.Errorf("expected 'placementdata', got %q", data.(*fakeData).data) + } + }) + + t.Run("set to nil clears", func(t *testing.T) { + state := NewCycleState() + state.SetPlacementCycleState(NewCycleState()) + state.SetPlacementCycleState(nil) + + if state.GetPlacementCycleState() != nil { + t.Errorf("expected nil PlacementCycleState after setting to nil") + } + }) + + t.Run("clone preserves reference", func(t *testing.T) { + state := NewCycleState() + state.Write(key, &fakeData{data: "pod-data"}) + + placementState := NewCycleState() + placementState.Write("pkey", &fakeData{data: "placement-data"}) + state.SetPlacementCycleState(placementState) + + cloned := state.Clone().(*CycleState) + + // The cloned state should reference the same PlacementCycleState. + if cloned.GetPlacementCycleState() == nil { + t.Fatal("cloned state should have non-nil PlacementCycleState") + } + + data, err := cloned.GetPlacementCycleState().Read("pkey") + if err != nil { + t.Fatalf("unexpected error reading from cloned PlacementCycleState: %v", err) + } + if data.(*fakeData).data != "placement-data" { + t.Errorf("expected 'placement-data', got %q", data.(*fakeData).data) + } + + // Writes to the PlacementCycleState via the clone should be visible from the original, + // since it's a shared reference (same as podGroupCycleState behavior). + cloned.GetPlacementCycleState().Write("newkey", &fakeData{data: "new"}) + newData, err := state.GetPlacementCycleState().Read("newkey") + if err != nil { + t.Fatalf("write via clone's PlacementCycleState should be visible from original: %v", err) + } + if newData.(*fakeData).data != "new" { + t.Errorf("expected 'new', got %q", newData.(*fakeData).data) + } + }) + + t.Run("clone with nil placement state", func(t *testing.T) { + state := NewCycleState() + state.Write(key, &fakeData{data: "data"}) + // Do not set PlacementCycleState — leave nil. + + cloned := state.Clone().(*CycleState) + if cloned.GetPlacementCycleState() != nil { + t.Errorf("cloned state should have nil PlacementCycleState when original has nil") + } + }) +} diff --git a/pkg/scheduler/schedule_one_podgroup.go b/pkg/scheduler/schedule_one_podgroup.go index 78c363ad428..50440c5b10d 100644 --- a/pkg/scheduler/schedule_one_podgroup.go +++ b/pkg/scheduler/schedule_one_podgroup.go @@ -182,7 +182,7 @@ type podSchedulingContext struct { } // initPodSchedulingContext initializes the scheduling context of a single pod for pod group scheduling cycle. -func initPodSchedulingContext(ctx context.Context, pod *v1.Pod, podGroupState *framework.CycleState, postFilterMode podGroupPostFilterMode) *podSchedulingContext { +func initPodSchedulingContext(ctx context.Context, pod *v1.Pod, podGroupState *framework.CycleState, placementCycleState fwk.PlacementCycleState, postFilterMode podGroupPostFilterMode) *podSchedulingContext { logger := klog.FromContext(ctx) // TODO(knelasevero): Remove duplicated keys from log entry calls // When contextualized logging hits GA @@ -202,6 +202,8 @@ func initPodSchedulingContext(ctx context.Context, pod *v1.Pod, podGroupState *f // Marks this cycle as a pod group scheduling cycle. state.SetPodGroupSchedulingCycle(podGroupState) + // Set the placement cycle state so per-pod plugins can access placement-scoped data. + state.SetPlacementCycleState(placementCycleState) // Skip post filters if requested. switch postFilterMode { @@ -358,6 +360,9 @@ type podGroupAlgorithmResult struct { // waitingOnPreemption indicates whether this pod group requires or is waiting for preemption to complete. // This can only be set to true when the status is Unschedulable. waitingOnPreemption bool + // placementCycleState holds the state accumulated during simulation of a specific placement. + // Only set for successful placements that proceed to the scoring phase. + placementCycleState fwk.PlacementCycleState } // podGroupSchedulingDefaultAlgorithm runs the default algorithm for scheduling a pod group. @@ -365,23 +370,24 @@ type podGroupAlgorithmResult struct { // If a pod requires preemption to be schedulable, subsequent pods in the algorithm // treat that pod as already scheduled on that node with victims being already removed in memory. func (sched *Scheduler) podGroupSchedulingDefaultAlgorithm(ctx context.Context, schedFwk framework.Framework, podGroupCycleState *framework.CycleState, podGroupInfo *framework.QueuedPodGroupInfo, postFilterMode podGroupPostFilterMode) podGroupAlgorithmResult { + placementCycleState := framework.NewCycleState() + placementCycleState.SetRecordPluginMetrics(true) + placementCycleState.SetPodGroupSchedulingCycle(podGroupCycleState) + result := podGroupAlgorithmResult{ podResults: make([]algorithmResult, 0, len(podGroupInfo.QueuedPodInfos)), status: fwk.NewStatus(fwk.Unschedulable).WithError(errPodGroupUnschedulable), waitingOnPreemption: false, + placementCycleState: placementCycleState, } - placementCycleState := framework.NewCycleState() - placementCycleState.SetRecordPluginMetrics(true) - placementCycleState.SetPodGroupSchedulingCycle(podGroupCycleState) - logger := klog.FromContext(ctx) logger.V(5).Info("Running a pod group scheduling algorithm", "podGroup", klog.KObj(podGroupInfo), "unscheduledPodsCount", len(podGroupInfo.QueuedPodInfos)) requiresPreemption := false anyScheduled := false for _, podInfo := range podGroupInfo.QueuedPodInfos { - podResult, revertFn := sched.podGroupPodSchedulingAlgorithm(ctx, schedFwk, podGroupCycleState, podGroupInfo, podInfo, postFilterMode) + podResult, revertFn := sched.podGroupPodSchedulingAlgorithm(ctx, schedFwk, podGroupCycleState, placementCycleState, podGroupInfo, podInfo, postFilterMode) result.podResults = append(result.podResults, podResult) if revertFn != nil { // We unreserve the pod at the end of the whole algorithm (via defer) because it should be ultimately returned to the queue, @@ -435,9 +441,9 @@ func (sched *Scheduler) podGroupSchedulingDefaultAlgorithm(ctx context.Context, // podGroupPodSchedulingAlgorithm runs a scheduling algorithm for individual pod from a pod group. // It returns the algorithm result together with the revert function. -func (sched *Scheduler) podGroupPodSchedulingAlgorithm(ctx context.Context, schedFwk framework.Framework, podGroupCycleState *framework.CycleState, podGroupInfo *framework.QueuedPodGroupInfo, podInfo *framework.QueuedPodInfo, postFilterMode podGroupPostFilterMode) (algorithmResult, func()) { +func (sched *Scheduler) podGroupPodSchedulingAlgorithm(ctx context.Context, schedFwk framework.Framework, podGroupCycleState *framework.CycleState, placementCycleState fwk.PlacementCycleState, podGroupInfo *framework.QueuedPodGroupInfo, podInfo *framework.QueuedPodInfo, postFilterMode podGroupPostFilterMode) (algorithmResult, func()) { pod := podInfo.Pod - podCtx := initPodSchedulingContext(ctx, pod, podGroupCycleState, postFilterMode) + podCtx := initPodSchedulingContext(ctx, pod, podGroupCycleState, placementCycleState, postFilterMode) logger := podCtx.logger ctx = klog.NewContext(ctx, logger) start := time.Now() @@ -506,7 +512,7 @@ func completePodGroupAlgorithmResult(ctx context.Context, podGroupInfo *framewor pInfo := podGroupInfo.QueuedPodInfos[i] newResults[i] = algorithmResult{ pod: pInfo.Pod, - podCtx: initPodSchedulingContext(ctx, pInfo.Pod, podGroupState, postFilterMode), + podCtx: initPodSchedulingContext(ctx, pInfo.Pod, podGroupState, nil, postFilterMode), status: podGroupResult.status.Clone(), } } @@ -783,6 +789,7 @@ func makePodGroupAssignments(successfulResults map[*fwk.Placement]*podGroupAlgor placementPodGroupAssignments = append(placementPodGroupAssignments, &fwk.PodGroupAssignments{ Placement: placement, ProposedAssignments: proposedAssignments, + PlacementCycleState: result.placementCycleState, }) } return placementPodGroupAssignments diff --git a/pkg/scheduler/schedule_one_podgroup_test.go b/pkg/scheduler/schedule_one_podgroup_test.go index 330abad102e..6d827007811 100644 --- a/pkg/scheduler/schedule_one_podgroup_test.go +++ b/pkg/scheduler/schedule_one_podgroup_test.go @@ -1751,7 +1751,7 @@ func TestSubmitPodGroupAlgorithmResult(t *testing.T) { for i := range tt.algorithmResult.podResults { pod := podGroupInfo.QueuedPodInfos[i].Pod - podCtx := initPodSchedulingContext(ctx, pod, podGroupCycleState, runAllPostFilters) + podCtx := initPodSchedulingContext(ctx, pod, podGroupCycleState, nil, runAllPostFilters) tt.algorithmResult.podResults[i].podCtx = podCtx } @@ -2538,6 +2538,7 @@ func TestPodGroupSchedulingPlacementAlgorithm(t *testing.T) { ScheduleResult{}, fwk.Status{}), cmpopts.IgnoreFields(algorithmResult{}, "podCtx", "schedulingDuration"), + cmpopts.IgnoreFields(podGroupAlgorithmResult{}, "placementCycleState"), statusCmpOpt, } @@ -2702,6 +2703,196 @@ func TestPodGroupSchedulingPlacementAlgorithm_Scoring(t *testing.T) { } } +// placementStateTracker is a fake plugin that writes to PlacementCycleState during Filter +// and reads from it during ScorePlacement, to verify the lifecycle of PlacementCycleState. +type placementStateTracker struct { + name string + mu sync.Mutex + // scoreReadValues records what value was read from PlacementCycleState + // during each ScorePlacement call, keyed by placement name. + scoreReadValues map[string]string + // generatePlacementsResult defines the placements to generate. + generatePlacementsResult map[string][]string +} + +type placementStateData struct { + value string +} + +func (d *placementStateData) Clone() fwk.StateData { return d } + +var placementStateKey fwk.StateKey = "placementStateTracker" + +var _ fwk.FilterPlugin = &placementStateTracker{} +var _ fwk.PlacementGeneratePlugin = &placementStateTracker{} +var _ fwk.PlacementScorePlugin = &placementStateTracker{} + +func (p *placementStateTracker) Name() string { return p.name } + +func (p *placementStateTracker) Filter(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeInfo fwk.NodeInfo) *fwk.Status { + placementState := state.GetPlacementCycleState() + if placementState == nil { + return fwk.NewStatus(fwk.Error, "PlacementCycleState is nil during Filter") + } + + // Write the node name as a marker so ScorePlacement can verify + // which placement's state it received. + placementState.Write(placementStateKey, &placementStateData{value: nodeInfo.Node().Name}) + return nil +} + +func (p *placementStateTracker) ScorePlacement(ctx context.Context, state fwk.PodGroupCycleState, podGroup fwk.PodGroupInfo, placement *fwk.PodGroupAssignments) (int64, *fwk.Status) { + if placement.PlacementCycleState == nil { + return 0, fwk.NewStatus(fwk.Error, "PlacementCycleState is nil during ScorePlacement") + } + + data, err := placement.PlacementCycleState.Read(placementStateKey) + p.mu.Lock() + defer p.mu.Unlock() + if err != nil { + p.scoreReadValues[placement.Name] = "" + } else { + p.scoreReadValues[placement.Name] = data.(*placementStateData).value + } + return 1, nil +} + +func (p *placementStateTracker) PlacementScoreExtensions() fwk.PlacementScoreExtensions { + return nil +} + +func (p *placementStateTracker) GeneratePlacements(ctx context.Context, state fwk.PodGroupCycleState, podGroup fwk.PodGroupInfo, parentPlacement *fwk.Placement) (*fwk.GeneratePlacementsResult, *fwk.Status) { + parentNodes := map[string]fwk.NodeInfo{} + for _, node := range parentPlacement.Nodes { + parentNodes[node.Node().Name] = node + } + + placements := make([]*fwk.Placement, 0, len(p.generatePlacementsResult)) + for placementName, nodeNames := range p.generatePlacementsResult { + placement := &fwk.Placement{Name: placementName} + for _, nodeName := range nodeNames { + placement.Nodes = append(placement.Nodes, parentNodes[nodeName]) + } + placements = append(placements, placement) + } + return &fwk.GeneratePlacementsResult{Placements: placements}, nil +} + +func TestPlacementCycleStateLifecycle(t *testing.T) { + featuregatetesting.SetFeatureGatesDuringTest(t, utilfeature.DefaultFeatureGate, featuregatetesting.FeatureOverrides{ + features.TopologyAwareWorkloadScheduling: true, + features.GenericWorkload: true, + }) + + // A single scenario exercises both isolation and continuity: + // - Filter writes a node-name marker into PlacementCycleState during each placement's simulation. + // - ScorePlacement reads from PodGroupAssignments.PlacementCycleState after all simulations. + // Assertions verify: + // 1. Each placement's scorer reads only the value its own simulation wrote (isolation). + // 2. The value is readable at all during scoring (continuity from simulation to scoring). + + nodes := []*v1.Node{ + st.MakeNode().Name("node1").Obj(), + st.MakeNode().Name("node2").Obj(), + } + podGroupPod := st.MakePod().Name("foo").UID("foo").PodGroupName("pg").Obj() + + logger, ctx := ktesting.NewTestContext(t) + + informerFactory := informers.NewSharedInformerFactory(clientsetfake.NewClientset(), 0) + queue := internalqueue.NewSchedulingQueue(nil, informerFactory) + + tracker := &placementStateTracker{ + name: "StateTracker", + scoreReadValues: make(map[string]string), + generatePlacementsResult: map[string][]string{ + "placementA": {nodes[0].Name}, + "placementB": {nodes[1].Name}, + }, + } + + registry := []tf.RegisterPluginFunc{ + tf.RegisterPlacementGeneratePlugin(tracker.Name(), func(_ context.Context, _ runtime.Object, _ fwk.Handle) (fwk.Plugin, error) { + return tracker, nil + }), + tf.RegisterPlacementScorePlugin(tracker.Name(), func(_ context.Context, _ runtime.Object, _ fwk.Handle) (fwk.Plugin, error) { + return tracker, nil + }, 1), + tf.RegisterFilterPlugin(tracker.Name(), func(_ context.Context, _ runtime.Object, _ fwk.Handle) (fwk.Plugin, error) { + return tracker, nil + }), + } + + snapshot := internalcache.NewEmptySnapshot() + schedFwk, err := tf.NewFramework(ctx, + append(registry, + tf.RegisterQueueSortPlugin(queuesort.Name, queuesort.New), + tf.RegisterBindPlugin(defaultbinder.Name, defaultbinder.New), + ), + "test-scheduler", + frameworkruntime.WithInformerFactory(informerFactory), + frameworkruntime.WithSnapshotSharedLister(snapshot), + frameworkruntime.WithPodNominator(queue), + ) + if err != nil { + t.Fatalf("Failed to create framework: %v", err) + } + + cache := internalcache.New(ctx, nil, true) + for _, node := range nodes { + cache.AddNode(logger, node) + } + + sched := &Scheduler{ + Cache: cache, + nodeInfoSnapshot: snapshot, + SchedulingQueue: queue, + Profiles: profile.Map{"test-scheduler": schedFwk}, + } + sched.SchedulePod = sched.schedulePod + + if err := sched.Cache.UpdateSnapshot(logger, sched.nodeInfoSnapshot); err != nil { + t.Fatalf("Failed to update snapshot: %v", err) + } + + pgInfo := &framework.QueuedPodGroupInfo{ + QueuedPodInfos: []*framework.QueuedPodInfo{ + {PodInfo: &framework.PodInfo{Pod: podGroupPod}}, + }, + PodGroupInfo: &framework.PodGroupInfo{ + UnscheduledPods: []*v1.Pod{podGroupPod}, + }, + } + + result := sched.podGroupSchedulingPlacementAlgorithm(ctx, schedFwk, framework.NewCycleState(), pgInfo, runAllPostFilters) + if !result.status.IsSuccess() { + t.Fatalf("Expected success, got: %v", result.status) + } + + tracker.mu.Lock() + defer tracker.mu.Unlock() + + // Continuity: ScorePlacement must have been called and able to read simulation data. + for _, placementName := range []string{"placementA", "placementB"} { + readValue, ok := tracker.scoreReadValues[placementName] + if !ok { + t.Fatalf("placement %s: ScorePlacement was not called", placementName) + } + if readValue == "" { + t.Fatalf("placement %s: ScorePlacement could not read state written during simulation", placementName) + } + } + + // Isolation: each placement's scorer must read only what its own simulation wrote. + // placementA simulated on node1, placementB simulated on node2. + if v := tracker.scoreReadValues["placementA"]; v != "node1" { + t.Errorf("placementA: scorer read %q, want %q (isolation violation if it read placementB's value)", v, "node1") + } + if v := tracker.scoreReadValues["placementB"]; v != "node2" { + t.Errorf("placementB: scorer read %q, want %q (isolation violation if it read placementA's value)", v, "node2") + } +} + type fakeDefaultPreemption struct { *fakePodGroupPlugin } diff --git a/staging/src/k8s.io/kube-scheduler/framework/cycle_state.go b/staging/src/k8s.io/kube-scheduler/framework/cycle_state.go index 4d103b2e39c..fdf41701ba6 100644 --- a/staging/src/k8s.io/kube-scheduler/framework/cycle_state.go +++ b/staging/src/k8s.io/kube-scheduler/framework/cycle_state.go @@ -104,6 +104,39 @@ type CycleState interface { // SetPodGroupSchedulingCycle sets the cycle state of the PodGroup for a Pod. // This should be only used when GenericWorkload feature flag is enabled. SetPodGroupSchedulingCycle(PodGroupCycleState) + // GetPlacementCycleState gets the cycle state of the current Placement for a Pod. + // Returns nil if this pod is not being scheduled within a placement context. + // This should be only used when GenericWorkload feature flag is enabled. + GetPlacementCycleState() PlacementCycleState + // SetPlacementCycleState sets the cycle state of the current Placement for a Pod. + // This should be only used when GenericWorkload feature flag is enabled. + SetPlacementCycleState(PlacementCycleState) +} + +// PlacementCycleState provides a mechanism for plugins to store and retrieve arbitrary data +// scoped to a single placement candidate within a pod group scheduling cycle. +// Data stored in PlacementCycleState is shared across all pods scheduled within the same +// placement iteration and is preserved through to the placement scoring phase via PodGroupAssignments. +// PlacementCycleState does not provide any data protection, as all plugins are assumed to be +// trusted. +type PlacementCycleState interface { + // ShouldRecordPluginMetrics returns whether metrics.PluginExecutionDuration metrics + // should be recorded. + // This function is mostly for the scheduling framework runtime, plugins usually don't have to use it. + ShouldRecordPluginMetrics() bool + // Read retrieves data with the given "key" from PlacementCycleState. If the key is not + // present, ErrNotFound is returned. + // + // See PlacementCycleState for notes on concurrency. + Read(key StateKey) (StateData, error) + // Write stores the given "val" in PlacementCycleState with the given "key". + // + // See PlacementCycleState for notes on concurrency. + Write(key StateKey, val StateData) + // Delete deletes data with the given key from PlacementCycleState. + // + // See PlacementCycleState for notes on concurrency. + Delete(key StateKey) } // PodGroupCycleState provides a mechanism for plugins that operate on pod groups to store and retrieve arbitrary data. diff --git a/staging/src/k8s.io/kube-scheduler/framework/types.go b/staging/src/k8s.io/kube-scheduler/framework/types.go index c5df461b3d7..d6e2f062bf5 100644 --- a/staging/src/k8s.io/kube-scheduler/framework/types.go +++ b/staging/src/k8s.io/kube-scheduler/framework/types.go @@ -682,6 +682,9 @@ type PodGroupAssignments struct { // during the pod group scheduling cycle. // The pods are guaranteed to also be present in the PodGroupInfo. ProposedAssignments []ProposedAssignment + // PlacementCycleState holds the state that was accumulated during the simulation of this placement. + // Placement score plugins can use this to access per-placement cached data. + PlacementCycleState PlacementCycleState } // NodeAllocatableDRAClaimState holds information about a node allocatable resource DRA claim's allocation on a node.