77import requests
88from dsp import backoff_hdlr , giveup_hdlr
99
10- from langchain_huggingface import HuggingFaceEmbeddings
11- from langchain_qdrant import Qdrant
12- from qdrant_client import QdrantClient
13-
1410from .utils import WebPageHelper
1511
1612
@@ -199,6 +195,8 @@ def __init__(
199195 device : str = "mps" ,
200196 k : int = 3 ,
201197 ):
198+ from langchain_huggingface import HuggingFaceEmbeddings
199+
202200 """
203201 Params:
204202 collection_name: Name of the Qdrant collection.
@@ -228,6 +226,8 @@ def __init__(
228226 self .qdrant = None
229227
230228 def _check_collection (self ):
229+ from langchain_qdrant import Qdrant
230+
231231 """
232232 Check if the Qdrant collection exists and create it if it does not.
233233 """
@@ -248,6 +248,8 @@ def _check_collection(self):
248248 )
249249
250250 def init_online_vector_db (self , url : str , api_key : str ):
251+ from qdrant_client import QdrantClient
252+
251253 """
252254 Initialize the Qdrant client that is connected to an online vector store with the given URL and API key.
253255
@@ -269,6 +271,8 @@ def init_online_vector_db(self, url: str, api_key: str):
269271 raise ValueError (f"Error occurs when connecting to the server: { e } " )
270272
271273 def init_offline_vector_db (self , vector_store_path : str ):
274+ from qdrant_client import QdrantClient
275+
272276 """
273277 Initialize the Qdrant client that is connected to an offline vector store with the given vector store folder path.
274278
@@ -336,36 +340,42 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
336340class StanfordOvalArxivRM (dspy .Retrieve ):
337341 """[Alpha] This retrieval class is for internal use only, not intended for the public."""
338342
339- def __init__ (self , endpoint , k = 3 ):
343+ def __init__ (self , endpoint , k = 3 , rerank = True ):
340344 super ().__init__ (k = k )
341345 self .endpoint = endpoint
342346 self .usage = 0
347+ self .rerank = rerank
343348
344349 def get_usage_and_reset (self ):
345350 usage = self .usage
346351 self .usage = 0
347352
348- return {"CS224vArxivRM " : usage }
353+ return {"StanfordOvalArxivRM " : usage }
349354
350355 def _retrieve (self , query : str ):
351- payload = {"query" : query , "num_blocks" : self .k }
356+ payload = {"query" : query , "num_blocks" : self .k , "rerank" : self . rerank }
352357
353358 response = requests .post (
354359 self .endpoint , json = payload , headers = {"Content-Type" : "application/json" }
355360 )
356361
357362 # Check if the request was successful
358363 if response .status_code == 200 :
359- data = response .json ()[0 ]
364+ response_data_list = response .json ()[0 ][ "results" ]
360365 results = []
361- for i in range ( len ( data [ "title" ])) :
366+ for response_data in response_data_list :
362367 result = {
363- "title" : data ["title" ][i ],
364- "url" : data ["title" ][i ],
365- "snippets" : [data ["text" ][i ]],
366- "description" : "N/A" ,
367- "meta" : {"section_title" : data ["full_section_title" ][i ]},
368+ "title" : response_data ["document_title" ],
369+ "url" : response_data ["url" ],
370+ "snippets" : [response_data ["content" ]],
371+ "description" : response_data .get ("description" , "N/A" ),
372+ "meta" : {
373+ key : value
374+ for key , value in response_data .items ()
375+ if key not in ["document_title" , "url" , "content" ]
376+ },
368377 }
378+
369379 results .append (result )
370380
371381 return results
@@ -537,9 +547,7 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
537547 snippets = [organic .get ("snippet" )]
538548 if self .ENABLE_EXTRA_SNIPPET_EXTRACTION :
539549 snippets .extend (
540- valid_url_to_snippets .get (url .strip ("'" ), {}).get (
541- "snippets" , []
542- )
550+ valid_url_to_snippets .get (url , {}).get ("snippets" , [])
543551 )
544552 collected_results .append (
545553 {
0 commit comments