repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import json |
| #2 | import multiprocessing as mp |
| #3 | import os |
| #4 | import time |
| #5 | from collections import defaultdict |
| #6 | |
| #7 | from dotenv import load_dotenv |
| #8 | from jinja2 import Template |
| #9 | from langgraph.checkpoint.memory import MemorySaver |
| #10 | from langgraph.prebuilt import create_react_agent |
| #11 | from langgraph.store.memory import InMemoryStore |
| #12 | from langgraph.utils.config import get_store |
| #13 | from langmem import create_manage_memory_tool, create_search_memory_tool |
| #14 | from openai import OpenAI |
| #15 | from prompts import ANSWER_PROMPT |
| #16 | from tqdm import tqdm |
| #17 | |
| #18 | load_dotenv() |
| #19 | |
| #20 | client = OpenAI() |
| #21 | |
| #22 | ANSWER_PROMPT_TEMPLATE = Template(ANSWER_PROMPT) |
| #23 | |
| #24 | |
| #25 | def get_answer(question, speaker_1_user_id, speaker_1_memories, speaker_2_user_id, speaker_2_memories): |
| #26 | prompt = ANSWER_PROMPT_TEMPLATE.render( |
| #27 | question=question, |
| #28 | speaker_1_user_id=speaker_1_user_id, |
| #29 | speaker_1_memories=speaker_1_memories, |
| #30 | speaker_2_user_id=speaker_2_user_id, |
| #31 | speaker_2_memories=speaker_2_memories, |
| #32 | ) |
| #33 | |
| #34 | t1 = time.time() |
| #35 | response = client.chat.completions.create( |
| #36 | model=os.getenv("MODEL"), messages=[{"role": "system", "content": prompt}], temperature=0.0 |
| #37 | ) |
| #38 | t2 = time.time() |
| #39 | return response.choices[0].message.content, t2 - t1 |
| #40 | |
| #41 | |
| #42 | def prompt(state): |
| #43 | """Prepare the messages for the LLM.""" |
| #44 | store = get_store() |
| #45 | memories = store.search( |
| #46 | ("memories",), |
| #47 | query=state["messages"][-1].content, |
| #48 | ) |
| #49 | system_msg = f"""You are a helpful assistant. |
| #50 | |
| #51 | ## Memories |
| #52 | <memories> |
| #53 | {memories} |
| #54 | </memories> |
| #55 | """ |
| #56 | return [{"role": "system", "content": system_msg}, *state["messages"]] |
| #57 | |
| #58 | |
| #59 | class LangMem: |
| #60 | def __init__( |
| #61 | self, |
| #62 | ): |
| #63 | self.store = InMemoryStore( |
| #64 | index={ |
| #65 | "dims": 1536, |
| #66 | "embed": f"openai:{os.getenv('EMBEDDING_MODEL')}", |
| #67 | } |
| #68 | ) |
| #69 | self.checkpointer = MemorySaver() # Checkpoint graph state |
| #70 | |
| #71 | self.agent = create_react_agent( |
| #72 | f"openai:{os.getenv('MODEL')}", |
| #73 | prompt=prompt, |
| #74 | tools=[ |
| #75 | create_manage_memory_tool(namespace=("memories",)), |
| #76 | create_search_memory_tool(namespace=("memories",)), |
| #77 | ], |
| #78 | store=self.store, |
| #79 | checkpointer=self.checkpointer, |
| #80 | ) |
| #81 | |
| #82 | def add_memory(self, message, config): |
| #83 | return self.agent.invoke({"messages": [{"role": "user", "content": message}]}, config=config) |
| #84 | |
| #85 | def search_memory(self, query, config): |
| #86 | try: |
| #87 | t1 = time.time() |
| #88 | response = self.agent.invoke({"messages": [{"role": "user", "content": query}]}, config=config) |
| #89 | t2 = time.time() |
| #90 | return response["messages"][-1].content, t2 - t1 |
| #91 | except Exception as e: |
| #92 | print(f"Error in search_memory: {e}") |
| #93 | return "", t2 - t1 |
| #94 | |
| #95 | |
| #96 | class LangMemManager: |
| #97 | def __init__(self, dataset_path): |
| #98 | self.dataset_path = dataset_path |
| #99 | with open(self.dataset_path, "r") as f: |
| #100 | self.data = json.load(f) |
| #101 | |
| #102 | def process_all_conversations(self, output_file_path): |
| #103 | OUTPUT = defaultdict(list) |
| #104 | |
| #105 | # Process conversations in parallel with multiple workers |
| #106 | def process_conversation(key_value_pair): |
| #107 | key, value = key_value_pair |
| #108 | result = defaultdict(list) |
| #109 | |
| #110 | chat_history = value["conversation"] |
| #111 | questions = value["question"] |
| #112 | |
| #113 | agent1 = LangMem() |
| #114 | agent2 = LangMem() |
| #115 | config = {"configurable": {"thread_id": f"thread-{key}"}} |
| #116 | speakers = set() |
| #117 | |
| #118 | # Identify speakers |
| #119 | for conv in chat_history: |
| #120 | speakers.add(conv["speaker"]) |
| #121 | |
| #122 | if len(speakers) != 2: |
| #123 | raise ValueError(f"Expected 2 speakers, got {len(speakers)}") |
| #124 | |
| #125 | speaker1 = list(speakers)[0] |
| #126 | speaker2 = list(speakers)[1] |
| #127 | |
| #128 | # Add memories for each message |
| #129 | for conv in tqdm(chat_history, desc=f"Processing messages {key}", leave=False): |
| #130 | message = f"{conv['timestamp']} | {conv['speaker']}: {conv['text']}" |
| #131 | if conv["speaker"] == speaker1: |
| #132 | agent1.add_memory(message, config) |
| #133 | elif conv["speaker"] == speaker2: |
| #134 | agent2.add_memory(message, config) |
| #135 | else: |
| #136 | raise ValueError(f"Expected speaker1 or speaker2, got {conv['speaker']}") |
| #137 | |
| #138 | # Process questions |
| #139 | for q in tqdm(questions, desc=f"Processing questions {key}", leave=False): |
| #140 | category = q["category"] |
| #141 | |
| #142 | if int(category) == 5: |
| #143 | continue |
| #144 | |
| #145 | answer = q["answer"] |
| #146 | question = q["question"] |
| #147 | response1, speaker1_memory_time = agent1.search_memory(question, config) |
| #148 | response2, speaker2_memory_time = agent2.search_memory(question, config) |
| #149 | |
| #150 | generated_answer, response_time = get_answer(question, speaker1, response1, speaker2, response2) |
| #151 | |
| #152 | result[key].append( |
| #153 | { |
| #154 | "question": question, |
| #155 | "answer": answer, |
| #156 | "response1": response1, |
| #157 | "response2": response2, |
| #158 | "category": category, |
| #159 | "speaker1_memory_time": speaker1_memory_time, |
| #160 | "speaker2_memory_time": speaker2_memory_time, |
| #161 | "response_time": response_time, |
| #162 | "response": generated_answer, |
| #163 | } |
| #164 | ) |
| #165 | |
| #166 | return result |
| #167 | |
| #168 | # Use multiprocessing to process conversations in parallel |
| #169 | with mp.Pool(processes=10) as pool: |
| #170 | results = list( |
| #171 | tqdm( |
| #172 | pool.imap(process_conversation, list(self.data.items())), |
| #173 | total=len(self.data), |
| #174 | desc="Processing conversations", |
| #175 | ) |
| #176 | ) |
| #177 | |
| #178 | # Combine results from all workers |
| #179 | for result in results: |
| #180 | for key, items in result.items(): |
| #181 | OUTPUT[key].extend(items) |
| #182 | |
| #183 | # Save final results |
| #184 | with open(output_file_path, "w") as f: |
| #185 | json.dump(OUTPUT, f, indent=4) |
| #186 |