Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions examples/costorm_examples/run_costorm_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def main(args):
if os.getenv("OPENAI_API_TYPE") == "azure":
openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE")
openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION")
gpt_4o_model_name = os.getenv("AZURE_API_MODEL")

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
Expand Down Expand Up @@ -199,7 +200,7 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True)

# Save article
with open(os.path.join(args.output_dir, "report.md"), "w") as f:
with open(os.path.join(args.output_dir, "report.md"), "w", encoding="utf-8") as f:
f.write(article)

# Save instance dump
Expand All @@ -210,7 +211,17 @@ def main(args):
# Save logging
log_dump = costorm_runner.dump_logging_and_reset()
with open(os.path.join(args.output_dir, "log.json"), "w") as f:
json.dump(log_dump, f, indent=2)
for stage in log_dump:
stage_obj = log_dump[stage]
for metric in stage_obj:
metric_obj=stage_obj[metric]
if metric == "lm_history":
# values = []
for history in metric_obj:
history["response"] = history["response"]["choices"][0].message.content
# values.append(text)
# metric_obj[history] = ".".join(values)
json.dump(stage_obj, f, indent=2)


if __name__ == "__main__":
Expand All @@ -225,6 +236,7 @@ def main(args):
parser.add_argument(
"--retriever",
type=str,
default="duckduckgo",
choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"],
help="The search engine API to use for retrieving information.",
)
Expand Down Expand Up @@ -280,7 +292,7 @@ def main(args):
parser.add_argument(
"--max_thread_num",
type=int,
default=10,
default=1,
help=(
"Maximum number of threads to use. "
"Consider reducing it if you keep getting 'Exceed rate limit' errors when calling the LM API."
Expand Down
3 changes: 3 additions & 0 deletions examples/storm_examples/run_storm_wiki_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def main(args):
if os.getenv("OPENAI_API_TYPE") == "azure":
openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE")
openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION")
openai_kwargs["api_key"] = os.getenv("AZURE_API_KEY")
gpt_35_model_name = os.getenv("AZURE_API_MODEL")
gpt_4_model_name = os.getenv("AZURE_API_MODEL")

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
Expand Down
4 changes: 2 additions & 2 deletions knowledge_storm/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ class AzureOpenAIModel(dspy.LM):

def __init__(
self,
azure_endpoint: str,
api_base: str,
api_version: str,
model: str,
api_key: str,
Expand All @@ -482,7 +482,7 @@ def __init__(
self.model_type = model_type

self.client = AzureOpenAI(
azure_endpoint=azure_endpoint,
azure_endpoint=api_base,
api_key=api_key,
api_version=api_version,
)
Expand Down