|
1 | 1 | """API endpoints for collections.""" |
2 | 2 |
|
3 | | -from typing import List |
| 3 | +import asyncio |
| 4 | +import json |
| 5 | +from typing import Any, Dict, List |
4 | 6 |
|
5 | 7 | from fastapi import BackgroundTasks, Depends, HTTPException, Path, Query |
| 8 | +from fastapi.responses import StreamingResponse |
| 9 | +from qdrant_client.http.models import Filter as QdrantFilter |
6 | 10 | from sqlalchemy.ext.asyncio import AsyncSession |
7 | 11 |
|
8 | 12 | from airweave import crud, schemas |
|
17 | 21 | from airweave.core.collection_service import collection_service |
18 | 22 | from airweave.core.guard_rail_service import GuardRailService |
19 | 23 | from airweave.core.logging import ContextualLogger |
| 24 | +from airweave.core.pubsub import core_pubsub |
20 | 25 | from airweave.core.shared_models import ActionType |
21 | 26 | from airweave.core.source_connection_service import source_connection_service |
22 | 27 | from airweave.core.sync_service import sync_service |
23 | 28 | from airweave.core.temporal_service import temporal_service |
| 29 | +from airweave.db.session import AsyncSessionLocal |
24 | 30 | from airweave.schemas.search import ResponseType, SearchRequest |
25 | 31 | from airweave.search.search_service_v2 import search_service_v2 as search_service |
26 | 32 |
|
@@ -175,7 +181,7 @@ async def search( |
175 | 181 | ), |
176 | 182 | examples=["raw", "completion"], |
177 | 183 | ), |
178 | | - limit: int = Query(20, ge=1, le=1000, description="Maximum number of results to return"), |
| 184 | + limit: int = Query(100, ge=1, le=1000, description="Maximum number of results to return"), |
179 | 185 | offset: int = Query(0, ge=0, description="Number of results to skip for pagination"), |
180 | 186 | recency_bias: float | None = Query( |
181 | 187 | None, |
@@ -451,3 +457,184 @@ async def refresh_all_source_connections( |
451 | 457 | await guard_rail.increment(ActionType.SYNCS) |
452 | 458 |
|
453 | 459 | return sync_jobs |
| 460 | + |
| 461 | + |
| 462 | +@router.get("/internal/filter-schema") |
| 463 | +async def get_filter_schema() -> Dict[str, Any]: |
| 464 | + """Get the JSON schema for Qdrant filter validation. |
| 465 | +
|
| 466 | + This endpoint returns the JSON schema that can be used to validate |
| 467 | + filter objects in the frontend. |
| 468 | + """ |
| 469 | + # Get the Pydantic model's JSON schema |
| 470 | + schema = QdrantFilter.model_json_schema() |
| 471 | + |
| 472 | + # Simplify the schema to make it more frontend-friendly |
| 473 | + # Remove some internal fields that might confuse the validator |
| 474 | + if "$defs" in schema: |
| 475 | + # Keep definitions but clean them up |
| 476 | + for _def_name, def_schema in schema.get("$defs", {}).items(): |
| 477 | + # Remove discriminator fields that might cause issues |
| 478 | + if "discriminator" in def_schema: |
| 479 | + del def_schema["discriminator"] |
| 480 | + |
| 481 | + return schema |
| 482 | + |
| 483 | + |
| 484 | +@router.post("/{readable_id}/search/stream") |
| 485 | +async def stream_search_collection_advanced( # noqa: C901 - streaming orchestration is acceptable |
| 486 | + readable_id: str = Path( |
| 487 | + ..., description="The unique readable identifier of the collection to search" |
| 488 | + ), |
| 489 | + search_request: SearchRequest = ..., |
| 490 | + db: AsyncSession = Depends(deps.get_db), |
| 491 | + ctx: ApiContext = Depends(deps.get_context), |
| 492 | + guard_rail: GuardRailService = Depends(deps.get_guard_rail_service), |
| 493 | +) -> StreamingResponse: |
| 494 | + """Server-Sent Events (SSE) streaming endpoint for advanced search. |
| 495 | +
|
| 496 | + This endpoint initializes a streaming session for a search request, |
| 497 | + subscribes to a Redis Pub/Sub channel scoped by a generated request ID, |
| 498 | + and relays events to the client. It concurrently runs the real search |
| 499 | + pipeline, which will publish lifecycle and data events to the same channel. |
| 500 | + """ |
| 501 | + # Use the request ID from ApiContext for end-to-end tracing |
| 502 | + request_id = ctx.request_id |
| 503 | + ctx.logger.info( |
| 504 | + f"[SearchStream] Starting stream for collection '{readable_id}' id={request_id}" |
| 505 | + ) |
| 506 | + |
| 507 | + # Ensure the organization is allowed to perform queries |
| 508 | + await guard_rail.is_allowed(ActionType.QUERIES) |
| 509 | + |
| 510 | + # Subscribe to the search:<request_id> channel |
| 511 | + pubsub = await core_pubsub.subscribe("search", request_id) |
| 512 | + |
| 513 | + # Start real search in the background; it will publish to Redis |
| 514 | + async def _run_search() -> None: |
| 515 | + try: |
| 516 | + # Use a dedicated DB session for the background task to avoid |
| 517 | + # sharing the request-scoped session across tasks |
| 518 | + async with AsyncSessionLocal() as search_db: |
| 519 | + await search_service.search_with_request( |
| 520 | + search_db, |
| 521 | + readable_id=readable_id, |
| 522 | + search_request=search_request, |
| 523 | + ctx=ctx, |
| 524 | + request_id=request_id, |
| 525 | + ) |
| 526 | + except Exception as e: |
| 527 | + # Emit error to stream so clients get notified |
| 528 | + await core_pubsub.publish( |
| 529 | + "search", |
| 530 | + request_id, |
| 531 | + {"type": "error", "message": str(e)}, |
| 532 | + ) |
| 533 | + |
| 534 | + search_task = asyncio.create_task(_run_search()) |
| 535 | + |
| 536 | + async def event_stream(): # noqa: C901 - streaming loop requires structured branching |
| 537 | + try: |
| 538 | + # Initial connected event with request_id and timestamp |
| 539 | + import datetime as _dt |
| 540 | + |
| 541 | + connected_event = { |
| 542 | + "type": "connected", |
| 543 | + "request_id": request_id, |
| 544 | + "ts": _dt.datetime.now(_dt.timezone.utc).isoformat(), |
| 545 | + } |
| 546 | + yield f"data: {json.dumps(connected_event)}\n\n" |
| 547 | + |
| 548 | + # Heartbeat every 30 seconds to keep the connection alive |
| 549 | + last_heartbeat = asyncio.get_event_loop().time() |
| 550 | + heartbeat_interval = 30 |
| 551 | + |
| 552 | + async for message in pubsub.listen(): |
| 553 | + # Heartbeat |
| 554 | + now = asyncio.get_event_loop().time() |
| 555 | + if now - last_heartbeat > heartbeat_interval: |
| 556 | + import datetime as _dt |
| 557 | + |
| 558 | + heartbeat_event = { |
| 559 | + "type": "heartbeat", |
| 560 | + "ts": _dt.datetime.now(_dt.timezone.utc).isoformat(), |
| 561 | + } |
| 562 | + yield f"data: {json.dumps(heartbeat_event)}\n\n" |
| 563 | + last_heartbeat = now |
| 564 | + |
| 565 | + if message["type"] == "message": |
| 566 | + data = message["data"] |
| 567 | + # Relay the raw pubsub payload |
| 568 | + yield f"data: {data}\n\n" |
| 569 | + |
| 570 | + # If a done event arrives, stop streaming |
| 571 | + try: |
| 572 | + parsed = json.loads(data) |
| 573 | + if isinstance(parsed, dict) and parsed.get("type") == "done": |
| 574 | + ctx.logger.info( |
| 575 | + f"[SearchStream] Done event received for search:{request_id}. " |
| 576 | + "Closing stream" |
| 577 | + ) |
| 578 | + # Increment usage upon completion |
| 579 | + try: |
| 580 | + await guard_rail.increment(ActionType.QUERIES) |
| 581 | + except Exception: |
| 582 | + pass |
| 583 | + break |
| 584 | + except Exception: |
| 585 | + # Non-JSON payloads are ignored for termination logic |
| 586 | + pass |
| 587 | + |
| 588 | + elif message["type"] == "subscribe": |
| 589 | + ctx.logger.info(f"[SearchStream] Subscribed to channel search:{request_id}") |
| 590 | + else: |
| 591 | + # If no messages of interest, still consider heartbeats |
| 592 | + current = asyncio.get_event_loop().time() |
| 593 | + if current - last_heartbeat > heartbeat_interval: |
| 594 | + heartbeat_event = { |
| 595 | + "type": "heartbeat", |
| 596 | + "ts": _dt.datetime.now(_dt.timezone.utc).isoformat(), |
| 597 | + } |
| 598 | + yield f"data: {json.dumps(heartbeat_event)}\n\n" |
| 599 | + last_heartbeat = current |
| 600 | + |
| 601 | + except asyncio.CancelledError: |
| 602 | + ctx.logger.info(f"[SearchStream] Cancelled stream id={request_id}") |
| 603 | + except Exception as e: |
| 604 | + ctx.logger.error(f"[SearchStream] Error id={request_id}: {str(e)}") |
| 605 | + import datetime as _dt |
| 606 | + |
| 607 | + error_event = { |
| 608 | + "type": "error", |
| 609 | + "message": str(e), |
| 610 | + "ts": _dt.datetime.now(_dt.timezone.utc).isoformat(), |
| 611 | + } |
| 612 | + yield f"data: {json.dumps(error_event)}\n\n" |
| 613 | + finally: |
| 614 | + # Ensure background task is cancelled if still running |
| 615 | + if not search_task.done(): |
| 616 | + search_task.cancel() |
| 617 | + try: |
| 618 | + await search_task |
| 619 | + except Exception: |
| 620 | + pass |
| 621 | + # Clean up pubsub connection |
| 622 | + try: |
| 623 | + await pubsub.close() |
| 624 | + ctx.logger.info( |
| 625 | + f"[SearchStream] Closed pubsub subscription for search:{request_id}" |
| 626 | + ) |
| 627 | + except Exception: |
| 628 | + pass |
| 629 | + |
| 630 | + return StreamingResponse( |
| 631 | + event_stream(), |
| 632 | + media_type="text/event-stream", |
| 633 | + headers={ |
| 634 | + "Cache-Control": "no-cache, no-transform", |
| 635 | + "Connection": "keep-alive", |
| 636 | + "X-Accel-Buffering": "no", |
| 637 | + "Content-Type": "text/event-stream", |
| 638 | + "Access-Control-Allow-Origin": "*", |
| 639 | + }, |
| 640 | + ) |
0 commit comments