Skip to content

Commit 52161f0

Browse files
authored
Merge pull request #1040 from airweave-ai/fix/embedders-vector-size
Fix: Embedder vector size
2 parents feb0f0d + 3d1d0e4 commit 52161f0

File tree

12 files changed

+383
-36
lines changed

12 files changed

+383
-36
lines changed

backend/airweave/core/collection_service.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ async def _create(
4444
uow: UnitOfWork,
4545
) -> schemas.Collection:
4646
"""Create a new collection."""
47+
from airweave.platform.destinations.collection_strategy import get_default_vector_size
48+
4749
# Check if the collection already exists
4850
try:
4951
existing_collection = await crud.collection.get_by_readable_id(
@@ -57,20 +59,39 @@ async def _create(
5759
status_code=400, detail="Collection with this readable_id already exists"
5860
)
5961

60-
collection = await crud.collection.create(db, obj_in=collection_in, ctx=ctx, uow=uow)
62+
# Determine vector size and embedding model for this collection
63+
vector_size = get_default_vector_size()
64+
65+
# Determine embedding model name based on vector size
66+
from airweave.platform.destinations.collection_strategy import (
67+
get_openai_embedding_model_for_vector_size,
68+
)
69+
70+
try:
71+
embedding_model_name = get_openai_embedding_model_for_vector_size(vector_size)
72+
except ValueError:
73+
# For non-OpenAI vector sizes (e.g., 384), use a generic name
74+
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
75+
76+
# Add vector_size and embedding_model_name to collection data
77+
collection_data = collection_in.model_dump()
78+
collection_data["vector_size"] = vector_size
79+
collection_data["embedding_model_name"] = embedding_model_name
80+
81+
collection = await crud.collection.create(db, obj_in=collection_data, ctx=ctx, uow=uow)
6182
await uow.session.flush()
6283

63-
# Create Qdrant destination with organization context
64-
# Vector size is auto-detected based on embedding model configuration
84+
# Create Qdrant destination with explicit vector size
6585
qdrant_destination = await QdrantDestination.create(
6686
credentials=None, # Native Qdrant uses settings
6787
config=None,
6888
collection_id=collection.id,
6989
organization_id=ctx.organization.id,
90+
vector_size=vector_size,
7091
logger=ctx.logger,
7192
)
7293

73-
# Setup the physical shared collection (auto-detects vector size)
94+
# Setup the physical shared collection
7495
await qdrant_destination.setup_collection()
7596

7697
return schemas.Collection.model_validate(collection, from_attributes=True)

backend/airweave/models/collection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import TYPE_CHECKING, List
44

5-
from sqlalchemy import String
5+
from sqlalchemy import Integer, String
66
from sqlalchemy.orm import Mapped, mapped_column, relationship
77

88
from airweave.models._base import OrganizationBase, UserMixin
@@ -19,6 +19,8 @@ class Collection(OrganizationBase, UserMixin):
1919

2020
name: Mapped[str] = mapped_column(String, nullable=False)
2121
readable_id: Mapped[str] = mapped_column(String, nullable=False, unique=True)
22+
vector_size: Mapped[int] = mapped_column(Integer, nullable=False)
23+
embedding_model_name: Mapped[str] = mapped_column(String, nullable=False)
2224
# Status is now ephemeral - removed from database model
2325

2426
# Relationships

backend/airweave/platform/destinations/collection_strategy.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
All collections now use shared physical collections in Qdrant:
44
- 384-dim vectors → airweave_shared_minilm_l6_v2 (local model)
55
- 1536-dim vectors → airweave_shared_text_embedding_3_small (OpenAI)
6+
- 3072-dim vectors → airweave_shared_text_embedding_3_large (OpenAI)
67
78
Tenant isolation is achieved via airweave_collection_id payload filtering.
89
"""
@@ -31,13 +32,38 @@ def get_physical_collection_name(vector_size: int | None = None) -> str:
3132
Returns:
3233
Physical collection name in Qdrant:
3334
- "airweave_shared_text_embedding_3_large" for 3072-dim vectors
35+
- "airweave_shared_text_embedding_3_small" for 1536-dim vectors
3436
- "airweave_shared_minilm_l6_v2" for 384-dim vectors (default for other sizes)
3537
"""
3638
if vector_size is None:
3739
vector_size = get_default_vector_size()
3840

39-
return (
40-
"airweave_shared_text_embedding_3_large"
41-
if vector_size == 3072
42-
else "airweave_shared_minilm_l6_v2"
43-
)
41+
if vector_size == 3072:
42+
return "airweave_shared_text_embedding_3_large"
43+
elif vector_size == 1536:
44+
return "airweave_shared_text_embedding_3_small"
45+
else:
46+
return "airweave_shared_minilm_l6_v2"
47+
48+
49+
def get_openai_embedding_model_for_vector_size(vector_size: int) -> str:
50+
"""Get OpenAI embedding model name for given vector dimensions.
51+
52+
Args:
53+
vector_size: Vector dimensions (3072 or 1536)
54+
55+
Returns:
56+
- "text-embedding-3-large" for 3072-dim
57+
- "text-embedding-3-small" for 1536-dim
58+
59+
Raises:
60+
ValueError: For vector sizes that don't use OpenAI models (e.g., 384 uses local model)
61+
"""
62+
if vector_size == 3072:
63+
return "text-embedding-3-large"
64+
elif vector_size == 1536:
65+
return "text-embedding-3-small"
66+
else:
67+
raise ValueError(
68+
f"No OpenAI model for vector_size {vector_size}. Only 3072 and 1536 use OpenAI models."
69+
)

backend/airweave/platform/embedders/openai.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,40 +15,64 @@
1515

1616

1717
class DenseEmbedder(BaseEmbedder):
18-
"""Singleton dense embedder using OpenAI text-embedding-3-large (3072 dims).
18+
"""OpenAI dense embedder with dynamic model selection (non-singleton).
19+
20+
IMPORTANT: No longer a singleton! Each collection may use different embedding models,
21+
so we create fresh instances with the correct model for each sync/search operation.
1922
2023
Features:
21-
- Singleton shared across all syncs in pod
24+
- Dynamic model selection based on vector_size (3072 or 1536)
2225
- Batch processing with OpenAI limits (2048 texts/request, 300K tokens/request)
2326
- 5 concurrent requests max
24-
- Rate limiting with OpenAIRateLimiter singleton
27+
- Rate limiting with OpenAIRateLimiter singleton (shared across instances)
2528
- Automatic retry on transient errors (via AsyncOpenAI client)
2629
- Fail-fast on any API errors (no silent failures)
2730
"""
2831

29-
MODEL_NAME = "text-embedding-3-large"
30-
VECTOR_DIMENSIONS = 3072
3132
MAX_TOKENS_PER_TEXT = 8192 # OpenAI limit per text
3233
MAX_BATCH_SIZE = 2048 # OpenAI limit per request
3334
MAX_TOKENS_PER_REQUEST = 300000 # OpenAI limit
3435
MAX_CONCURRENT_REQUESTS = 5
3536

36-
def __init__(self):
37-
"""Initialize OpenAI embedder (once per pod)."""
38-
if self._initialized:
39-
return
37+
def __new__(cls, vector_size: int = None):
38+
"""Override singleton pattern from BaseEmbedder - create fresh instances."""
39+
return object.__new__(cls)
40+
41+
def __init__(self, vector_size: int = None):
42+
"""Initialize OpenAI embedder for specific vector dimensions.
4043
44+
Args:
45+
vector_size: Vector dimensions to determine model:
46+
- 3072: text-embedding-3-large
47+
- 1536: text-embedding-3-small
48+
- None: defaults to 3072 (large model)
49+
"""
4150
if not settings.OPENAI_API_KEY:
4251
raise SyncFailureError("OPENAI_API_KEY required for dense embeddings")
4352

53+
# Fail-fast: vector_size should always be provided from collection
54+
# Only allow None for backward compatibility, but warn
55+
if vector_size is None:
56+
# Fallback to large model but this shouldn't happen
57+
self.MODEL_NAME = "text-embedding-3-large"
58+
self.VECTOR_DIMENSIONS = 3072
59+
else:
60+
# Select model and dimensions based on vector_size
61+
from airweave.platform.destinations.collection_strategy import (
62+
get_openai_embedding_model_for_vector_size,
63+
)
64+
65+
self.MODEL_NAME = get_openai_embedding_model_for_vector_size(vector_size)
66+
self.VECTOR_DIMENSIONS = vector_size
67+
68+
# Create fresh client instance
4469
self._client = AsyncOpenAI(
4570
api_key=settings.OPENAI_API_KEY,
4671
timeout=1200.0, # 20 min timeout for high concurrency
4772
max_retries=2,
4873
)
49-
self._rate_limiter = OpenAIRateLimiter() # Singleton
74+
self._rate_limiter = OpenAIRateLimiter() # This singleton is still OK (shared rate limit)
5075
self._tokenizer = tiktoken.get_encoding("cl100k_base")
51-
self._initialized = True
5276

5377
async def embed_many(self, texts: List[str], sync_context: SyncContext) -> List[List[float]]:
5478
"""Embed batch of texts using OpenAI text-embedding-3-large.

backend/airweave/platform/sync/entity_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1167,9 +1167,10 @@ async def _embed_entities(
11671167
sparse_texts.append(json.dumps(entity_dict, sort_keys=True))
11681168

11691169
# Compute dense embeddings (always required)
1170+
# Create embedder with collection's vector_size (creates fresh instance)
11701171
from airweave.platform.embedders import DenseEmbedder
11711172

1172-
dense_embedder = DenseEmbedder()
1173+
dense_embedder = DenseEmbedder(vector_size=sync_context.collection.vector_size)
11731174
dense_embeddings = await dense_embedder.embed_many(dense_texts, sync_context)
11741175

11751176
# Compute sparse embeddings (only if destination supports keyword index)

backend/airweave/platform/sync/factory.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,7 @@ async def _get_integration_credential(
951951
return credential
952952

953953
@classmethod
954-
async def _create_destination_instances(
954+
async def _create_destination_instances( # noqa: C901
955955
cls,
956956
db: AsyncSession,
957957
sync: schemas.Sync,
@@ -997,12 +997,17 @@ async def _create_destination_instances(
997997
destination_schema = schemas.Destination.model_validate(destination_model)
998998
destination_class = resource_locator.get_destination(destination_schema)
999999

1000+
# Fail-fast: vector_size must be set
1001+
if collection.vector_size is None:
1002+
raise ValueError(f"Collection {collection.id} has no vector_size set.")
1003+
10001004
# Native Qdrant: no credentials (uses settings)
10011005
destination = await destination_class.create(
10021006
credentials=None,
10031007
config=None,
10041008
collection_id=collection.id,
10051009
organization_id=collection.organization_id,
1010+
vector_size=collection.vector_size,
10061011
logger=logger,
10071012
)
10081013

backend/airweave/schemas/collection.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,22 @@ class CollectionInDBBase(CollectionBase):
166166
"once the collection is created."
167167
),
168168
)
169+
vector_size: int = Field(
170+
...,
171+
description=(
172+
"Vector dimensions used by this collection. Determines which embedding model "
173+
"is used: 3072 (text-embedding-3-large), 1536 (text-embedding-3-small), "
174+
"or 384 (MiniLM-L6-v2)."
175+
),
176+
)
177+
embedding_model_name: str = Field(
178+
...,
179+
description=(
180+
"Name of the embedding model used for this collection "
181+
"(e.g., 'text-embedding-3-large', 'text-embedding-3-small'). "
182+
"This ensures queries use the same model as the indexed data."
183+
),
184+
)
169185
created_at: datetime = Field(
170186
...,
171187
description="Timestamp when the collection was created (ISO 8601 format).",

backend/airweave/search/defaults.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ provider_models:
2323
name: "gpt-5-nano"
2424
tokenizer: "cl100k_base"
2525
context_window: 400000
26+
embedding_large:
27+
name: "text-embedding-3-large"
28+
tokenizer: "cl100k_base"
29+
dimensions: 3072
30+
max_tokens: 8191
31+
embedding_small:
32+
name: "text-embedding-3-small"
33+
tokenizer: "cl100k_base"
34+
dimensions: 1536
35+
max_tokens: 8191
2636
embedding:
2737
name: "text-embedding-3-large"
2838
tokenizer: "cl100k_base"

0 commit comments

Comments
 (0)