11import asyncio
22import logging
3- import os
43from collections import deque
54from datetime import datetime , timezone
65from typing import Any
@@ -106,11 +105,11 @@ def __init__(
106105 self .cache = SimpleMemoryCache (ttl = 60 * 10 )
107106 self ._setup_routes ()
108107
109- def get_observe_provider (self , request : Request ) -> ObservabilityProvider :
108+ async def get_observe_provider (self , request : Request ) -> ObservabilityProvider :
110109 """Get observability provider based on request.
111110
112111 For local mode, always use the default Jaeger provider.
113- For non-local mode, use default AWS provider unless request specifies Tencent .
112+ For non-local mode, fetch provider configuration from request params and MongoDB .
114113
115114 Args:
116115 request: FastAPI request object
@@ -121,15 +120,82 @@ def get_observe_provider(self, request: Request) -> ObservabilityProvider:
121120 if self .local_mode :
122121 return self .default_observe_provider
123122
124- # Check if request requires Tencent provider
125- # For now, check REST_OBSERVE_PROVIDER env var
126- # TODO: Check user settings from request to determine provider
127- provider_type = os .getenv ("REST_OBSERVE_PROVIDER" , "aws" ).lower ()
123+ # Extract provider parameters from request
124+ query_params = request .query_params
125+ trace_provider = query_params .get ("trace_provider" , "aws" )
126+ log_provider = query_params .get ("log_provider" , "aws" )
127+ trace_region = query_params .get ("trace_region" )
128+ log_region = query_params .get ("log_region" )
128129
129- if provider_type == "tencent" :
130- return ObservabilityProvider .create_tencent_provider ()
131- else :
132- return self .default_observe_provider
130+ # Get user email to fetch MongoDB config
131+ user_email , _ , _ = get_user_credentials (request )
132+
133+ # Prepare configurations
134+ trace_config : dict [str , Any ] = {}
135+ log_config : dict [str , Any ] = {}
136+
137+ # For Tencent, fetch credentials from MongoDB
138+ if trace_provider == "tencent" :
139+ trace_provider_config = await self .db_client .get_trace_provider_config (
140+ user_email
141+ )
142+ if trace_provider_config and trace_provider_config .get ("tencentTraceConfig" ):
143+ tencent_config = trace_provider_config ["tencentTraceConfig" ]
144+ trace_config = {
145+ "region" : trace_region or tencent_config .get ("region" ,
146+ "ap-hongkong" ),
147+ "secret_id" : tencent_config .get ("secretId" ),
148+ "secret_key" : tencent_config .get ("secretKey" ),
149+ "apm_instance_id" : tencent_config .get ("apmInstanceId" ),
150+ }
151+ else :
152+ # Fallback to region only if no MongoDB config
153+ trace_config = {"region" : trace_region or "ap-hongkong" }
154+ elif trace_provider == "aws" :
155+ trace_config = {"region" : trace_region }
156+ elif trace_provider == "jaeger" :
157+ # Fetch jaeger config from MongoDB if available
158+ trace_provider_config = await self .db_client .get_trace_provider_config (
159+ user_email
160+ )
161+ if trace_provider_config and trace_provider_config .get ("jaegerTraceConfig" ):
162+ jaeger_config = trace_provider_config ["jaegerTraceConfig" ]
163+ trace_config = {"url" : jaeger_config .get ("endpoint" )}
164+ else :
165+ trace_config = {}
166+
167+ if log_provider == "tencent" :
168+ log_provider_config = await self .db_client .get_log_provider_config (user_email )
169+ if log_provider_config and log_provider_config .get ("tencentLogConfig" ):
170+ tencent_config = log_provider_config ["tencentLogConfig" ]
171+ log_config = {
172+ "region" : log_region or tencent_config .get ("region" ,
173+ "ap-hongkong" ),
174+ "secret_id" : tencent_config .get ("secretId" ),
175+ "secret_key" : tencent_config .get ("secretKey" ),
176+ "cls_topic_id" : tencent_config .get ("clsTopicId" ),
177+ }
178+ else :
179+ # Fallback to region only if no MongoDB config
180+ log_config = {"region" : log_region or "ap-hongkong" }
181+ elif log_provider == "aws" :
182+ log_config = {"region" : log_region }
183+ elif log_provider == "jaeger" :
184+ # Fetch jaeger config from MongoDB if available
185+ log_provider_config = await self .db_client .get_log_provider_config (user_email )
186+ if log_provider_config and log_provider_config .get ("jaegerLogConfig" ):
187+ jaeger_config = log_provider_config ["jaegerLogConfig" ]
188+ log_config = {"url" : jaeger_config .get ("endpoint" )}
189+ else :
190+ log_config = {}
191+
192+ # Create and return the provider
193+ return ObservabilityProvider .create (
194+ trace_provider = trace_provider ,
195+ log_provider = log_provider ,
196+ trace_config = trace_config ,
197+ log_config = log_config ,
198+ )
133199
134200 def _setup_routes (self ):
135201 r"""Set up API routes"""
@@ -341,7 +407,7 @@ async def list_traces(
341407 # If trace_id is provided, fetch that specific trace directly
342408 if trace_id :
343409 try :
344- observe_provider = self .get_observe_provider (request )
410+ observe_provider = await self .get_observe_provider (request )
345411
346412 # Use the new get_trace_by_id method which handles everything
347413 trace = await observe_provider .trace_client .get_trace_by_id (
@@ -435,7 +501,7 @@ async def list_traces(
435501 return resp .model_dump ()
436502
437503 try :
438- observe_provider = self .get_observe_provider (request )
504+ observe_provider = await self .get_observe_provider (request )
439505 traces : list [Trace ] = await observe_provider .trace_client .get_recent_traces (
440506 start_time = start_time ,
441507 end_time = end_time ,
@@ -554,7 +620,7 @@ async def get_logs_by_trace_id(
554620 return resp .model_dump ()
555621
556622 try :
557- observe_provider = self .get_observe_provider (request )
623+ observe_provider = await self .get_observe_provider (request )
558624 logs : TraceLogs = await observe_provider .log_client .get_logs_by_trace_id (
559625 trace_id = req_data .trace_id ,
560626 start_time = req_data .start_time ,
@@ -721,7 +787,7 @@ async def post_chat(
721787 is_github_pr = github_related .is_github_pr
722788
723789 # Get the trace #######################################################
724- observe_provider = self .get_observe_provider (request )
790+ observe_provider = await self .get_observe_provider (request )
725791 selected_trace : Trace | None = None
726792
727793 # If we have a trace_id, fetch it directly
@@ -777,7 +843,7 @@ async def post_chat(
777843 keys = (trace_id , start_time , end_time , log_group_name )
778844 logs : TraceLogs | None = await self .cache .get (keys )
779845 if logs is None :
780- observe_provider = self .get_observe_provider (request )
846+ observe_provider = await self .get_observe_provider (request )
781847 logs = await observe_provider .log_client .get_logs_by_trace_id (
782848 trace_id = trace_id ,
783849 start_time = start_time ,
@@ -1014,7 +1080,7 @@ async def get_traces_and_logs_since_date(
10141080
10151081 try :
10161082 # Get comprehensive traces and logs data since the specified date
1017- observe_provider = self .get_observe_provider (request )
1083+ observe_provider = await self .get_observe_provider (request )
10181084 usage_data = await get_user_traces_and_logs_since_payment (
10191085 user_sub = user_sub ,
10201086 last_payment_date = req_data .since_date ,
@@ -1086,7 +1152,7 @@ async def _filter_traces_by_log_content(
10861152 # from semicolon-separated logs
10871153
10881154 # Single query to get all matching trace IDs
1089- observe_provider = self .get_observe_provider (request )
1155+ observe_provider = await self .get_observe_provider (request )
10901156 matching_trace_ids = \
10911157 await observe_provider .log_client .get_trace_ids_from_logs (
10921158 start_time = start_time ,
0 commit comments