repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | from abc import ABC, abstractmethod |
| #3 | |
| #4 | from mem0.memory.utils import format_entities |
| #5 | |
| #6 | try: |
| #7 | from rank_bm25 import BM25Okapi |
| #8 | except ImportError: |
| #9 | raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25") |
| #10 | |
| #11 | from mem0.graphs.tools import ( |
| #12 | DELETE_MEMORY_STRUCT_TOOL_GRAPH, |
| #13 | DELETE_MEMORY_TOOL_GRAPH, |
| #14 | EXTRACT_ENTITIES_STRUCT_TOOL, |
| #15 | EXTRACT_ENTITIES_TOOL, |
| #16 | RELATIONS_STRUCT_TOOL, |
| #17 | RELATIONS_TOOL, |
| #18 | ) |
| #19 | from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages |
| #20 | from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory |
| #21 | |
| #22 | logger = logging.getLogger(__name__) |
| #23 | |
| #24 | |
| #25 | class NeptuneBase(ABC): |
| #26 | """ |
| #27 | Abstract base class for neptune (neptune analytics and neptune db) calls using OpenCypher |
| #28 | to store/retrieve data |
| #29 | """ |
| #30 | |
| #31 | @staticmethod |
| #32 | def _create_embedding_model(config): |
| #33 | """ |
| #34 | :return: the Embedder model used for memory store |
| #35 | """ |
| #36 | return EmbedderFactory.create( |
| #37 | config.embedder.provider, |
| #38 | config.embedder.config, |
| #39 | {"enable_embeddings": True}, |
| #40 | ) |
| #41 | |
| #42 | @staticmethod |
| #43 | def _create_llm(config, llm_provider): |
| #44 | """ |
| #45 | :return: the llm model used for memory store |
| #46 | """ |
| #47 | return LlmFactory.create(llm_provider, config.llm.config) |
| #48 | |
| #49 | @staticmethod |
| #50 | def _create_vector_store(vector_store_provider, config): |
| #51 | """ |
| #52 | :param vector_store_provider: name of vector store |
| #53 | :param config: the vector_store configuration |
| #54 | :return: |
| #55 | """ |
| #56 | return VectorStoreFactory.create(vector_store_provider, config.vector_store.config) |
| #57 | |
| #58 | def add(self, data, filters): |
| #59 | """ |
| #60 | Adds data to the graph. |
| #61 | |
| #62 | Args: |
| #63 | data (str): The data to add to the graph. |
| #64 | filters (dict): A dictionary containing filters to be applied during the addition. |
| #65 | """ |
| #66 | entity_type_map = self._retrieve_nodes_from_data(data, filters) |
| #67 | to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map) |
| #68 | search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters) |
| #69 | to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters) |
| #70 | |
| #71 | deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"]) |
| #72 | added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map) |
| #73 | |
| #74 | return {"deleted_entities": deleted_entities, "added_entities": added_entities} |
| #75 | |
| #76 | def _retrieve_nodes_from_data(self, data, filters): |
| #77 | """ |
| #78 | Extract all entities mentioned in the query. |
| #79 | """ |
| #80 | _tools = [EXTRACT_ENTITIES_TOOL] |
| #81 | if self.llm_provider in ["azure_openai_structured", "openai_structured"]: |
| #82 | _tools = [EXTRACT_ENTITIES_STRUCT_TOOL] |
| #83 | search_results = self.llm.generate_response( |
| #84 | messages=[ |
| #85 | { |
| #86 | "role": "system", |
| #87 | "content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.", |
| #88 | }, |
| #89 | {"role": "user", "content": data}, |
| #90 | ], |
| #91 | tools=_tools, |
| #92 | ) |
| #93 | |
| #94 | entity_type_map = {} |
| #95 | |
| #96 | try: |
| #97 | for tool_call in search_results["tool_calls"]: |
| #98 | if tool_call["name"] != "extract_entities": |
| #99 | continue |
| #100 | for item in tool_call["arguments"]["entities"]: |
| #101 | entity_type_map[item["entity"]] = item["entity_type"] |
| #102 | except Exception as e: |
| #103 | logger.exception( |
| #104 | f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}" |
| #105 | ) |
| #106 | |
| #107 | entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()} |
| #108 | return entity_type_map |
| #109 | |
| #110 | def _establish_nodes_relations_from_data(self, data, filters, entity_type_map): |
| #111 | """ |
| #112 | Establish relations among the extracted nodes. |
| #113 | """ |
| #114 | if self.config.graph_store.custom_prompt: |
| #115 | messages = [ |
| #116 | { |
| #117 | "role": "system", |
| #118 | "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace( |
| #119 | "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}" |
| #120 | ), |
| #121 | }, |
| #122 | {"role": "user", "content": data}, |
| #123 | ] |
| #124 | else: |
| #125 | messages = [ |
| #126 | { |
| #127 | "role": "system", |
| #128 | "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]), |
| #129 | }, |
| #130 | { |
| #131 | "role": "user", |
| #132 | "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}", |
| #133 | }, |
| #134 | ] |
| #135 | |
| #136 | _tools = [RELATIONS_TOOL] |
| #137 | if self.llm_provider in ["azure_openai_structured", "openai_structured"]: |
| #138 | _tools = [RELATIONS_STRUCT_TOOL] |
| #139 | |
| #140 | extracted_entities = self.llm.generate_response( |
| #141 | messages=messages, |
| #142 | tools=_tools, |
| #143 | ) |
| #144 | |
| #145 | entities = [] |
| #146 | if extracted_entities["tool_calls"]: |
| #147 | entities = extracted_entities["tool_calls"][0]["arguments"]["entities"] |
| #148 | |
| #149 | entities = self._remove_spaces_from_entities(entities) |
| #150 | logger.debug(f"Extracted entities: {entities}") |
| #151 | return entities |
| #152 | |
| #153 | def _remove_spaces_from_entities(self, entity_list): |
| #154 | for item in entity_list: |
| #155 | item["source"] = item["source"].lower().replace(" ", "_") |
| #156 | item["relationship"] = item["relationship"].lower().replace(" ", "_") |
| #157 | item["destination"] = item["destination"].lower().replace(" ", "_") |
| #158 | return entity_list |
| #159 | |
| #160 | def _get_delete_entities_from_search_output(self, search_output, data, filters): |
| #161 | """ |
| #162 | Get the entities to be deleted from the search output. |
| #163 | """ |
| #164 | |
| #165 | search_output_string = format_entities(search_output) |
| #166 | system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"]) |
| #167 | |
| #168 | _tools = [DELETE_MEMORY_TOOL_GRAPH] |
| #169 | if self.llm_provider in ["azure_openai_structured", "openai_structured"]: |
| #170 | _tools = [ |
| #171 | DELETE_MEMORY_STRUCT_TOOL_GRAPH, |
| #172 | ] |
| #173 | |
| #174 | memory_updates = self.llm.generate_response( |
| #175 | messages=[ |
| #176 | {"role": "system", "content": system_prompt}, |
| #177 | {"role": "user", "content": user_prompt}, |
| #178 | ], |
| #179 | tools=_tools, |
| #180 | ) |
| #181 | |
| #182 | to_be_deleted = [] |
| #183 | for item in memory_updates["tool_calls"]: |
| #184 | if item["name"] == "delete_graph_memory": |
| #185 | to_be_deleted.append(item["arguments"]) |
| #186 | # in case if it is not in the correct format |
| #187 | to_be_deleted = self._remove_spaces_from_entities(to_be_deleted) |
| #188 | logger.debug(f"Deleted relationships: {to_be_deleted}") |
| #189 | return to_be_deleted |
| #190 | |
| #191 | def _delete_entities(self, to_be_deleted, user_id): |
| #192 | """ |
| #193 | Delete the entities from the graph. |
| #194 | """ |
| #195 | |
| #196 | results = [] |
| #197 | for item in to_be_deleted: |
| #198 | source = item["source"] |
| #199 | destination = item["destination"] |
| #200 | relationship = item["relationship"] |
| #201 | |
| #202 | # Delete the specific relationship between nodes |
| #203 | cypher, params = self._delete_entities_cypher(source, destination, relationship, user_id) |
| #204 | result = self.graph.query(cypher, params=params) |
| #205 | results.append(result) |
| #206 | return results |
| #207 | |
| #208 | @abstractmethod |
| #209 | def _delete_entities_cypher(self, source, destination, relationship, user_id): |
| #210 | """ |
| #211 | Returns the OpenCypher query and parameters for deleting entities in the graph DB |
| #212 | """ |
| #213 | |
| #214 | pass |
| #215 | |
| #216 | def _add_entities(self, to_be_added, user_id, entity_type_map): |
| #217 | """ |
| #218 | Add the new entities to the graph. Merge the nodes if they already exist. |
| #219 | """ |
| #220 | |
| #221 | results = [] |
| #222 | for item in to_be_added: |
| #223 | # entities |
| #224 | source = item["source"] |
| #225 | destination = item["destination"] |
| #226 | relationship = item["relationship"] |
| #227 | |
| #228 | # types |
| #229 | source_type = entity_type_map.get(source, "__User__") |
| #230 | destination_type = entity_type_map.get(destination, "__User__") |
| #231 | |
| #232 | # embeddings |
| #233 | source_embedding = self.embedding_model.embed(source) |
| #234 | dest_embedding = self.embedding_model.embed(destination) |
| #235 | |
| #236 | # search for the nodes with the closest embeddings |
| #237 | source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=self.threshold) |
| #238 | destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=self.threshold) |
| #239 | |
| #240 | cypher, params = self._add_entities_cypher( |
| #241 | source_node_search_result, |
| #242 | source, |
| #243 | source_embedding, |
| #244 | source_type, |
| #245 | destination_node_search_result, |
| #246 | destination, |
| #247 | dest_embedding, |
| #248 | destination_type, |
| #249 | relationship, |
| #250 | user_id, |
| #251 | ) |
| #252 | result = self.graph.query(cypher, params=params) |
| #253 | results.append(result) |
| #254 | return results |
| #255 | |
| #256 | def _add_entities_cypher( |
| #257 | self, |
| #258 | source_node_list, |
| #259 | source, |
| #260 | source_embedding, |
| #261 | source_type, |
| #262 | destination_node_list, |
| #263 | destination, |
| #264 | dest_embedding, |
| #265 | destination_type, |
| #266 | relationship, |
| #267 | user_id, |
| #268 | ): |
| #269 | """ |
| #270 | Returns the OpenCypher query and parameters for adding entities in the graph DB |
| #271 | """ |
| #272 | if not destination_node_list and source_node_list: |
| #273 | return self._add_entities_by_source_cypher( |
| #274 | source_node_list, |
| #275 | destination, |
| #276 | dest_embedding, |
| #277 | destination_type, |
| #278 | relationship, |
| #279 | user_id) |
| #280 | elif destination_node_list and not source_node_list: |
| #281 | return self._add_entities_by_destination_cypher( |
| #282 | source, |
| #283 | source_embedding, |
| #284 | source_type, |
| #285 | destination_node_list, |
| #286 | relationship, |
| #287 | user_id) |
| #288 | elif source_node_list and destination_node_list: |
| #289 | return self._add_relationship_entities_cypher( |
| #290 | source_node_list, |
| #291 | destination_node_list, |
| #292 | relationship, |
| #293 | user_id) |
| #294 | # else source_node_list and destination_node_list are empty |
| #295 | return self._add_new_entities_cypher( |
| #296 | source, |
| #297 | source_embedding, |
| #298 | source_type, |
| #299 | destination, |
| #300 | dest_embedding, |
| #301 | destination_type, |
| #302 | relationship, |
| #303 | user_id) |
| #304 | |
| #305 | @abstractmethod |
| #306 | def _add_entities_by_source_cypher( |
| #307 | self, |
| #308 | source_node_list, |
| #309 | destination, |
| #310 | dest_embedding, |
| #311 | destination_type, |
| #312 | relationship, |
| #313 | user_id, |
| #314 | ): |
| #315 | pass |
| #316 | |
| #317 | @abstractmethod |
| #318 | def _add_entities_by_destination_cypher( |
| #319 | self, |
| #320 | source, |
| #321 | source_embedding, |
| #322 | source_type, |
| #323 | destination_node_list, |
| #324 | relationship, |
| #325 | user_id, |
| #326 | ): |
| #327 | pass |
| #328 | |
| #329 | @abstractmethod |
| #330 | def _add_relationship_entities_cypher( |
| #331 | self, |
| #332 | source_node_list, |
| #333 | destination_node_list, |
| #334 | relationship, |
| #335 | user_id, |
| #336 | ): |
| #337 | pass |
| #338 | |
| #339 | @abstractmethod |
| #340 | def _add_new_entities_cypher( |
| #341 | self, |
| #342 | source, |
| #343 | source_embedding, |
| #344 | source_type, |
| #345 | destination, |
| #346 | dest_embedding, |
| #347 | destination_type, |
| #348 | relationship, |
| #349 | user_id, |
| #350 | ): |
| #351 | pass |
| #352 | |
| #353 | def search(self, query, filters, limit=100): |
| #354 | """ |
| #355 | Search for memories and related graph data. |
| #356 | |
| #357 | Args: |
| #358 | query (str): Query to search for. |
| #359 | filters (dict): A dictionary containing filters to be applied during the search. |
| #360 | limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100. |
| #361 | |
| #362 | Returns: |
| #363 | dict: A dictionary containing: |
| #364 | - "contexts": List of search results from the base data store. |
| #365 | - "entities": List of related graph data based on the query. |
| #366 | """ |
| #367 | |
| #368 | entity_type_map = self._retrieve_nodes_from_data(query, filters) |
| #369 | search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters) |
| #370 | |
| #371 | if not search_output: |
| #372 | return [] |
| #373 | |
| #374 | search_outputs_sequence = [ |
| #375 | [item["source"], item["relationship"], item["destination"]] for item in search_output |
| #376 | ] |
| #377 | bm25 = BM25Okapi(search_outputs_sequence) |
| #378 | |
| #379 | tokenized_query = query.split(" ") |
| #380 | reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5) |
| #381 | |
| #382 | search_results = [] |
| #383 | for item in reranked_results: |
| #384 | search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]}) |
| #385 | |
| #386 | return search_results |
| #387 | |
| #388 | def _search_source_node(self, source_embedding, user_id, threshold=0.9): |
| #389 | cypher, params = self._search_source_node_cypher(source_embedding, user_id, threshold) |
| #390 | result = self.graph.query(cypher, params=params) |
| #391 | return result |
| #392 | |
| #393 | @abstractmethod |
| #394 | def _search_source_node_cypher(self, source_embedding, user_id, threshold): |
| #395 | """ |
| #396 | Returns the OpenCypher query and parameters to search for source nodes |
| #397 | """ |
| #398 | pass |
| #399 | |
| #400 | def _search_destination_node(self, destination_embedding, user_id, threshold=0.9): |
| #401 | cypher, params = self._search_destination_node_cypher(destination_embedding, user_id, threshold) |
| #402 | result = self.graph.query(cypher, params=params) |
| #403 | return result |
| #404 | |
| #405 | @abstractmethod |
| #406 | def _search_destination_node_cypher(self, destination_embedding, user_id, threshold): |
| #407 | """ |
| #408 | Returns the OpenCypher query and parameters to search for destination nodes |
| #409 | """ |
| #410 | pass |
| #411 | |
| #412 | def delete_all(self, filters): |
| #413 | cypher, params = self._delete_all_cypher(filters) |
| #414 | self.graph.query(cypher, params=params) |
| #415 | |
| #416 | @abstractmethod |
| #417 | def _delete_all_cypher(self, filters): |
| #418 | """ |
| #419 | Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store |
| #420 | """ |
| #421 | pass |
| #422 | |
| #423 | def get_all(self, filters, limit=100): |
| #424 | """ |
| #425 | Retrieves all nodes and relationships from the graph database based on filtering criteria. |
| #426 | |
| #427 | Args: |
| #428 | filters (dict): A dictionary containing filters to be applied during the retrieval. |
| #429 | limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100. |
| #430 | Returns: |
| #431 | list: A list of dictionaries, each containing: |
| #432 | - 'contexts': The base data store response for each memory. |
| #433 | - 'entities': A list of strings representing the nodes and relationships |
| #434 | """ |
| #435 | |
| #436 | # return all nodes and relationships |
| #437 | query, params = self._get_all_cypher(filters, limit) |
| #438 | results = self.graph.query(query, params=params) |
| #439 | |
| #440 | final_results = [] |
| #441 | for result in results: |
| #442 | final_results.append( |
| #443 | { |
| #444 | "source": result["source"], |
| #445 | "relationship": result["relationship"], |
| #446 | "target": result["target"], |
| #447 | } |
| #448 | ) |
| #449 | |
| #450 | logger.debug(f"Retrieved {len(final_results)} relationships") |
| #451 | |
| #452 | return final_results |
| #453 | |
| #454 | @abstractmethod |
| #455 | def _get_all_cypher(self, filters, limit): |
| #456 | """ |
| #457 | Returns the OpenCypher query and parameters to get all edges/nodes in the memory store |
| #458 | """ |
| #459 | pass |
| #460 | |
| #461 | def _search_graph_db(self, node_list, filters, limit=100): |
| #462 | """ |
| #463 | Search similar nodes among and their respective incoming and outgoing relations. |
| #464 | """ |
| #465 | result_relations = [] |
| #466 | |
| #467 | for node in node_list: |
| #468 | n_embedding = self.embedding_model.embed(node) |
| #469 | cypher_query, params = self._search_graph_db_cypher(n_embedding, filters, limit) |
| #470 | ans = self.graph.query(cypher_query, params=params) |
| #471 | result_relations.extend(ans) |
| #472 | |
| #473 | return result_relations |
| #474 | |
| #475 | @abstractmethod |
| #476 | def _search_graph_db_cypher(self, n_embedding, filters, limit): |
| #477 | """ |
| #478 | Returns the OpenCypher query and parameters to search for similar nodes in the memory store |
| #479 | """ |
| #480 | pass |
| #481 | |
| #482 | # Reset is not defined in base.py |
| #483 | def reset(self): |
| #484 | """ |
| #485 | Reset the graph by clearing all nodes and relationships. |
| #486 | |
| #487 | link: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/neptune-graph/client/reset_graph.html |
| #488 | """ |
| #489 | |
| #490 | logger.warning("Clearing graph...") |
| #491 | graph_id = self.graph.graph_identifier |
| #492 | self.graph.client.reset_graph( |
| #493 | graphIdentifier=graph_id, |
| #494 | skipSnapshot=True, |
| #495 | ) |
| #496 | waiter = self.graph.client.get_waiter("graph_available") |
| #497 | waiter.wait(graphIdentifier=graph_id, WaiterConfig={"Delay": 10, "MaxAttempts": 60}) |
| #498 |