Skip to content

Commit 56672d9

Browse files
authored
Merge pull request #255 from traceroot-ai/feat/support_tencent_cloud_backend
feat: support tencent cloud backend
2 parents 723f537 + ecf7e37 commit 56672d9

File tree

8 files changed

+328
-65
lines changed

8 files changed

+328
-65
lines changed

rest/config/traces_and_logs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
class GetTracesAndLogsSinceDateRequest(BaseModel):
1212
"""Request model for getting traces and logs since a specific date."""
1313
since_date: datetime
14+
trace_provider: str = "aws"
15+
log_provider: str = "aws"
16+
trace_region: str | None = None
17+
log_region: str | None = None
1418

1519
@field_validator('since_date')
1620
@classmethod

rest/dao/mongodb_dao.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,33 @@ async def get_credentials_by_token(
9090
) -> dict[str,
9191
Any] | None:
9292
pass
93+
94+
async def get_trace_provider_config(
95+
self,
96+
user_email: str,
97+
) -> dict[str,
98+
Any] | None:
99+
"""Get trace provider configuration for a user from MongoDB.
100+
101+
Args:
102+
user_email (str): The user's email
103+
104+
Returns:
105+
dict[str, Any] | None: The trace provider config if found, None otherwise
106+
"""
107+
return None
108+
109+
async def get_log_provider_config(
110+
self,
111+
user_email: str,
112+
) -> dict[str,
113+
Any] | None:
114+
"""Get log provider configuration for a user from MongoDB.
115+
116+
Args:
117+
user_email (str): The user's email
118+
119+
Returns:
120+
dict[str, Any] | None: The log provider config if found, None otherwise
121+
"""
122+
return None

