Add PlacementCycleState to WAS scheduler framework

Added PlacementCycleState as a third state scope for WAS, alongside pod-level CycleState and PodGroupCycleState.

This is foundational plumbing only — plugin adoption is a follow-up.

Signed-off-by: Travis O'Neal <wtravisoneal@gmail.com>
This commit is contained in:
Travis O'Neal 2026-04-08 14:00:25 -04:00 committed by Travis O'Neal
parent abb213f9e9
commit 070cb9ec48
6 changed files with 343 additions and 10 deletions

View file

@ -46,6 +46,10 @@ type CycleState struct {
// or doesn't belong to any pod group.
// This field can only be non-nil when GenericWorkload feature flag is enabled.
podGroupCycleState fwk.PodGroupCycleState
// placementCycleState contains the CycleState for the current Placement being evaluated.
// If set to nil, it means this pod is not being scheduled within a placement context.
// This field can only be non-nil when GenericWorkload feature flag is enabled.
placementCycleState fwk.PlacementCycleState
}
// NewCycleState initializes a new CycleState and returns its pointer.
@ -113,6 +117,14 @@ func (c *CycleState) GetPodGroupSchedulingCycle() fwk.PodGroupCycleState {
return c.podGroupCycleState
}
func (c *CycleState) GetPlacementCycleState() fwk.PlacementCycleState {
return c.placementCycleState
}
func (c *CycleState) SetPlacementCycleState(placementCycleState fwk.PlacementCycleState) {
c.placementCycleState = placementCycleState
}
func (c *CycleState) SetSkipAllPostFilterPlugins(flag bool) {
c.skipAllPostFilterPlugins = flag
}
@ -140,6 +152,7 @@ func (c *CycleState) Clone() fwk.CycleState {
copy.skipPreBindPlugins = c.skipPreBindPlugins
copy.parallelPreBindPlugins = c.parallelPreBindPlugins
copy.podGroupCycleState = c.podGroupCycleState
copy.placementCycleState = c.placementCycleState
copy.skipAllPostFilterPlugins = c.skipAllPostFilterPlugins
return copy

View file

@ -193,3 +193,89 @@ func TestCycleStateClone(t *testing.T) {
})
}
}
func TestPlacementCycleState(t *testing.T) {
t.Run("nil by default", func(t *testing.T) {
state := NewCycleState()
if state.GetPlacementCycleState() != nil {
t.Errorf("expected nil PlacementCycleState on fresh CycleState")
}
})
t.Run("set and get", func(t *testing.T) {
state := NewCycleState()
placementState := NewCycleState()
placementState.Write("testkey", &fakeData{data: "placementdata"})
state.SetPlacementCycleState(placementState)
got := state.GetPlacementCycleState()
if got == nil {
t.Fatal("expected non-nil PlacementCycleState after Set")
}
data, err := got.Read("testkey")
if err != nil {
t.Fatalf("unexpected error reading from PlacementCycleState: %v", err)
}
if data.(*fakeData).data != "placementdata" {
t.Errorf("expected 'placementdata', got %q", data.(*fakeData).data)
}
})
t.Run("set to nil clears", func(t *testing.T) {
state := NewCycleState()
state.SetPlacementCycleState(NewCycleState())
state.SetPlacementCycleState(nil)
if state.GetPlacementCycleState() != nil {
t.Errorf("expected nil PlacementCycleState after setting to nil")
}
})
t.Run("clone preserves reference", func(t *testing.T) {
state := NewCycleState()
state.Write(key, &fakeData{data: "pod-data"})
placementState := NewCycleState()
placementState.Write("pkey", &fakeData{data: "placement-data"})
state.SetPlacementCycleState(placementState)
cloned := state.Clone().(*CycleState)
// The cloned state should reference the same PlacementCycleState.
if cloned.GetPlacementCycleState() == nil {
t.Fatal("cloned state should have non-nil PlacementCycleState")
}
data, err := cloned.GetPlacementCycleState().Read("pkey")
if err != nil {
t.Fatalf("unexpected error reading from cloned PlacementCycleState: %v", err)
}
if data.(*fakeData).data != "placement-data" {
t.Errorf("expected 'placement-data', got %q", data.(*fakeData).data)
}
// Writes to the PlacementCycleState via the clone should be visible from the original,
// since it's a shared reference (same as podGroupCycleState behavior).
cloned.GetPlacementCycleState().Write("newkey", &fakeData{data: "new"})
newData, err := state.GetPlacementCycleState().Read("newkey")
if err != nil {
t.Fatalf("write via clone's PlacementCycleState should be visible from original: %v", err)
}
if newData.(*fakeData).data != "new" {
t.Errorf("expected 'new', got %q", newData.(*fakeData).data)
}
})
t.Run("clone with nil placement state", func(t *testing.T) {
state := NewCycleState()
state.Write(key, &fakeData{data: "data"})
// Do not set PlacementCycleState — leave nil.
cloned := state.Clone().(*CycleState)
if cloned.GetPlacementCycleState() != nil {
t.Errorf("cloned state should have nil PlacementCycleState when original has nil")
}
})
}

