repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | import subprocess |
| #3 | import sys |
| #4 | import threading |
| #5 | from typing import List, Optional, Union |
| #6 | |
| #7 | import httpx |
| #8 | |
| #9 | import mem0 |
| #10 | |
| #11 | try: |
| #12 | import litellm |
| #13 | except ImportError: |
| #14 | try: |
| #15 | subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"]) |
| #16 | import litellm |
| #17 | except subprocess.CalledProcessError: |
| #18 | print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.") |
| #19 | sys.exit(1) |
| #20 | |
| #21 | from mem0 import Memory, MemoryClient |
| #22 | from mem0.configs.prompts import MEMORY_ANSWER_PROMPT |
| #23 | from mem0.memory.telemetry import capture_client_event, capture_event |
| #24 | |
| #25 | logger = logging.getLogger(__name__) |
| #26 | |
| #27 | |
| #28 | class Mem0: |
| #29 | def __init__( |
| #30 | self, |
| #31 | config: Optional[dict] = None, |
| #32 | api_key: Optional[str] = None, |
| #33 | host: Optional[str] = None, |
| #34 | ): |
| #35 | if api_key: |
| #36 | self.mem0_client = MemoryClient(api_key, host) |
| #37 | else: |
| #38 | self.mem0_client = Memory.from_config(config) if config else Memory() |
| #39 | |
| #40 | self.chat = Chat(self.mem0_client) |
| #41 | |
| #42 | |
| #43 | class Chat: |
| #44 | def __init__(self, mem0_client): |
| #45 | self.completions = Completions(mem0_client) |
| #46 | |
| #47 | |
| #48 | class Completions: |
| #49 | def __init__(self, mem0_client): |
| #50 | self.mem0_client = mem0_client |
| #51 | |
| #52 | def create( |
| #53 | self, |
| #54 | model: str, |
| #55 | messages: List = [], |
| #56 | # Mem0 arguments |
| #57 | user_id: Optional[str] = None, |
| #58 | agent_id: Optional[str] = None, |
| #59 | run_id: Optional[str] = None, |
| #60 | metadata: Optional[dict] = None, |
| #61 | filters: Optional[dict] = None, |
| #62 | limit: Optional[int] = 10, |
| #63 | # LLM arguments |
| #64 | timeout: Optional[Union[float, str, httpx.Timeout]] = None, |
| #65 | temperature: Optional[float] = None, |
| #66 | top_p: Optional[float] = None, |
| #67 | n: Optional[int] = None, |
| #68 | stream: Optional[bool] = None, |
| #69 | stream_options: Optional[dict] = None, |
| #70 | stop=None, |
| #71 | max_tokens: Optional[int] = None, |
| #72 | presence_penalty: Optional[float] = None, |
| #73 | frequency_penalty: Optional[float] = None, |
| #74 | logit_bias: Optional[dict] = None, |
| #75 | user: Optional[str] = None, |
| #76 | # openai v1.0+ new params |
| #77 | response_format: Optional[dict] = None, |
| #78 | seed: Optional[int] = None, |
| #79 | tools: Optional[List] = None, |
| #80 | tool_choice: Optional[Union[str, dict]] = None, |
| #81 | logprobs: Optional[bool] = None, |
| #82 | top_logprobs: Optional[int] = None, |
| #83 | parallel_tool_calls: Optional[bool] = None, |
| #84 | deployment_id=None, |
| #85 | extra_headers: Optional[dict] = None, |
| #86 | # soon to be deprecated params by OpenAI |
| #87 | functions: Optional[List] = None, |
| #88 | function_call: Optional[str] = None, |
| #89 | # set api_base, api_version, api_key |
| #90 | base_url: Optional[str] = None, |
| #91 | api_version: Optional[str] = None, |
| #92 | api_key: Optional[str] = None, |
| #93 | model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. |
| #94 | ): |
| #95 | if not any([user_id, agent_id, run_id]): |
| #96 | raise ValueError("One of user_id, agent_id, run_id must be provided") |
| #97 | |
| #98 | if not litellm.supports_function_calling(model): |
| #99 | raise ValueError( |
| #100 | f"Model '{model}' does not support function calling. Please use a model that supports function calling." |
| #101 | ) |
| #102 | |
| #103 | prepared_messages = self._prepare_messages(messages) |
| #104 | if prepared_messages[-1]["role"] == "user": |
| #105 | self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters) |
| #106 | relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit) |
| #107 | logger.debug(f"Retrieved {len(relevant_memories)} relevant memories") |
| #108 | prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories) |
| #109 | |
| #110 | response = litellm.completion( |
| #111 | model=model, |
| #112 | messages=prepared_messages, |
| #113 | temperature=temperature, |
| #114 | top_p=top_p, |
| #115 | n=n, |
| #116 | timeout=timeout, |
| #117 | stream=stream, |
| #118 | stream_options=stream_options, |
| #119 | stop=stop, |
| #120 | max_tokens=max_tokens, |
| #121 | presence_penalty=presence_penalty, |
| #122 | frequency_penalty=frequency_penalty, |
| #123 | logit_bias=logit_bias, |
| #124 | user=user, |
| #125 | response_format=response_format, |
| #126 | seed=seed, |
| #127 | tools=tools, |
| #128 | tool_choice=tool_choice, |
| #129 | logprobs=logprobs, |
| #130 | top_logprobs=top_logprobs, |
| #131 | parallel_tool_calls=parallel_tool_calls, |
| #132 | deployment_id=deployment_id, |
| #133 | extra_headers=extra_headers, |
| #134 | functions=functions, |
| #135 | function_call=function_call, |
| #136 | base_url=base_url, |
| #137 | api_version=api_version, |
| #138 | api_key=api_key, |
| #139 | model_list=model_list, |
| #140 | ) |
| #141 | if isinstance(self.mem0_client, Memory): |
| #142 | capture_event("mem0.chat.create", self.mem0_client) |
| #143 | else: |
| #144 | capture_client_event("mem0.chat.create", self.mem0_client) |
| #145 | return response |
| #146 | |
| #147 | def _prepare_messages(self, messages: List[dict]) -> List[dict]: |
| #148 | if not messages or messages[0]["role"] != "system": |
| #149 | return [{"role": "system", "content": MEMORY_ANSWER_PROMPT}] + messages |
| #150 | return messages |
| #151 | |
| #152 | def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters): |
| #153 | def add_task(): |
| #154 | logger.debug("Adding to memory asynchronously") |
| #155 | self.mem0_client.add( |
| #156 | messages=messages, |
| #157 | user_id=user_id, |
| #158 | agent_id=agent_id, |
| #159 | run_id=run_id, |
| #160 | metadata=metadata, |
| #161 | filters=filters, |
| #162 | ) |
| #163 | |
| #164 | threading.Thread(target=add_task, daemon=True).start() |
| #165 | |
| #166 | def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit): |
| #167 | # Currently, only pass the last 6 messages to the search API to prevent long query |
| #168 | message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:] |
| #169 | # TODO: Make it better by summarizing the past conversation |
| #170 | return self.mem0_client.search( |
| #171 | query="\n".join(message_input), |
| #172 | user_id=user_id, |
| #173 | agent_id=agent_id, |
| #174 | run_id=run_id, |
| #175 | filters=filters, |
| #176 | limit=limit, |
| #177 | ) |
| #178 | |
| #179 | def _format_query_with_memories(self, messages, relevant_memories): |
| #180 | # Check if self.mem0_client is an instance of Memory or MemoryClient |
| #181 | |
| #182 | entities = [] |
| #183 | if isinstance(self.mem0_client, mem0.memory.main.Memory): |
| #184 | memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"]) |
| #185 | if relevant_memories.get("relations"): |
| #186 | entities = [entity for entity in relevant_memories["relations"]] |
| #187 | elif isinstance(self.mem0_client, mem0.client.main.MemoryClient): |
| #188 | memories_text = "\n".join(memory["memory"] for memory in relevant_memories) |
| #189 | return f"- Relevant Memories/Facts: {memories_text}\n\n- Entities: {entities}\n\n- User Question: {messages[-1]['content']}" |
| #190 |