rest/routers/explore.py

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import logging
3-
import os
43
from collections import deque
54
from datetime import datetime, timezone
65
from typing import Any
@@ -106,11 +105,11 @@ def __init__(
106105
self.cache = SimpleMemoryCache(ttl=60 * 10)
107106
self._setup_routes()
108107

109-
def get_observe_provider(self, request: Request) -> ObservabilityProvider:
108+
async def get_observe_provider(self, request: Request) -> ObservabilityProvider:
110109
"""Get observability provider based on request.
111110
112111
For local mode, always use the default Jaeger provider.
113-
For non-local mode, use default AWS provider unless request specifies Tencent.
112+
For non-local mode, fetch provider configuration from request params and MongoDB.
114113
115114
Args:
116115
request: FastAPI request object
@@ -121,15 +120,82 @@ def get_observe_provider(self, request: Request) -> ObservabilityProvider:
121120
if self.local_mode:
122121
return self.default_observe_provider
123122

124-
# Check if request requires Tencent provider
125-
# For now, check REST_OBSERVE_PROVIDER env var
126-
# TODO: Check user settings from request to determine provider
127-
provider_type = os.getenv("REST_OBSERVE_PROVIDER", "aws").lower()
123+
# Extract provider parameters from request
124+
query_params = request.query_params
125+
trace_provider = query_params.get("trace_provider", "aws")
126+
log_provider = query_params.get("log_provider", "aws")
127+
trace_region = query_params.get("trace_region")
128+
log_region = query_params.get("log_region")
128129

129-
if provider_type == "tencent":
130-
return ObservabilityProvider.create_tencent_provider()
131-
else:
132-
return self.default_observe_provider
130+
# Get user email to fetch MongoDB config
131+
user_email, _, _ = get_user_credentials(request)
132+
133+
# Prepare configurations
134+
trace_config: dict[str, Any] = {}
135+
log_config: dict[str, Any] = {}
136+
137+
# For Tencent, fetch credentials from MongoDB
138+
if trace_provider == "tencent":
139+
trace_provider_config = await self.db_client.get_trace_provider_config(
140+
user_email
141+
)
142+
if trace_provider_config and trace_provider_config.get("tencentTraceConfig"):
143+
tencent_config = trace_provider_config["tencentTraceConfig"]
144+
trace_config = {
145+
"region": trace_region or tencent_config.get("region",
146+
"ap-hongkong"),
147+
"secret_id": tencent_config.get("secretId"),
148+
"secret_key": tencent_config.get("secretKey"),
149+
"apm_instance_id": tencent_config.get("apmInstanceId"),
150+
}
151+
else:
152+
# Fallback to region only if no MongoDB config
153+
trace_config = {"region": trace_region or "ap-hongkong"}
154+
elif trace_provider == "aws":
155+
trace_config = {"region": trace_region}
156+
elif trace_provider == "jaeger":
157+
# Fetch jaeger config from MongoDB if available
158+
trace_provider_config = await self.db_client.get_trace_provider_config(
159+
user_email
160+
)
161+
if trace_provider_config and trace_provider_config.get("jaegerTraceConfig"):
162+
jaeger_config = trace_provider_config["jaegerTraceConfig"]
163+
trace_config = {"url": jaeger_config.get("endpoint")}
164+
else:
165+
trace_config = {}
166+
167+
if log_provider == "tencent":
168+
log_provider_config = await self.db_client.get_log_provider_config(user_email)
169+
if log_provider_config and log_provider_config.get("tencentLogConfig"):
170+
tencent_config = log_provider_config["tencentLogConfig"]
171+
log_config = {
172+
"region": log_region or tencent_config.get("region",
173+
"ap-hongkong"),
174+
"secret_id": tencent_config.get("secretId"),
175+
"secret_key": tencent_config.get("secretKey"),
176+
"cls_topic_id": tencent_config.get("clsTopicId"),
177+
}
178+
else:
179+
# Fallback to region only if no MongoDB config
180+
log_config = {"region": log_region or "ap-hongkong"}
181+
elif log_provider == "aws":
182+
log_config = {"region": log_region}
183+
elif log_provider == "jaeger":
184+
# Fetch jaeger config from MongoDB if available
185+
log_provider_config = await self.db_client.get_log_provider_config(user_email)
186+
if log_provider_config and log_provider_config.get("jaegerLogConfig"):
187+
jaeger_config = log_provider_config["jaegerLogConfig"]
188+
log_config = {"url": jaeger_config.get("endpoint")}
189+
else:
190+
log_config = {}
191+
192+
# Create and return the provider
193+
return ObservabilityProvider.create(
194+
trace_provider=trace_provider,
195+
log_provider=log_provider,
196+
trace_config=trace_config,
197+
log_config=log_config,
198+
)
133199

134200
def _setup_routes(self):
135201
r"""Set up API routes"""
@@ -341,7 +407,7 @@ async def list_traces(
341407
# If trace_id is provided, fetch that specific trace directly
342408
if trace_id:
343409
try:
344-
observe_provider = self.get_observe_provider(request)
410+
observe_provider = await self.get_observe_provider(request)
345411

346412
# Use the new get_trace_by_id method which handles everything
347413
trace = await observe_provider.trace_client.get_trace_by_id(
@@ -435,7 +501,7 @@ async def list_traces(
435501
return resp.model_dump()
436502

437503
try:
438-
observe_provider = self.get_observe_provider(request)
504+
observe_provider = await self.get_observe_provider(request)
439505
traces: list[Trace] = await observe_provider.trace_client.get_recent_traces(
440506
start_time=start_time,
441507
end_time=end_time,
@@ -554,7 +620,7 @@ async def get_logs_by_trace_id(
554620
return resp.model_dump()
555621

556622
try:
557-
observe_provider = self.get_observe_provider(request)
623+
observe_provider = await self.get_observe_provider(request)
558624
logs: TraceLogs = await observe_provider.log_client.get_logs_by_trace_id(
559625
trace_id=req_data.trace_id,
560626
start_time=req_data.start_time,
@@ -721,7 +787,7 @@ async def post_chat(
721787
is_github_pr = github_related.is_github_pr
722788

723789
# Get the trace #######################################################
724-
observe_provider = self.get_observe_provider(request)
790+
observe_provider = await self.get_observe_provider(request)
725791
selected_trace: Trace | None = None
726792

727793
# If we have a trace_id, fetch it directly
@@ -777,7 +843,7 @@ async def post_chat(
777843
keys = (trace_id, start_time, end_time, log_group_name)
778844
logs: TraceLogs | None = await self.cache.get(keys)
779845
if logs is None:
780-
observe_provider = self.get_observe_provider(request)
846+
observe_provider = await self.get_observe_provider(request)
781847
logs = await observe_provider.log_client.get_logs_by_trace_id(
782848
trace_id=trace_id,
783849
start_time=start_time,
@@ -1014,7 +1080,7 @@ async def get_traces_and_logs_since_date(
10141080

10151081
try:
10161082
# Get comprehensive traces and logs data since the specified date
1017-
observe_provider = self.get_observe_provider(request)
1083+
observe_provider = await self.get_observe_provider(request)
10181084
usage_data = await get_user_traces_and_logs_since_payment(
10191085
user_sub=user_sub,
10201086
last_payment_date=req_data.since_date,
@@ -1086,7 +1152,7 @@ async def _filter_traces_by_log_content(
10861152
# from semicolon-separated logs
10871153

10881154
# Single query to get all matching trace IDs
1089-
observe_provider = self.get_observe_provider(request)
1155+
observe_provider = await self.get_observe_provider(request)
10901156
matching_trace_ids = \
10911157
await observe_provider.log_client.get_trace_ids_from_logs(
10921158
start_time=start_time,

rest/service/provider.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ def create_log_client(
5151
if provider == ObservabilityProviderType.AWS:
5252
return AWSLogClient(aws_region=kwargs.get('region'))
5353
elif provider == ObservabilityProviderType.TENCENT:
54-
return TencentLogClient(tencent_region=kwargs.get('region'))
54+
return TencentLogClient(
55+
tencent_region=kwargs.get('region'),
56+
secret_id=kwargs.get('secret_id'),
57+
secret_key=kwargs.get('secret_key'),
58+
cls_topic_id=kwargs.get('cls_topic_id')
59+
)
5560
elif provider == ObservabilityProviderType.JAEGER:
5661
return JaegerLogClient(jaeger_url=kwargs.get('url'))
5762
else:
@@ -79,7 +84,12 @@ def create_trace_client(
7984
if provider == ObservabilityProviderType.AWS:
8085
return AWSTraceClient(aws_region=kwargs.get('region'))
8186
elif provider == ObservabilityProviderType.TENCENT:
82-
return TencentTraceClient(tencent_region=kwargs.get('region'))
87+
return TencentTraceClient(
88+
tencent_region=kwargs.get('region'),
89+
secret_id=kwargs.get('secret_id'),
90+
secret_key=kwargs.get('secret_key'),
91+
apm_instance_id=kwargs.get('apm_instance_id')
92+
)
8393
elif provider == ObservabilityProviderType.JAEGER:
8494
return JaegerTraceClient(jaeger_url=kwargs.get('url'))
8595
else:

rest/utils/encryption.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import base64
2+
import hashlib
3+
import os
4+
from typing import Optional
5+
6+
from cryptography.hazmat.backends import default_backend
7+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
8+
9+
10+
def get_encryption_secret() -> str:
11+
"""Get encryption secret from environment variable.
12+
13+
Returns:
14+
str: The encryption secret key
15+
"""
16+
secret = os.getenv("SECRET_ENCRYPT_KEY")
17+
if not secret:
18+
# Fallback to "LOCAL" for local development
19+
secret = "LOCAL"
20+
return secret
21+
22+
23+
def get_key(secret: str) -> bytes:
24+
"""Generate AES-256 key from secret using SHA256.
25+
26+
Args:
27+
secret (str): The secret key string
28+
29+
Returns:
30+
bytes: 32-byte key for AES-256
31+
"""
32+
return hashlib.sha256(secret.encode()).digest()
33+
34+
35+
def decrypt_value(encrypted_value: str) -> Optional[str]:
36+
"""Decrypt a value using AES-256-CBC.
37+
38+
Args:
39+
encrypted_value (str): Base64-encoded encrypted value (IV + ciphertext)
40+
41+
Returns:
42+
Optional[str]: Decrypted string value, or None if decryption fails
43+
"""
44+
if not encrypted_value:
45+
return None
46+
47+
# Always use "LOCAL" to match frontend behavior
48+
# Client-side encryption cannot truly be secret in browser
49+
secrets_to_try = [
50+
"LOCAL", # Primary key used by frontend
51+
]
52+
53+
for secret in secrets_to_try:
54+
try:
55+
key = get_key(secret)
56+
57+
# Decode from base64
58+
data = base64.b64decode(encrypted_value)
59+
60+
# Extract IV (first 16 bytes) and encrypted data
61+
iv = data[:16]
62+
encrypted = data[16:]
63+
64+
# Create cipher and decrypt
65+
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
66+
decryptor = cipher.decryptor()
67+
decrypted = decryptor.update(encrypted) + decryptor.finalize()
68+
69+
# Remove PKCS7 padding
70+
pad_len = decrypted[-1]
71+
72+
# Validate padding length (must be 1-16 for AES)
73+
if pad_len < 1 or pad_len > 16:
74+
raise ValueError(f"Invalid padding length: {pad_len}")
75+
76+
# Validate PKCS7 padding - all padding bytes should have the same value
77+
padding_bytes = decrypted[-pad_len:]
78+
if not all(b == pad_len for b in padding_bytes):
79+
raise ValueError("Invalid PKCS7 padding")
80+
81+
unpadded = decrypted[:-pad_len]
82+
result = unpadded.decode('utf-8')
83+
return result
84+
except Exception:
85+
continue
86+
87+
# If all secrets failed
88+
return None
89+
90+
91+
def encrypt_value(value: str) -> Optional[str]:
92+
"""Encrypt a value using AES-256-CBC.
93+
94+
Args:
95+
value (str): Plain text value to encrypt
96+
97+
Returns:
98+
Optional[str]: Base64-encoded encrypted value, or None if encryption fails
99+
"""
100+
if not value:
101+
return None
102+
103+
try:
104+
secret = get_encryption_secret()
105+
key = get_key(secret)
106+
107+
# Generate random IV
108+
iv = os.urandom(16)
109+
110+
# Add PKCS7 padding
111+
pad_len = 16 - (len(value) % 16)
112+
padded = value + chr(pad_len) * pad_len
113+
114+
# Create cipher and encrypt
115+
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
116+
encryptor = cipher.encryptor()
117+
encrypted = encryptor.update(padded.encode()) + encryptor.finalize()
118+
119+
# Combine IV and encrypted data, then base64 encode
120+
return base64.b64encode(iv + encrypted).decode('utf-8')
121+
except Exception as e:
122+
print(f"Error encrypting value: {e}")
123+
return None

0 commit comments

Comments
 (0)