repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import argparse |
| #2 | import json |
| #3 | import os |
| #4 | import time |
| #5 | from collections import defaultdict |
| #6 | |
| #7 | from dotenv import load_dotenv |
| #8 | from jinja2 import Template |
| #9 | from openai import OpenAI |
| #10 | from prompts import ANSWER_PROMPT_ZEP |
| #11 | from tqdm import tqdm |
| #12 | from zep_cloud import EntityEdge, EntityNode |
| #13 | from zep_cloud.client import Zep |
| #14 | |
| #15 | load_dotenv() |
| #16 | |
| #17 | TEMPLATE = """ |
| #18 | FACTS and ENTITIES represent relevant context to the current conversation. |
| #19 | |
| #20 | # These are the most relevant facts and their valid date ranges |
| #21 | # format: FACT (Date range: from - to) |
| #22 | |
| #23 | {facts} |
| #24 | |
| #25 | |
| #26 | # These are the most relevant entities |
| #27 | # ENTITY_NAME: entity summary |
| #28 | |
| #29 | {entities} |
| #30 | |
| #31 | """ |
| #32 | |
| #33 | |
| #34 | class ZepSearch: |
| #35 | def __init__(self): |
| #36 | self.zep_client = Zep(api_key=os.getenv("ZEP_API_KEY")) |
| #37 | self.results = defaultdict(list) |
| #38 | self.openai_client = OpenAI() |
| #39 | |
| #40 | def format_edge_date_range(self, edge: EntityEdge) -> str: |
| #41 | # return f"{datetime(edge.valid_at).strftime('%Y-%m-%d %H:%M:%S') if edge.valid_at else 'date unknown'} - {(edge.invalid_at.strftime('%Y-%m-%d %H:%M:%S') if edge.invalid_at else 'present')}" |
| #42 | return f"{edge.valid_at if edge.valid_at else 'date unknown'} - {(edge.invalid_at if edge.invalid_at else 'present')}" |
| #43 | |
| #44 | def compose_search_context(self, edges: list[EntityEdge], nodes: list[EntityNode]) -> str: |
| #45 | facts = [f" - {edge.fact} ({self.format_edge_date_range(edge)})" for edge in edges] |
| #46 | entities = [f" - {node.name}: {node.summary}" for node in nodes] |
| #47 | return TEMPLATE.format(facts="\n".join(facts), entities="\n".join(entities)) |
| #48 | |
| #49 | def search_memory(self, run_id, idx, query, max_retries=3, retry_delay=1): |
| #50 | start_time = time.time() |
| #51 | retries = 0 |
| #52 | while retries < max_retries: |
| #53 | try: |
| #54 | user_id = f"run_id_{run_id}_experiment_user_{idx}" |
| #55 | edges_results = ( |
| #56 | self.zep_client.graph.search( |
| #57 | user_id=user_id, reranker="cross_encoder", query=query, scope="edges", limit=20 |
| #58 | ) |
| #59 | ).edges |
| #60 | node_results = ( |
| #61 | self.zep_client.graph.search(user_id=user_id, reranker="rrf", query=query, scope="nodes", limit=20) |
| #62 | ).nodes |
| #63 | context = self.compose_search_context(edges_results, node_results) |
| #64 | break |
| #65 | except Exception as e: |
| #66 | print("Retrying...") |
| #67 | retries += 1 |
| #68 | if retries >= max_retries: |
| #69 | raise e |
| #70 | time.sleep(retry_delay) |
| #71 | |
| #72 | end_time = time.time() |
| #73 | |
| #74 | return context, end_time - start_time |
| #75 | |
| #76 | def process_question(self, run_id, val, idx): |
| #77 | question = val.get("question", "") |
| #78 | answer = val.get("answer", "") |
| #79 | category = val.get("category", -1) |
| #80 | evidence = val.get("evidence", []) |
| #81 | adversarial_answer = val.get("adversarial_answer", "") |
| #82 | |
| #83 | response, search_memory_time, response_time, context = self.answer_question(run_id, idx, question) |
| #84 | |
| #85 | result = { |
| #86 | "question": question, |
| #87 | "answer": answer, |
| #88 | "category": category, |
| #89 | "evidence": evidence, |
| #90 | "response": response, |
| #91 | "adversarial_answer": adversarial_answer, |
| #92 | "search_memory_time": search_memory_time, |
| #93 | "response_time": response_time, |
| #94 | "context": context, |
| #95 | } |
| #96 | |
| #97 | return result |
| #98 | |
| #99 | def answer_question(self, run_id, idx, question): |
| #100 | context, search_memory_time = self.search_memory(run_id, idx, question) |
| #101 | |
| #102 | template = Template(ANSWER_PROMPT_ZEP) |
| #103 | answer_prompt = template.render(memories=context, question=question) |
| #104 | |
| #105 | t1 = time.time() |
| #106 | response = self.openai_client.chat.completions.create( |
| #107 | model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0 |
| #108 | ) |
| #109 | t2 = time.time() |
| #110 | response_time = t2 - t1 |
| #111 | return response.choices[0].message.content, search_memory_time, response_time, context |
| #112 | |
| #113 | def process_data_file(self, file_path, run_id, output_file_path): |
| #114 | with open(file_path, "r") as f: |
| #115 | data = json.load(f) |
| #116 | |
| #117 | for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"): |
| #118 | qa = item["qa"] |
| #119 | |
| #120 | for question_item in tqdm( |
| #121 | qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False |
| #122 | ): |
| #123 | result = self.process_question(run_id, question_item, idx) |
| #124 | self.results[idx].append(result) |
| #125 | |
| #126 | # Save results after each question is processed |
| #127 | with open(output_file_path, "w") as f: |
| #128 | json.dump(self.results, f, indent=4) |
| #129 | |
| #130 | # Final save at the end |
| #131 | with open(output_file_path, "w") as f: |
| #132 | json.dump(self.results, f, indent=4) |
| #133 | |
| #134 | |
| #135 | if __name__ == "__main__": |
| #136 | parser = argparse.ArgumentParser() |
| #137 | parser.add_argument("--run_id", type=str, required=True) |
| #138 | args = parser.parse_args() |
| #139 | zep_search = ZepSearch() |
| #140 | zep_search.process_data_file("../../dataset/locomo10.json", args.run_id, "results/zep_search_results.json") |
| #141 |