Skip to content

Commit 41cd66f

Browse files
authored
Merge pull request #761 from airweave-ai/feat/rework-ui-search
Feat/rework UI search
2 parents 9cfb785 + 4473261 commit 41cd66f

File tree

88 files changed

+13140
-795
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+13140
-795
lines changed

.cursor/rules/backend-rules.mdc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ alwaysApply: false
88

99
### API Layer
1010
- **FastAPI Application** (`main.py`): Entry point for HTTP requests
11-
- **API Routes** (`api/v1/endpoints/`): RESTful endpoints organized by resource
11+
- **API Routes** (`endpoints/`): RESTful endpoints organized by resource
1212
- **Dependencies** (`api/deps.py`): Authentication and request validation
1313

1414
### Service Layer
@@ -74,7 +74,7 @@ If necessary (like editing ), refer [sync-architecture.mdc](mdc:.cursor/rules/sy
7474
- Background processing with Redis workers (upcoming)
7575

7676
### API Convention
77-
- RESTful endpoints in `api/v1/endpoints/` -> the version is not part of the endpoint. It's just host.com/{endpoint}!
77+
- RESTful endpoints in `endpoints/` -> the version is not part of the endpoint. It's just host.com/{endpoint}!
7878
- Consistent response structures
7979
- One router per resource type
8080
- Logger injected via `ctx` dependency for contextual logging

.vscode/launch.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
"console": "integratedTerminal",
5959
"cwd": "${workspaceFolder}/backend",
6060
"env": {
61-
"PYTHONPATH": "${workspaceFolder}/backend"
61+
"PYTHONPATH": "${workspaceFolder}/backend",
62+
// "LOG_LEVEL": "DEBUG"
6263
},
6364
"justMyCode": false,
6465
"envFile": "${workspaceFolder}/.env"

backend/airweave/api/v1/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
destinations,
1313
embedding_models,
1414
entities,
15+
entity_counts,
1516
file_retrieval,
1617
health,
1718
organizations,
@@ -48,6 +49,7 @@
4849
api_router.include_router(white_label.router, prefix="/white-labels", tags=["white-labels"])
4950
api_router.include_router(dag.router, prefix="/dag", tags=["dag"])
5051
api_router.include_router(entities.router, prefix="/entities", tags=["entities"])
52+
api_router.include_router(entity_counts.router, prefix="/entity-counts", tags=["entity-counts"])
5153
api_router.include_router(transformers.router, prefix="/transformers", tags=["transformers"])
5254
api_router.include_router(file_retrieval.router, prefix="/files", tags=["files"])
5355

backend/airweave/api/v1/endpoints/collections.py

Lines changed: 189 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
"""API endpoints for collections."""
22

3-
from typing import List
3+
import asyncio
4+
import json
5+
from typing import Any, Dict, List
46

57
from fastapi import BackgroundTasks, Depends, HTTPException, Path, Query
8+
from fastapi.responses import StreamingResponse
9+
from qdrant_client.http.models import Filter as QdrantFilter
610
from sqlalchemy.ext.asyncio import AsyncSession
711

812
from airweave import crud, schemas
@@ -17,10 +21,12 @@
1721
from airweave.core.collection_service import collection_service
1822
from airweave.core.guard_rail_service import GuardRailService
1923
from airweave.core.logging import ContextualLogger
24+
from airweave.core.pubsub import core_pubsub
2025
from airweave.core.shared_models import ActionType
2126
from airweave.core.source_connection_service import source_connection_service
2227
from airweave.core.sync_service import sync_service
2328
from airweave.core.temporal_service import temporal_service
29+
from airweave.db.session import AsyncSessionLocal
2430
from airweave.schemas.search import ResponseType, SearchRequest
2531
from airweave.search.search_service_v2 import search_service_v2 as search_service
2632

@@ -175,7 +181,7 @@ async def search(
175181
),
176182
examples=["raw", "completion"],
177183
),
178-
limit: int = Query(20, ge=1, le=1000, description="Maximum number of results to return"),
184+
limit: int = Query(100, ge=1, le=1000, description="Maximum number of results to return"),
179185
offset: int = Query(0, ge=0, description="Number of results to skip for pagination"),
180186
recency_bias: float | None = Query(
181187
None,
@@ -451,3 +457,184 @@ async def refresh_all_source_connections(
451457
await guard_rail.increment(ActionType.SYNCS)
452458

453459
return sync_jobs
460+
461+
462+
@router.get("/internal/filter-schema")
463+
async def get_filter_schema() -> Dict[str, Any]:
464+
"""Get the JSON schema for Qdrant filter validation.
465+
466+
This endpoint returns the JSON schema that can be used to validate
467+
filter objects in the frontend.
468+
"""
469+
# Get the Pydantic model's JSON schema
470+
schema = QdrantFilter.model_json_schema()
471+
472+
# Simplify the schema to make it more frontend-friendly
473+
# Remove some internal fields that might confuse the validator
474+
if "$defs" in schema:
475+
# Keep definitions but clean them up
476+
for _def_name, def_schema in schema.get("$defs", {}).items():
477+
# Remove discriminator fields that might cause issues
478+
if "discriminator" in def_schema:
479+
del def_schema["discriminator"]
480+
481+
return schema
482+
483+
484+
@router.post("/{readable_id}/search/stream")
485+
async def stream_search_collection_advanced( # noqa: C901 - streaming orchestration is acceptable
486+
readable_id: str = Path(
487+
..., description="The unique readable identifier of the collection to search"
488+
),
489+
search_request: SearchRequest = ...,
490+
db: AsyncSession = Depends(deps.get_db),
491+
ctx: ApiContext = Depends(deps.get_context),
492+
guard_rail: GuardRailService = Depends(deps.get_guard_rail_service),
493+
) -> StreamingResponse:
494+
"""Server-Sent Events (SSE) streaming endpoint for advanced search.
495+
496+
This endpoint initializes a streaming session for a search request,
497+
subscribes to a Redis Pub/Sub channel scoped by a generated request ID,
498+
and relays events to the client. It concurrently runs the real search
499+
pipeline, which will publish lifecycle and data events to the same channel.
500+
"""
501+
# Use the request ID from ApiContext for end-to-end tracing
502+
request_id = ctx.request_id
503+
ctx.logger.info(
504+
f"[SearchStream] Starting stream for collection '{readable_id}' id={request_id}"
505+
)
506+
507+
# Ensure the organization is allowed to perform queries
508+
await guard_rail.is_allowed(ActionType.QUERIES)
509+
510+
# Subscribe to the search:<request_id> channel
511+
pubsub = await core_pubsub.subscribe("search", request_id)
512+
513+
# Start real search in the background; it will publish to Redis
514+
async def _run_search() -> None:
515+
try:
516+
# Use a dedicated DB session for the background task to avoid
517+
# sharing the request-scoped session across tasks
518+
async with AsyncSessionLocal() as search_db:
519+
await search_service.search_with_request(
520+
search_db,
521+
readable_id=readable_id,
522+
search_request=search_request,
523+
ctx=ctx,
524+
request_id=request_id,
525+
)
526+
except Exception as e:
527+
# Emit error to stream so clients get notified
528+
await core_pubsub.publish(
529+
"search",
530+
request_id,
531+
{"type": "error", "message": str(e)},
532+
)
533+
534+
search_task = asyncio.create_task(_run_search())
535+
536+
async def event_stream(): # noqa: C901 - streaming loop requires structured branching
537+
try:
538+
# Initial connected event with request_id and timestamp
539+
import datetime as _dt
540+
541+
connected_event = {
542+
"type": "connected",
543+
"request_id": request_id,
544+
"ts": _dt.datetime.now(_dt.timezone.utc).isoformat(),
545+
}
546+
yield f"data: {json.dumps(connected_event)}\n\n"
547+
548+
# Heartbeat every 30 seconds to keep the connection alive
549+
last_heartbeat = asyncio.get_event_loop().time()
550+
heartbeat_interval = 30
551+
552+
async for message in pubsub.listen():
553+
# Heartbeat
554+
now = asyncio.get_event_loop().time()
555+
if now - last_heartbeat > heartbeat_interval:
556+
import datetime as _dt
557+
558+
heartbeat_event = {
559+
"type": "heartbeat",
560+
"ts": _dt.datetime.now(_dt.timezone.utc).isoformat(),
561+
}
562+
yield f"data: {json.dumps(heartbeat_event)}\n\n"
563+
last_heartbeat = now
564+
565+
if message["type"] == "message":
566+
data = message["data"]
567+
# Relay the raw pubsub payload
568+
yield f"data: {data}\n\n"
569+
570+
# If a done event arrives, stop streaming
571+
try:
572+
parsed = json.loads(data)
573+
if isinstance(parsed, dict) and parsed.get("type") == "done":
574+
ctx.logger.info(
575+
f"[SearchStream] Done event received for search:{request_id}. "
576+
"Closing stream"
577+
)
578+
# Increment usage upon completion
579+
try:
580+
await guard_rail.increment(ActionType.QUERIES)
581+
except Exception:
582+
pass
583+
break
584+
except Exception:
585+
# Non-JSON payloads are ignored for termination logic
586+
pass
587+
588+
elif message["type"] == "subscribe":
589+
ctx.logger.info(f"[SearchStream] Subscribed to channel search:{request_id}")
590+
else:
591+
# If no messages of interest, still consider heartbeats
592+
current = asyncio.get_event_loop().time()
593+
if current - last_heartbeat > heartbeat_interval:
594+
heartbeat_event = {
595+
"type": "heartbeat",
596+
"ts": _dt.datetime.now(_dt.timezone.utc).isoformat(),
597+
}
598+
yield f"data: {json.dumps(heartbeat_event)}\n\n"
599+
last_heartbeat = current
600+
601+
except asyncio.CancelledError:
602+
ctx.logger.info(f"[SearchStream] Cancelled stream id={request_id}")
603+
except Exception as e:
604+
ctx.logger.error(f"[SearchStream] Error id={request_id}: {str(e)}")
605+
import datetime as _dt
606+
607+
error_event = {
608+
"type": "error",
609+
"message": str(e),
610+
"ts": _dt.datetime.now(_dt.timezone.utc).isoformat(),
611+
}
612+
yield f"data: {json.dumps(error_event)}\n\n"
613+
finally:
614+
# Ensure background task is cancelled if still running
615+
if not search_task.done():
616+
search_task.cancel()
617+
try:
618+
await search_task
619+
except Exception:
620+
pass
621+
# Clean up pubsub connection
622+
try:
623+
await pubsub.close()
624+
ctx.logger.info(
625+
f"[SearchStream] Closed pubsub subscription for search:{request_id}"
626+
)
627+
except Exception:
628+
pass
629+
630+
return StreamingResponse(
631+
event_stream(),
632+
media_type="text/event-stream",
633+
headers={
634+
"Cache-Control": "no-cache, no-transform",
635+
"Connection": "keep-alive",
636+
"X-Accel-Buffering": "no",
637+
"Content-Type": "text/event-stream",
638+
"Access-Control-Allow-Origin": "*",
639+
},
640+
)

backend/airweave/api/v1/endpoints/entities.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,6 @@ async def create_entity_definition(
3333
return await crud.entity_definition.create(db, obj_in=definition, ctx=ctx)
3434

3535

36-
@router.put("/definitions/{definition_id}", response_model=schemas.EntityDefinition)
37-
async def update_entity_definition(
38-
definition_id: UUID,
39-
definition: schemas.EntityDefinitionUpdate,
40-
db: AsyncSession = Depends(deps.get_db),
41-
ctx: ApiContext = Depends(deps.get_context),
42-
) -> schemas.EntityDefinition:
43-
"""Update an entity definition."""
44-
db_obj = await crud.entity_definition.get(db, id=definition_id)
45-
return await crud.entity_definition.update(db, db_obj=db_obj, obj_in=definition, ctx=ctx)
46-
47-
4836
@router.post("/definitions/by-ids/", response_model=List[schemas.EntityDefinition])
4937
async def get_entity_definitions_by_ids(
5038
ids: List[UUID],
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Entity counts API endpoints."""
2+
3+
from typing import List
4+
from uuid import UUID
5+
6+
from fastapi import APIRouter, Depends, HTTPException
7+
from sqlalchemy.ext.asyncio import AsyncSession
8+
9+
from airweave import crud
10+
from airweave.api import deps
11+
from airweave.api.context import ApiContext
12+
from airweave.schemas.entity_count import EntityCountWithDefinition
13+
14+
router = APIRouter()
15+
16+
17+
@router.get("/syncs/{sync_id}/counts", response_model=List[EntityCountWithDefinition])
18+
async def get_entity_counts_for_sync(
19+
sync_id: UUID,
20+
db: AsyncSession = Depends(deps.get_db),
21+
ctx: ApiContext = Depends(deps.get_context),
22+
) -> List[EntityCountWithDefinition]:
23+
"""Get entity counts for a sync with entity definition details.
24+
25+
This endpoint returns the count of entities grouped by entity type,
26+
along with details about each entity definition.
27+
"""
28+
# Verify the sync belongs to the organization
29+
sync = await crud.sync.get(db, id=sync_id, ctx=ctx)
30+
if not sync:
31+
raise HTTPException(
32+
status_code=404,
33+
detail=f"Sync {sync_id} not found",
34+
)
35+
36+
# Get the counts with definition details
37+
counts = await crud.entity_count.get_counts_per_sync_and_type(db, sync_id)
38+
39+
return counts
40+
41+
42+
@router.get("/syncs/{sync_id}/total-count", response_model=int)
43+
async def get_total_entity_count_for_sync(
44+
sync_id: UUID,
45+
db: AsyncSession = Depends(deps.get_db),
46+
ctx: ApiContext = Depends(deps.get_context),
47+
) -> int:
48+
"""Get total entity count across all types for a sync."""
49+
# Verify the sync belongs to the organization
50+
sync = await crud.sync.get(db, id=sync_id, ctx=ctx)
51+
if not sync:
52+
raise HTTPException(
53+
status_code=404,
54+
detail=f"Sync {sync_id} not found",
55+
)
56+
57+
# Get the total count
58+
total = await crud.entity_count.get_total_count_by_sync(db, sync_id)
59+
60+
return total

0 commit comments

Comments
 (0)