Skip to content

Commit 564a507

Browse files
Merge pull request #185 from stanford-oval/costorm-integration
Costorm integration
2 parents 33a03a3 + efac123 commit 564a507

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+5191
-270
lines changed

README.md

Lines changed: 166 additions & 41 deletions
Large diffs are not rendered by default.

assets/co-storm-workflow.jpg

767 KB
Loading
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""
2+
Co-STORM pipeline powered by GPT-4o/4o-mini and Bing search engine.
3+
You need to set up the following environment variables to run this script:
4+
- OPENAI_API_KEY: OpenAI API key
5+
- OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure')
6+
- AZURE_API_BASE: Azure API base URL if using Azure API
7+
- AZURE_API_VERSION: Azure API version if using Azure API
8+
- BING_SEARCH_API_KEY: Biang search API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key
9+
10+
Output will be structured as below
11+
args.output_dir/
12+
log.json # Log of information-seeking conversation
13+
report.txt # Final article generated
14+
"""
15+
16+
import os
17+
import json
18+
from argparse import ArgumentParser
19+
from knowledge_storm.collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, CoStormRunner
20+
from knowledge_storm.collaborative_storm.modules.callback import LocalConsolePrintCallBackHandler
21+
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
22+
from knowledge_storm.logging_wrapper import LoggingWrapper
23+
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG
24+
from knowledge_storm.utils import load_api_key
25+
26+
27+
def main(args):
28+
load_api_key(toml_file_path='secrets.toml')
29+
lm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs()
30+
openai_kwargs = {
31+
"api_key": os.getenv("OPENAI_API_KEY"),
32+
"api_provider": "openai",
33+
"temperature": 1.0,
34+
"top_p": 0.9,
35+
"api_base": None,
36+
} if os.getenv('OPENAI_API_TYPE') == 'openai' else {
37+
"api_key": os.getenv("AZURE_API_KEY"),
38+
"temperature": 1.0,
39+
"top_p": 0.9,
40+
"api_base": os.getenv("AZURE_API_BASE"),
41+
"api_version": os.getenv("AZURE_API_VERSION"),
42+
}
43+
44+
ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel
45+
# If you are using Azure service, make sure the model name matches your own deployed model name.
46+
# The default name here is only used for demonstration and may not match your case.
47+
gpt_4o_mini_model_name = 'gpt-4o-mini'
48+
gpt_4o_model_name = 'gpt-4o'
49+
if os.getenv('OPENAI_API_TYPE') == 'azure':
50+
openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE')
51+
openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION')
52+
53+
# STORM is a LM system so different components can be powered by different models.
54+
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
55+
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
56+
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
57+
# which is responsible for generating sections with citations.
58+
question_answering_lm = ModelClass(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs)
59+
discourse_manage_lm = ModelClass(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs)
60+
utterance_polishing_lm = ModelClass(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs)
61+
warmstart_outline_gen_lm = ModelClass(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs)
62+
question_asking_lm = ModelClass(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs)
63+
knowledge_base_lm = ModelClass(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs)
64+
65+
lm_config.set_question_answering_lm(question_answering_lm)
66+
lm_config.set_discourse_manage_lm(discourse_manage_lm)
67+
lm_config.set_utterance_polishing_lm(utterance_polishing_lm)
68+
lm_config.set_warmstart_outline_gen_lm(warmstart_outline_gen_lm)
69+
lm_config.set_question_asking_lm(question_asking_lm)
70+
lm_config.set_knowledge_base_lm(knowledge_base_lm)
71+
72+
topic = input('Topic: ')
73+
runner_argument = RunnerArgument(
74+
topic=topic,
75+
retrieve_top_k=args.retrieve_top_k,
76+
max_search_queries=args.max_search_queries,
77+
total_conv_turn=args.total_conv_turn,
78+
max_search_thread=args.max_search_thread,
79+
max_search_queries_per_turn=args.max_search_queries_per_turn,
80+
warmstart_max_num_experts=args.warmstart_max_num_experts,
81+
warmstart_max_turn_per_experts=args.warmstart_max_turn_per_experts,
82+
warmstart_max_thread=args.warmstart_max_thread,
83+
max_thread_num=args.max_thread_num,
84+
max_num_round_table_experts=args.max_num_round_table_experts,
85+
moderator_override_N_consecutive_answering_turn=args.moderator_override_N_consecutive_answering_turn,
86+
node_expansion_trigger_count=args.node_expansion_trigger_count)
87+
logging_wrapper = LoggingWrapper(lm_config)
88+
callback_handler = LocalConsolePrintCallBackHandler() if args.enable_log_print else None
89+
90+
# Co-STORM is a knowledge curation system which consumes information from the retrieval module.
91+
# Currently, the information source is the Internet and we use search engine API as the retrieval module.
92+
match args.retriever:
93+
case 'bing':
94+
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=runner_argument.retrieve_top_k)
95+
case 'you':
96+
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=runner_argument.retrieve_top_k)
97+
case 'brave':
98+
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=runner_argument.retrieve_top_k)
99+
case 'duckduckgo':
100+
rm = DuckDuckGoSearchRM(k=runner_argument.retrieve_top_k, safe_search='On', region='us-en')
101+
case 'serper':
102+
rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1})
103+
case 'tavily':
104+
rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=runner_argument.retrieve_top_k, include_raw_content=True)
105+
case 'searxng':
106+
rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=runner_argument.retrieve_top_k)
107+
case _:
108+
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"')
109+
110+
costorm_runner = CoStormRunner(lm_config=lm_config,
111+
runner_argument=runner_argument,
112+
logging_wrapper=logging_wrapper,
113+
rm=rm,
114+
callback_handler=callback_handler)
115+
116+
# warm start the system
117+
costorm_runner.warm_start()
118+
119+
# Below is an example of how users may interact with Co-STORM to seek information together
120+
# In actual deployment, we suggest allowing the user to decide whether to observe the agent utterance or inject a turn
121+
122+
# observing Co-STORM LLM agent utterance for 5 turns
123+
for _ in range(1):
124+
conv_turn = costorm_runner.step()
125+
print(f"**{conv_turn.role}**: {conv_turn.utterance}\n")
126+
127+
# active engaging by injecting your utterance
128+
your_utterance = input('Your utterance: ')
129+
costorm_runner.step(user_utterance=your_utterance)
130+
131+
# continue observing
132+
conv_turn = costorm_runner.step()
133+
print(f"**{conv_turn.role}**: {conv_turn.utterance}\n")
134+
135+
# generate report
136+
costorm_runner.knowledge_base.reogranize()
137+
article = costorm_runner.generate_report()
138+
139+
# save results
140+
os.makedirs(args.output_dir, exist_ok=True)
141+
142+
# Save article
143+
with open(os.path.join(args.output_dir, "report.md"), "w") as f:
144+
f.write(article)
145+
146+
# Save logging
147+
log_dump = costorm_runner.dump_logging_and_reset()
148+
with open(os.path.join(args.output_dir, "log.json"), "w") as f:
149+
json.dump(log_dump, f, indent=2)
150+
151+
152+
if __name__ == '__main__':
153+
parser = ArgumentParser()
154+
# global arguments
155+
parser.add_argument('--output-dir', type=str, default='./results/co-storm',
156+
help='Directory to store the outputs.')
157+
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'],
158+
help='The search engine API to use for retrieving information.')
159+
# hyperparameters for co-storm
160+
parser.add_argument(
161+
'--retrieve_top_k',
162+
type=int,
163+
default=10,
164+
help='Retrieve top k results for each query in retriever.'
165+
)
166+
parser.add_argument(
167+
'--max_search_queries',
168+
type=int,
169+
default=2,
170+
help='Maximum number of search queries to consider for each question.'
171+
)
172+
parser.add_argument(
173+
'--total_conv_turn',
174+
type=int,
175+
default=20,
176+
help='Maximum number of turns in conversation.'
177+
)
178+
parser.add_argument(
179+
'--max_search_thread',
180+
type=int,
181+
default=5,
182+
help='Maximum number of parallel threads for retriever.'
183+
)
184+
parser.add_argument(
185+
'--max_search_queries_per_turn',
186+
type=int,
187+
default=3,
188+
help='Maximum number of search queries to consider in each turn.'
189+
)
190+
parser.add_argument(
191+
'--warmstart_max_num_experts',
192+
type=int,
193+
default=3,
194+
help='Max number of experts in perspective-guided QA during warm start.'
195+
)
196+
parser.add_argument(
197+
'--warmstart_max_turn_per_experts',
198+
type=int,
199+
default=2,
200+
help='Max number of turns per perspective during warm start.'
201+
)
202+
parser.add_argument(
203+
'--warmstart_max_thread',
204+
type=int,
205+
default=3,
206+
help='Max number of threads for parallel perspective-guided QA during warm start.'
207+
)
208+
parser.add_argument(
209+
'--max_thread_num',
210+
type=int,
211+
default=10,
212+
help=("Maximum number of threads to use. "
213+
"Consider reducing it if you keep getting 'Exceed rate limit' errors when calling the LM API.")
214+
)
215+
parser.add_argument(
216+
'--max_num_round_table_experts',
217+
type=int,
218+
default=2,
219+
help='Max number of active experts in round table discussion.'
220+
)
221+
parser.add_argument(
222+
'--moderator_override_N_consecutive_answering_turn',
223+
type=int,
224+
default=3,
225+
help=('Number of consecutive expert answering turns before the moderator overrides the conversation.')
226+
)
227+
parser.add_argument(
228+
'--node_expansion_trigger_count',
229+
type=int,
230+
default=10,
231+
help='Trigger node expansion for nodes that contain more than N snippets.'
232+
)
233+
234+
# Boolean flags
235+
parser.add_argument(
236+
'--enable_log_print',
237+
action='store_true',
238+
help='If set, enable console log print.'
239+
)
240+
241+
main(parser.parse_args())

examples/README.md renamed to examples/storm_examples/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ We host a number of example scripts for various customization of STORM (e.g., us
1111
2. Run the following command under the root directory of the repository:
1212

1313
```
14-
python examples/run_storm_wiki_mistral.py \
14+
python examples/storm_examples/run_storm_wiki_mistral.py \
1515
--url $URL \
1616
--port $PORT \
1717
--output-dir $OUTPUT_DIR \
@@ -50,7 +50,7 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
5050
To create the vector store offline, run
5151
5252
```
53-
python examples/run_storm_wiki_gpt_with_VectorRM.py \
53+
python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
5454
--output-dir $OUTPUT_DIR \
5555
--vector-db-mode offline \
5656
--offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \
@@ -65,7 +65,7 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
6565
To create the vector store online on a Qdrant server, run
6666
6767
```
68-
python examples/run_storm_wiki_gpt_with_VectorRM.py \
68+
python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
6969
--output-dir $OUTPUT_DIR \
7070
--vector-db-mode online \
7171
--online-vector-db-url $ONLINE_VECTOR_DB_URL \
@@ -83,12 +83,12 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
8383
- Run the following command under the root directory to downsample the dataset by filtering papers with terms `[cs.CV]` and get a csv file that match the format mentioned above.
8484
8585
```
86-
python examples/helper/process_kaggle_arxiv_abstract_dataset.py --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV
86+
python examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV
8787
```
8888
- Run the following command to run STORM grounding on the processed dataset. You can input a topic related to computer vision (e.g., "The progress of multimodal models in computer vision") to see the generated article. (Note that the generated article may not include enough details since the quick test only use the abstracts of arxiv papers.)
8989
9090
```
91-
python examples/run_storm_wiki_gpt_with_VectorRM.py \
91+
python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
9292
--output-dir $OUTPUT_DIR \
9393
--vector-db-mode offline \
9494
--offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \
@@ -102,7 +102,7 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
102102
- For a quicker run, you can also download the pre-embedded vector store directly from [here](https://drive.google.com/file/d/1bijFkw5BKU7bqcmXMhO-5hg2fdKAL9bf/view?usp=share_link).
103103
104104
```
105-
python examples/run_storm_wiki_gpt_with_VectorRM.py \
105+
python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
106106
--output-dir $OUTPUT_DIR \
107107
--vector-db-mode offline \
108108
--offline-vector-db-dir $DOWNLOADED_VECTOR_DB_DR \
File renamed without changes.

examples/run_storm_wiki_gpt_with_VectorRM.py renamed to examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727
"""
2828

2929
import os
30-
import sys
3130
from argparse import ArgumentParser
3231

33-
sys.path.append('./')
3432
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
3533
from knowledge_storm.rm import VectorRM
3634
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel

0 commit comments

Comments
 (0)