Skip to content

Commit a547d6a

Browse files
BrianPark314BrianPark314YuhanLiu11
authored
feature/KV-cache-aware-routing (vllm-project#550)
* feat: add kv aware routing Signed-off-by: BrianPark314 <[email protected]> * feat: add kv aware routing Signed-off-by: BrianPark314 <[email protected]> * feat: update kv aware logic Signed-off-by: BrianPark314 <[email protected]> * chore: remove unnecessary comment Signed-off-by: BrianPark314 <[email protected]> --------- Signed-off-by: BrianPark314 <[email protected]> Co-authored-by: BrianPark314 <[email protected]> Co-authored-by: Yuhan Liu <[email protected]>
1 parent 5a6f1e4 commit a547d6a

File tree

2 files changed

+134
-0
lines changed

2 files changed

+134
-0
lines changed

src/gateway_inference_extension/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ RUN git clone https://github.com/kubernetes-sigs/gateway-api-inference-extension
2828
git apply scheduler.patch && \
2929
cd ../../../.. && \
3030
cp /src/roundrobin_picker.go gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker/roundrobin_picker.go && \
31+
cp /src/kv_aware_picker.go gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker/kv_aware_picker.go && \
3132
mkdir -p /src/pkg/ && \
3233
cp -r gateway-api-inference-extension/pkg/epp/ /src/pkg/epp && \
3334
cp gateway-api-inference-extension/go.mod /src && \
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package picker
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"fmt"
7+
"net/http"
8+
"sort"
9+
"strconv"
10+
"strings"
11+
"sync/atomic"
12+
"time"
13+
14+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
15+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
16+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
17+
)
18+
19+
// KvAwarePicker attempts to route requests to the pod that already holds
20+
// the longest matching KV cache. If no information is available it falls
21+
// back to a round robin selection.
22+
//
23+
// NOTE: The actual lookup against the LMCache controller is left as a TODO
24+
// as the Go library for LMCache is not yet available. The code structure
25+
// mirrors the Python implementation found in routing_logic.KvawareRouter.
26+
var _ plugins.Picker = &KvAwarePicker{}
27+
28+
type KvAwarePicker struct {
29+
currentIndex uint64
30+
controllerAddr string
31+
threshold int
32+
instToPod map[string]*types.ScoredPod
33+
httpClient *http.Client
34+
}
35+
36+
func NewKvAwarePicker(addr string, threshold int) *KvAwarePicker {
37+
return &KvAwarePicker{
38+
controllerAddr: addr,
39+
threshold: threshold,
40+
instToPod: make(map[string]*types.ScoredPod),
41+
httpClient: &http.Client{Timeout: 2 * time.Second},
42+
}
43+
}
44+
45+
func (p *KvAwarePicker) Name() string { return "kvaware" }
46+
47+
func (p *KvAwarePicker) Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result {
48+
if len(scoredPods) == 0 {
49+
return &types.Result{}
50+
}
51+
52+
prompt := ctx.Request.Prompt
53+
model := ctx.Request.Model
54+
55+
inst, tokens, err := p.lookupInstance(model, prompt)
56+
if err == nil && inst != "" {
57+
if tokens >= len(strings.Fields(prompt))-p.threshold {
58+
if _, ok := p.instToPod[inst]; !ok {
59+
for _, pod := range scoredPods {
60+
ip := pod.GetPod().Status.PodIP
61+
iid, err := p.queryInstance(ip)
62+
if err == nil && iid != "" {
63+
p.instToPod[iid] = pod
64+
}
65+
}
66+
}
67+
if target, ok := p.instToPod[inst]; ok {
68+
ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("KvAwarePicker routed to %s", inst))
69+
return &types.Result{TargetPod: target}
70+
}
71+
}
72+
}
73+
74+
// Fallback to round robin routing when no KV cache information is
75+
// available. Sort candidates for deterministic behavior across schedulers.
76+
sort.Slice(scoredPods, func(i, j int) bool {
77+
return scoredPods[i].GetPod().NamespacedName.String() <
78+
scoredPods[j].GetPod().NamespacedName.String()
79+
})
80+
index := int(atomic.AddUint64(&p.currentIndex, 1) - 1)
81+
index = index % len(scoredPods)
82+
ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf(
83+
"KvAwarePicker falling back to round robin, index %d of %d", index, len(scoredPods)))
84+
return &types.Result{TargetPod: scoredPods[index]}
85+
}
86+
87+
// lookupInstance queries the LMCache controller for the instance containing the
88+
// longest prefix match for the given prompt. It returns the instance ID and the
89+
// number of matched tokens.
90+
func (p *KvAwarePicker) lookupInstance(model, prompt string) (string, int, error) {
91+
body, err := json.Marshal(map[string]string{"model": model, "prompt": prompt})
92+
if err != nil {
93+
return "", 0, err
94+
}
95+
url := fmt.Sprintf("http://%s/lookup", p.controllerAddr)
96+
resp, err := p.httpClient.Post(url, "application/json", bytes.NewReader(body))
97+
if err != nil {
98+
return "", 0, err
99+
}
100+
defer resp.Body.Close()
101+
if resp.StatusCode != http.StatusOK {
102+
return "", 0, fmt.Errorf("unexpected status %d", resp.StatusCode)
103+
}
104+
var data struct {
105+
InstanceID string `json:"instance_id"`
106+
Tokens int `json:"tokens"`
107+
}
108+
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
109+
return "", 0, err
110+
}
111+
return data.InstanceID, data.Tokens, nil
112+
}
113+
114+
// queryInstance resolves the instance ID for the given pod IP. It returns an
115+
// empty string if the controller does not recognize the pod.
116+
func (p *KvAwarePicker) queryInstance(ip string) (string, error) {
117+
url := fmt.Sprintf("http://%s/query?ip=%s", p.controllerAddr, ip)
118+
resp, err := p.httpClient.Get(url)
119+
if err != nil {
120+
return "", err
121+
}
122+
defer resp.Body.Close()
123+
if resp.StatusCode != http.StatusOK {
124+
return "", fmt.Errorf("unexpected status %d", resp.StatusCode)
125+
}
126+
var data struct {
127+
InstanceID string `json:"instance_id"`
128+
}
129+
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
130+
return "", err
131+
}
132+
return data.InstanceID, nil
133+
}

0 commit comments

Comments
 (0)