View file

@ -182,7 +182,7 @@ type podSchedulingContext struct {
}
// initPodSchedulingContext initializes the scheduling context of a single pod for pod group scheduling cycle.
func initPodSchedulingContext(ctx context.Context, pod *v1.Pod, podGroupState *framework.CycleState, postFilterMode podGroupPostFilterMode) *podSchedulingContext {
func initPodSchedulingContext(ctx context.Context, pod *v1.Pod, podGroupState *framework.CycleState, placementCycleState fwk.PlacementCycleState, postFilterMode podGroupPostFilterMode) *podSchedulingContext {
logger := klog.FromContext(ctx)
// TODO(knelasevero): Remove duplicated keys from log entry calls
// When contextualized logging hits GA
@ -202,6 +202,8 @@ func initPodSchedulingContext(ctx context.Context, pod *v1.Pod, podGroupState *f
// Marks this cycle as a pod group scheduling cycle.
state.SetPodGroupSchedulingCycle(podGroupState)
// Set the placement cycle state so per-pod plugins can access placement-scoped data.
state.SetPlacementCycleState(placementCycleState)
// Skip post filters if requested.
switch postFilterMode {
@ -358,6 +360,9 @@ type podGroupAlgorithmResult struct {
// waitingOnPreemption indicates whether this pod group requires or is waiting for preemption to complete.
// This can only be set to true when the status is Unschedulable.
waitingOnPreemption bool
// placementCycleState holds the state accumulated during simulation of a specific placement.
// Only set for successful placements that proceed to the scoring phase.
placementCycleState fwk.PlacementCycleState
}
// podGroupSchedulingDefaultAlgorithm runs the default algorithm for scheduling a pod group.
@ -365,23 +370,24 @@ type podGroupAlgorithmResult struct {
// If a pod requires preemption to be schedulable, subsequent pods in the algorithm
// treat that pod as already scheduled on that node with victims being already removed in memory.
func (sched *Scheduler) podGroupSchedulingDefaultAlgorithm(ctx context.Context, schedFwk framework.Framework, podGroupCycleState *framework.CycleState, podGroupInfo *framework.QueuedPodGroupInfo, postFilterMode podGroupPostFilterMode) podGroupAlgorithmResult {
placementCycleState := framework.NewCycleState()
placementCycleState.SetRecordPluginMetrics(true)
placementCycleState.SetPodGroupSchedulingCycle(podGroupCycleState)
result := podGroupAlgorithmResult{
podResults: make([]algorithmResult, 0, len(podGroupInfo.QueuedPodInfos)),
status: fwk.NewStatus(fwk.Unschedulable).WithError(errPodGroupUnschedulable),
waitingOnPreemption: false,
placementCycleState: placementCycleState,
}
placementCycleState := framework.NewCycleState()
placementCycleState.SetRecordPluginMetrics(true)
placementCycleState.SetPodGroupSchedulingCycle(podGroupCycleState)
logger := klog.FromContext(ctx)
logger.V(5).Info("Running a pod group scheduling algorithm", "podGroup", klog.KObj(podGroupInfo), "unscheduledPodsCount", len(podGroupInfo.QueuedPodInfos))
requiresPreemption := false
anyScheduled := false
for _, podInfo := range podGroupInfo.QueuedPodInfos {
podResult, revertFn := sched.podGroupPodSchedulingAlgorithm(ctx, schedFwk, podGroupCycleState, podGroupInfo, podInfo, postFilterMode)
podResult, revertFn := sched.podGroupPodSchedulingAlgorithm(ctx, schedFwk, podGroupCycleState, placementCycleState, podGroupInfo, podInfo, postFilterMode)
result.podResults = append(result.podResults, podResult)
if revertFn != nil {
// We unreserve the pod at the end of the whole algorithm (via defer) because it should be ultimately returned to the queue,
@ -435,9 +441,9 @@ func (sched *Scheduler) podGroupSchedulingDefaultAlgorithm(ctx context.Context,
// podGroupPodSchedulingAlgorithm runs a scheduling algorithm for individual pod from a pod group.
// It returns the algorithm result together with the revert function.
func (sched *Scheduler) podGroupPodSchedulingAlgorithm(ctx context.Context, schedFwk framework.Framework, podGroupCycleState *framework.CycleState, podGroupInfo *framework.QueuedPodGroupInfo, podInfo *framework.QueuedPodInfo, postFilterMode podGroupPostFilterMode) (algorithmResult, func()) {
func (sched *Scheduler) podGroupPodSchedulingAlgorithm(ctx context.Context, schedFwk framework.Framework, podGroupCycleState *framework.CycleState, placementCycleState fwk.PlacementCycleState, podGroupInfo *framework.QueuedPodGroupInfo, podInfo *framework.QueuedPodInfo, postFilterMode podGroupPostFilterMode) (algorithmResult, func()) {
pod := podInfo.Pod
podCtx := initPodSchedulingContext(ctx, pod, podGroupCycleState, postFilterMode)
podCtx := initPodSchedulingContext(ctx, pod, podGroupCycleState, placementCycleState, postFilterMode)
logger := podCtx.logger
ctx = klog.NewContext(ctx, logger)
start := time.Now()
@ -506,7 +512,7 @@ func completePodGroupAlgorithmResult(ctx context.Context, podGroupInfo *framewor
pInfo := podGroupInfo.QueuedPodInfos[i]
newResults[i] = algorithmResult{
pod: pInfo.Pod,
podCtx: initPodSchedulingContext(ctx, pInfo.Pod, podGroupState, postFilterMode),
podCtx: initPodSchedulingContext(ctx, pInfo.Pod, podGroupState, nil, postFilterMode),
status: podGroupResult.status.Clone(),
}
}
@ -783,6 +789,7 @@ func makePodGroupAssignments(successfulResults map[*fwk.Placement]*podGroupAlgor
placementPodGroupAssignments = append(placementPodGroupAssignments, &fwk.PodGroupAssignments{
Placement: placement,
ProposedAssignments: proposedAssignments,
PlacementCycleState: result.placementCycleState,
})
}
return placementPodGroupAssignments

View file

@ -1751,7 +1751,7 @@ func TestSubmitPodGroupAlgorithmResult(t *testing.T) {
for i := range tt.algorithmResult.podResults {
pod := podGroupInfo.QueuedPodInfos[i].Pod
podCtx := initPodSchedulingContext(ctx, pod, podGroupCycleState, runAllPostFilters)
podCtx := initPodSchedulingContext(ctx, pod, podGroupCycleState, nil, runAllPostFilters)
tt.algorithmResult.podResults[i].podCtx = podCtx
}
@ -2538,6 +2538,7 @@ func TestPodGroupSchedulingPlacementAlgorithm(t *testing.T) {
ScheduleResult{},
fwk.Status{}),
cmpopts.IgnoreFields(algorithmResult{}, "podCtx", "schedulingDuration"),
cmpopts.IgnoreFields(podGroupAlgorithmResult{}, "placementCycleState"),
statusCmpOpt,
}
@ -2702,6 +2703,196 @@ func TestPodGroupSchedulingPlacementAlgorithm_Scoring(t *testing.T) {
}
}
// placementStateTracker is a fake plugin that writes to PlacementCycleState during Filter
// and reads from it during ScorePlacement, to verify the lifecycle of PlacementCycleState.
type placementStateTracker struct {
name string
mu sync.Mutex
// scoreReadValues records what value was read from PlacementCycleState
// during each ScorePlacement call, keyed by placement name.
scoreReadValues map[string]string
// generatePlacementsResult defines the placements to generate.
generatePlacementsResult map[string][]string
}
type placementStateData struct {
value string
}
func (d *placementStateData) Clone() fwk.StateData { return d }
var placementStateKey fwk.StateKey = "placementStateTracker"
var _ fwk.FilterPlugin = &placementStateTracker{}
var _ fwk.PlacementGeneratePlugin = &placementStateTracker{}
var _ fwk.PlacementScorePlugin = &placementStateTracker{}
func (p *placementStateTracker) Name() string { return p.name }
func (p *placementStateTracker) Filter(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeInfo fwk.NodeInfo) *fwk.Status {
placementState := state.GetPlacementCycleState()
if placementState == nil {
return fwk.NewStatus(fwk.Error, "PlacementCycleState is nil during Filter")
}
// Write the node name as a marker so ScorePlacement can verify
// which placement's state it received.
placementState.Write(placementStateKey, &placementStateData{value: nodeInfo.Node().Name})
return nil
}
func (p *placementStateTracker) ScorePlacement(ctx context.Context, state fwk.PodGroupCycleState, podGroup fwk.PodGroupInfo, placement *fwk.PodGroupAssignments) (int64, *fwk.Status) {
if placement.PlacementCycleState == nil {
return 0, fwk.NewStatus(fwk.Error, "PlacementCycleState is nil during ScorePlacement")
}
data, err := placement.PlacementCycleState.Read(placementStateKey)
p.mu.Lock()
defer p.mu.Unlock()
if err != nil {
p.scoreReadValues[placement.Name] = "<not found>"
} else {
p.scoreReadValues[placement.Name] = data.(*placementStateData).value
}
return 1, nil
}
func (p *placementStateTracker) PlacementScoreExtensions() fwk.PlacementScoreExtensions {
return nil
}
func (p *placementStateTracker) GeneratePlacements(ctx context.Context, state fwk.PodGroupCycleState, podGroup fwk.PodGroupInfo, parentPlacement *fwk.Placement) (*fwk.GeneratePlacementsResult, *fwk.Status) {
parentNodes := map[string]fwk.NodeInfo{}
for _, node := range parentPlacement.Nodes {
parentNodes[node.Node().Name] = node
}
placements := make([]*fwk.Placement, 0, len(p.generatePlacementsResult))
for placementName, nodeNames := range p.generatePlacementsResult {
placement := &fwk.Placement{Name: placementName}
for _, nodeName := range nodeNames {
placement.Nodes = append(placement.Nodes, parentNodes[nodeName])
}
placements = append(placements, placement)
}
return &fwk.GeneratePlacementsResult{Placements: placements}, nil
}
func TestPlacementCycleStateLifecycle(t *testing.T) {
featuregatetesting.SetFeatureGatesDuringTest(t, utilfeature.DefaultFeatureGate, featuregatetesting.FeatureOverrides{
features.TopologyAwareWorkloadScheduling: true,
features.GenericWorkload: true,
})
// A single scenario exercises both isolation and continuity:
// - Filter writes a node-name marker into PlacementCycleState during each placement's simulation.
// - ScorePlacement reads from PodGroupAssignments.PlacementCycleState after all simulations.
// Assertions verify:
// 1. Each placement's scorer reads only the value its own simulation wrote (isolation).
// 2. The value is readable at all during scoring (continuity from simulation to scoring).
nodes := []*v1.Node{
st.MakeNode().Name("node1").Obj(),
st.MakeNode().Name("node2").Obj(),
}
podGroupPod := st.MakePod().Name("foo").UID("foo").PodGroupName("pg").Obj()
logger, ctx := ktesting.NewTestContext(t)
informerFactory := informers.NewSharedInformerFactory(clientsetfake.NewClientset(), 0)
queue := internalqueue.NewSchedulingQueue(nil, informerFactory)
tracker := &placementStateTracker{
name: "StateTracker",
scoreReadValues: make(map[string]string),
generatePlacementsResult: map[string][]string{
"placementA": {nodes[0].Name},
"placementB": {nodes[1].Name},
},
}
registry := []tf.RegisterPluginFunc{
tf.RegisterPlacementGeneratePlugin(tracker.Name(), func(_ context.Context, _ runtime.Object, _ fwk.Handle) (fwk.Plugin, error) {
return tracker, nil
}),
tf.RegisterPlacementScorePlugin(tracker.Name(), func(_ context.Context, _ runtime.Object, _ fwk.Handle) (fwk.Plugin, error) {
return tracker, nil
}, 1),
tf.RegisterFilterPlugin(tracker.Name(), func(_ context.Context, _ runtime.Object, _ fwk.Handle) (fwk.Plugin, error) {
return tracker, nil
}),
}
snapshot := internalcache.NewEmptySnapshot()
schedFwk, err := tf.NewFramework(ctx,
append(registry,
tf.RegisterQueueSortPlugin(queuesort.Name, queuesort.New),
tf.RegisterBindPlugin(defaultbinder.Name, defaultbinder.New),
),
"test-scheduler",
frameworkruntime.WithInformerFactory(informerFactory),
frameworkruntime.WithSnapshotSharedLister(snapshot),
frameworkruntime.WithPodNominator(queue),
)
if err != nil {
t.Fatalf("Failed to create framework: %v", err)
}
cache := internalcache.New(ctx, nil, true)
for _, node := range nodes {
cache.AddNode(logger, node)
}
sched := &Scheduler{
Cache: cache,
nodeInfoSnapshot: snapshot,
SchedulingQueue: queue,
Profiles: profile.Map{"test-scheduler": schedFwk},
}
sched.SchedulePod = sched.schedulePod
if err := sched.Cache.UpdateSnapshot(logger, sched.nodeInfoSnapshot); err != nil {
t.Fatalf("Failed to update snapshot: %v", err)
}
pgInfo := &framework.QueuedPodGroupInfo{
QueuedPodInfos: []*framework.QueuedPodInfo{
{PodInfo: &framework.PodInfo{Pod: podGroupPod}},
},
PodGroupInfo: &framework.PodGroupInfo{
UnscheduledPods: []*v1.Pod{podGroupPod},
},
}
result := sched.podGroupSchedulingPlacementAlgorithm(ctx, schedFwk, framework.NewCycleState(), pgInfo, runAllPostFilters)
if !result.status.IsSuccess() {
t.Fatalf("Expected success, got: %v", result.status)
}
tracker.mu.Lock()
defer tracker.mu.Unlock()
// Continuity: ScorePlacement must have been called and able to read simulation data.
for _, placementName := range []string{"placementA", "placementB"} {
readValue, ok := tracker.scoreReadValues[placementName]
if !ok {
t.Fatalf("placement %s: ScorePlacement was not called", placementName)
}
if readValue == "<not found>" {
t.Fatalf("placement %s: ScorePlacement could not read state written during simulation", placementName)
}
}
// Isolation: each placement's scorer must read only what its own simulation wrote.
// placementA simulated on node1, placementB simulated on node2.
if v := tracker.scoreReadValues["placementA"]; v != "node1" {
t.Errorf("placementA: scorer read %q, want %q (isolation violation if it read placementB's value)", v, "node1")
}
if v := tracker.scoreReadValues["placementB"]; v != "node2" {
t.Errorf("placementB: scorer read %q, want %q (isolation violation if it read placementA's value)", v, "node2")
}
}
type fakeDefaultPreemption struct {
*fakePodGroupPlugin
}

View file

@ -104,6 +104,39 @@ type CycleState interface {
// SetPodGroupSchedulingCycle sets the cycle state of the PodGroup for a Pod.
// This should be only used when GenericWorkload feature flag is enabled.
SetPodGroupSchedulingCycle(PodGroupCycleState)
// GetPlacementCycleState gets the cycle state of the current Placement for a Pod.
// Returns nil if this pod is not being scheduled within a placement context.
// This should be only used when GenericWorkload feature flag is enabled.
GetPlacementCycleState() PlacementCycleState
// SetPlacementCycleState sets the cycle state of the current Placement for a Pod.
// This should be only used when GenericWorkload feature flag is enabled.
SetPlacementCycleState(PlacementCycleState)
}
// PlacementCycleState provides a mechanism for plugins to store and retrieve arbitrary data
// scoped to a single placement candidate within a pod group scheduling cycle.
// Data stored in PlacementCycleState is shared across all pods scheduled within the same
// placement iteration and is preserved through to the placement scoring phase via PodGroupAssignments.
// PlacementCycleState does not provide any data protection, as all plugins are assumed to be
// trusted.
type PlacementCycleState interface {
// ShouldRecordPluginMetrics returns whether metrics.PluginExecutionDuration metrics
// should be recorded.
// This function is mostly for the scheduling framework runtime, plugins usually don't have to use it.
ShouldRecordPluginMetrics() bool
// Read retrieves data with the given "key" from PlacementCycleState. If the key is not
// present, ErrNotFound is returned.
//
// See PlacementCycleState for notes on concurrency.
Read(key StateKey) (StateData, error)
// Write stores the given "val" in PlacementCycleState with the given "key".
//
// See PlacementCycleState for notes on concurrency.
Write(key StateKey, val StateData)
// Delete deletes data with the given key from PlacementCycleState.
//
// See PlacementCycleState for notes on concurrency.
Delete(key StateKey)
}
// PodGroupCycleState provides a mechanism for plugins that operate on pod groups to store and retrieve arbitrary data.

View file

@ -682,6 +682,9 @@ type PodGroupAssignments struct {
// during the pod group scheduling cycle.
// The pods are guaranteed to also be present in the PodGroupInfo.
ProposedAssignments []ProposedAssignment
// PlacementCycleState holds the state that was accumulated during the simulation of this placement.
// Placement score plugins can use this to access per-placement cached data.
PlacementCycleState PlacementCycleState
}
// NodeAllocatableDRAClaimState holds information about a node allocatable resource DRA claim's allocation on a node.