diff --git a/rest/agent/agent.py b/rest/agent/agent.py index ef6ecb4f..9ab6133c 100644 --- a/rest/agent/agent.py +++ b/rest/agent/agent.py @@ -4,6 +4,8 @@ from openai import AsyncOpenAI +from rest.agent.chunk.semantic import semantic_chunk + try: from rest.dao.ee.mongodb_dao import TraceRootMongoDBClient except ImportError: @@ -13,7 +15,6 @@ from copy import deepcopy from typing import Any, Tuple -from rest.agent.chunk.sequential import sequential_chunk from rest.agent.context.tree import SpanNode from rest.agent.filter.feature import ( SpanFeature, @@ -467,8 +468,8 @@ async def chat_with_context_chunks_streaming( def get_context_messages(self, context: str) -> list[str]: r"""Get the context message.""" - # TODO: Make this more efficient. - context_chunks = list(sequential_chunk(context)) + + context_chunks = list(semantic_chunk(context)) if len(context_chunks) == 1: return [ ( diff --git a/rest/agent/chat.py b/rest/agent/chat.py index d711c1b3..278cd3a1 100644 --- a/rest/agent/chat.py +++ b/rest/agent/chat.py @@ -11,7 +11,7 @@ except ImportError: from rest.dao.mongodb_dao import TraceRootMongoDBClient -from rest.agent.chunk.sequential import sequential_chunk +from rest.agent.chunk.semantic import semantic_chunk from rest.agent.context.tree import SpanNode from rest.agent.filter.feature import log_feature_selector, span_feature_selector from rest.agent.filter.structure import filter_log_node, log_node_selector @@ -420,8 +420,8 @@ async def _update_streaming_record( def get_context_messages(self, context: str) -> list[str]: r"""Get the context message. """ - # Make this more efficient. - context_chunks = list(sequential_chunk(context)) + + context_chunks = list(semantic_chunk(context)) if len(context_chunks) == 1: return [ ( diff --git a/rest/agent/chunk/semantic.py b/rest/agent/chunk/semantic.py new file mode 100644 index 00000000..03af6878 --- /dev/null +++ b/rest/agent/chunk/semantic.py @@ -0,0 +1,204 @@ +import json +from typing import Any, Dict, Iterator + +CHUNK_SIZE = 200_000 + + +def semantic_chunk(text: str, chunk_size: int = CHUNK_SIZE) -> Iterator[str]: + """Hierarchical span-aware chunking that preserves structure. + + Strategy: + 1. Keep entire tree intact if it fits + 2. If too large, separate child spans while maintaining hierarchy + 3. Include parent context when splitting + 4. Only flatten as last resort for huge individual spans + + Args: + text: JSON string (single span or span tree) + chunk_size: Target max size per chunk + + Yields: + JSON strings with preserved hierarchical structure + """ + data = json.loads(text) + full_size = len(text) + + # Case 1: Perfect - everything fits + if full_size <= chunk_size: + yield text + return + + # Case 2: Need to split - use hierarchical splitting + if isinstance(data, dict) and "span_id" in data: + yield from _hierarchical_split(data, chunk_size) + elif isinstance(data, list): + # Array of spans - batch them intelligently + yield from _batch_spans_list(data, chunk_size) + else: + # Unknown structure, yield as-is + yield text + + +def _hierarchical_split(span: Dict[str, Any], chunk_size: int) -> Iterator[str]: + """Split span hierarchically, preserving parent-child relationships. + + Args: + span: The span dict to split + chunk_size: Maximum chunk size + """ + # Separate parent data from children + parent_data = {} + children = [] + + for key, value in span.items(): + if isinstance(value, dict) and "span_id" in value: + children.append((key, value)) + else: + parent_data[key] = value + + parent_size = len(json.dumps(parent_data, indent=2)) + + # Case A: Parent alone is too big - split its logs + if parent_size > chunk_size: + yield from _split_large_span_logs(parent_data, chunk_size) + + # Yield children separately with parent context + for child_key, child_data in children: + child_with_context = _add_parent_context(child_data, parent_data) + yield from _hierarchical_split(child_with_context, chunk_size) + return + + # Case B: Parent fits, try grouping with children (preserving order) + current_chunk = parent_data.copy() + + for child_key, child_data in children: + # Calculate size if we add this child + test_chunk = current_chunk.copy() + test_chunk[child_key] = child_data + test_size = len(json.dumps(test_chunk, indent=2)) + + # Check if child fits in current chunk + if test_size <= chunk_size: + current_chunk[child_key] = child_data + else: + # Child doesn't fit - yield current chunk first to preserve order + if len(current_chunk) > len(parent_data): + # We have some children, yield them + yield json.dumps(current_chunk, indent=2) + current_chunk = parent_data.copy() + + # Now handle the child that didn't fit + child_size = len(json.dumps(child_data, indent=2)) + + if child_size > chunk_size: + # Child is too big - recursively split it + child_with_context = _add_parent_context(child_data, parent_data) + yield from _hierarchical_split(child_with_context, chunk_size) + else: + # Child fits alone, yield it with parent context + chunk_with_context = { + "_parent_context": { + "span_id": parent_data.get("span_id"), + "func_full_name": parent_data.get("func_full_name") + }, + child_key: child_data + } + yield json.dumps(chunk_with_context, indent=2) + + # Yield final chunk if it has content + if len(current_chunk) > len(parent_data) or not children: + # Has children or parent-only (no children at all) + yield json.dumps(current_chunk, indent=2) + + +def _add_parent_context(child: Dict[str, Any], parent: Dict[str, Any]) -> Dict[str, Any]: + """Add parent context metadata to a child span.""" + context = { + "_parent_span_id": parent.get("span_id"), + "_parent_function": parent.get("func_full_name") + } + return {**context, **child} + + +def _split_large_span_logs(span: Dict[str, Any], chunk_size: int) -> Iterator[str]: + """Split a span with too many logs into multiple chunks. + + Keeps metadata in each chunk, splits logs into groups. + """ + # Separate metadata from logs + metadata = {} + logs = {} + + for key, value in span.items(): + if key.startswith("log_"): + logs[key] = value + else: + metadata[key] = value + + # If no logs, yield metadata as-is + if not logs: + yield json.dumps(span, indent=2) + return + + # Split logs into batches + log_items = sorted(logs.items(), key=lambda x: int(x[0].split('_')[1])) + + current_log_batch = {} + for log_key, log_value in log_items: + current_log_batch[log_key] = log_value + + # Check size with metadata + chunk_data = {**metadata, **current_log_batch} + chunk_str = json.dumps(chunk_data, indent=2) + + if len(chunk_str) > chunk_size and len(current_log_batch) > 1: + # Yield without this log + current_log_batch.pop(log_key) + yield json.dumps({**metadata, **current_log_batch}, indent=2) + + # Start new batch with just this log + current_log_batch = {log_key: log_value} + + # Yield final batch *(incase len(chunk_str) > chunk_size is false for them) + if current_log_batch: + yield json.dumps({**metadata, **current_log_batch}, indent=2) + + +def _batch_spans_list(spans: list, chunk_size: int) -> Iterator[str]: + """Batch a list of spans into chunks.""" + current_batch = [] + current_size = 0 + + for span in spans: + span_str = json.dumps(span, indent=2) + span_size = len(span_str) + + # If single span is too big, split it + if span_size > chunk_size: + # Yield current batch first + if current_batch: + yield json.dumps(current_batch, indent=2) + current_batch = [] + current_size = 0 + + # Split the large span + if isinstance(span, dict) and "span_id" in span: + yield from _hierarchical_split(span, chunk_size) + else: + yield span_str + continue + + # Check if adding this span exceeds limit + if current_batch and current_size + span_size > chunk_size: + # Yield current batch + yield json.dumps(current_batch, indent=2) + current_batch = [] + current_size = 0 + + # Add span to current batch + current_batch.append(span) + current_size += span_size + + # Yield final batch + if current_batch: + yield json.dumps(current_batch, indent=2)