repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | |
| #3 | from mem0.memory.utils import format_entities |
| #4 | |
| #5 | try: |
| #6 | import kuzu |
| #7 | except ImportError: |
| #8 | raise ImportError("kuzu is not installed. Please install it using pip install kuzu") |
| #9 | |
| #10 | try: |
| #11 | from rank_bm25 import BM25Okapi |
| #12 | except ImportError: |
| #13 | raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25") |
| #14 | |
| #15 | from mem0.graphs.tools import ( |
| #16 | DELETE_MEMORY_STRUCT_TOOL_GRAPH, |
| #17 | DELETE_MEMORY_TOOL_GRAPH, |
| #18 | EXTRACT_ENTITIES_STRUCT_TOOL, |
| #19 | EXTRACT_ENTITIES_TOOL, |
| #20 | RELATIONS_STRUCT_TOOL, |
| #21 | RELATIONS_TOOL, |
| #22 | ) |
| #23 | from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages |
| #24 | from mem0.utils.factory import EmbedderFactory, LlmFactory |
| #25 | |
| #26 | logger = logging.getLogger(__name__) |
| #27 | |
| #28 | |
| #29 | class MemoryGraph: |
| #30 | def __init__(self, config): |
| #31 | self.config = config |
| #32 | |
| #33 | self.embedding_model = EmbedderFactory.create( |
| #34 | self.config.embedder.provider, |
| #35 | self.config.embedder.config, |
| #36 | self.config.vector_store.config, |
| #37 | ) |
| #38 | self.embedding_dims = self.embedding_model.config.embedding_dims |
| #39 | |
| #40 | if self.embedding_dims is None or self.embedding_dims <= 0: |
| #41 | raise ValueError(f"embedding_dims must be a positive integer. Given: {self.embedding_dims}") |
| #42 | |
| #43 | self.db = kuzu.Database(self.config.graph_store.config.db) |
| #44 | self.graph = kuzu.Connection(self.db) |
| #45 | |
| #46 | self.node_label = ":Entity" |
| #47 | self.rel_label = ":CONNECTED_TO" |
| #48 | self.kuzu_create_schema() |
| #49 | |
| #50 | # Default to openai if no specific provider is configured |
| #51 | self.llm_provider = "openai" |
| #52 | if self.config.llm and self.config.llm.provider: |
| #53 | self.llm_provider = self.config.llm.provider |
| #54 | if self.config.graph_store and self.config.graph_store.llm and self.config.graph_store.llm.provider: |
| #55 | self.llm_provider = self.config.graph_store.llm.provider |
| #56 | # Get LLM config with proper null checks |
| #57 | llm_config = None |
| #58 | if self.config.graph_store and self.config.graph_store.llm and hasattr(self.config.graph_store.llm, "config"): |
| #59 | llm_config = self.config.graph_store.llm.config |
| #60 | elif hasattr(self.config.llm, "config"): |
| #61 | llm_config = self.config.llm.config |
| #62 | self.llm = LlmFactory.create(self.llm_provider, llm_config) |
| #63 | |
| #64 | self.user_id = None |
| #65 | # Use threshold from graph_store config, default to 0.7 for backward compatibility |
| #66 | self.threshold = self.config.graph_store.threshold if hasattr(self.config.graph_store, 'threshold') else 0.7 |
| #67 | |
| #68 | def kuzu_create_schema(self): |
| #69 | self.kuzu_execute( |
| #70 | """ |
| #71 | CREATE NODE TABLE IF NOT EXISTS Entity( |
| #72 | id SERIAL PRIMARY KEY, |
| #73 | user_id STRING, |
| #74 | agent_id STRING, |
| #75 | run_id STRING, |
| #76 | name STRING, |
| #77 | mentions INT64, |
| #78 | created TIMESTAMP, |
| #79 | embedding FLOAT[]); |
| #80 | """ |
| #81 | ) |
| #82 | self.kuzu_execute( |
| #83 | """ |
| #84 | CREATE REL TABLE IF NOT EXISTS CONNECTED_TO( |
| #85 | FROM Entity TO Entity, |
| #86 | name STRING, |
| #87 | mentions INT64, |
| #88 | created TIMESTAMP, |
| #89 | updated TIMESTAMP |
| #90 | ); |
| #91 | """ |
| #92 | ) |
| #93 | |
| #94 | def kuzu_execute(self, query, parameters=None): |
| #95 | results = self.graph.execute(query, parameters) |
| #96 | return list(results.rows_as_dict()) |
| #97 | |
| #98 | def add(self, data, filters): |
| #99 | """ |
| #100 | Adds data to the graph. |
| #101 | |
| #102 | Args: |
| #103 | data (str): The data to add to the graph. |
| #104 | filters (dict): A dictionary containing filters to be applied during the addition. |
| #105 | """ |
| #106 | entity_type_map = self._retrieve_nodes_from_data(data, filters) |
| #107 | to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map) |
| #108 | search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters) |
| #109 | to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters) |
| #110 | |
| #111 | deleted_entities = self._delete_entities(to_be_deleted, filters) |
| #112 | added_entities = self._add_entities(to_be_added, filters, entity_type_map) |
| #113 | |
| #114 | return {"deleted_entities": deleted_entities, "added_entities": added_entities} |
| #115 | |
| #116 | def search(self, query, filters, limit=5): |
| #117 | """ |
| #118 | Search for memories and related graph data. |
| #119 | |
| #120 | Args: |
| #121 | query (str): Query to search for. |
| #122 | filters (dict): A dictionary containing filters to be applied during the search. |
| #123 | limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100. |
| #124 | |
| #125 | Returns: |
| #126 | dict: A dictionary containing: |
| #127 | - "contexts": List of search results from the base data store. |
| #128 | - "entities": List of related graph data based on the query. |
| #129 | """ |
| #130 | entity_type_map = self._retrieve_nodes_from_data(query, filters) |
| #131 | search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters) |
| #132 | |
| #133 | if not search_output: |
| #134 | return [] |
| #135 | |
| #136 | search_outputs_sequence = [ |
| #137 | [item["source"], item["relationship"], item["destination"]] for item in search_output |
| #138 | ] |
| #139 | bm25 = BM25Okapi(search_outputs_sequence) |
| #140 | |
| #141 | tokenized_query = query.split(" ") |
| #142 | reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=limit) |
| #143 | |
| #144 | search_results = [] |
| #145 | for item in reranked_results: |
| #146 | search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]}) |
| #147 | |
| #148 | logger.info(f"Returned {len(search_results)} search results") |
| #149 | |
| #150 | return search_results |
| #151 | |
| #152 | def delete_all(self, filters): |
| #153 | # Build node properties for filtering |
| #154 | node_props = ["user_id: $user_id"] |
| #155 | if filters.get("agent_id"): |
| #156 | node_props.append("agent_id: $agent_id") |
| #157 | if filters.get("run_id"): |
| #158 | node_props.append("run_id: $run_id") |
| #159 | node_props_str = ", ".join(node_props) |
| #160 | |
| #161 | cypher = f""" |
| #162 | MATCH (n {self.node_label} {{{node_props_str}}}) |
| #163 | DETACH DELETE n |
| #164 | """ |
| #165 | params = {"user_id": filters["user_id"]} |
| #166 | if filters.get("agent_id"): |
| #167 | params["agent_id"] = filters["agent_id"] |
| #168 | if filters.get("run_id"): |
| #169 | params["run_id"] = filters["run_id"] |
| #170 | self.kuzu_execute(cypher, parameters=params) |
| #171 | |
| #172 | def get_all(self, filters, limit=100): |
| #173 | """ |
| #174 | Retrieves all nodes and relationships from the graph database based on optional filtering criteria. |
| #175 | Args: |
| #176 | filters (dict): A dictionary containing filters to be applied during the retrieval. |
| #177 | limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100. |
| #178 | Returns: |
| #179 | list: A list of dictionaries, each containing: |
| #180 | - 'contexts': The base data store response for each memory. |
| #181 | - 'entities': A list of strings representing the nodes and relationships |
| #182 | """ |
| #183 | |
| #184 | params = { |
| #185 | "user_id": filters["user_id"], |
| #186 | "limit": limit, |
| #187 | } |
| #188 | # Build node properties based on filters |
| #189 | node_props = ["user_id: $user_id"] |
| #190 | if filters.get("agent_id"): |
| #191 | node_props.append("agent_id: $agent_id") |
| #192 | params["agent_id"] = filters["agent_id"] |
| #193 | if filters.get("run_id"): |
| #194 | node_props.append("run_id: $run_id") |
| #195 | params["run_id"] = filters["run_id"] |
| #196 | node_props_str = ", ".join(node_props) |
| #197 | |
| #198 | query = f""" |
| #199 | MATCH (n {self.node_label} {{{node_props_str}}})-[r]->(m {self.node_label} {{{node_props_str}}}) |
| #200 | RETURN |
| #201 | n.name AS source, |
| #202 | r.name AS relationship, |
| #203 | m.name AS target |
| #204 | LIMIT $limit |
| #205 | """ |
| #206 | results = self.kuzu_execute(query, parameters=params) |
| #207 | |
| #208 | final_results = [] |
| #209 | for result in results: |
| #210 | final_results.append( |
| #211 | { |
| #212 | "source": result["source"], |
| #213 | "relationship": result["relationship"], |
| #214 | "target": result["target"], |
| #215 | } |
| #216 | ) |
| #217 | |
| #218 | logger.info(f"Retrieved {len(final_results)} relationships") |
| #219 | |
| #220 | return final_results |
| #221 | |
| #222 | def _retrieve_nodes_from_data(self, data, filters): |
| #223 | """Extracts all the entities mentioned in the query.""" |
| #224 | _tools = [EXTRACT_ENTITIES_TOOL] |
| #225 | if self.llm_provider in ["azure_openai_structured", "openai_structured"]: |
| #226 | _tools = [EXTRACT_ENTITIES_STRUCT_TOOL] |
| #227 | search_results = self.llm.generate_response( |
| #228 | messages=[ |
| #229 | { |
| #230 | "role": "system", |
| #231 | "content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.", |
| #232 | }, |
| #233 | {"role": "user", "content": data}, |
| #234 | ], |
| #235 | tools=_tools, |
| #236 | ) |
| #237 | |
| #238 | entity_type_map = {} |
| #239 | |
| #240 | try: |
| #241 | for tool_call in search_results["tool_calls"]: |
| #242 | if tool_call["name"] != "extract_entities": |
| #243 | continue |
| #244 | for item in tool_call["arguments"]["entities"]: |
| #245 | entity_type_map[item["entity"]] = item["entity_type"] |
| #246 | except Exception as e: |
| #247 | logger.exception( |
| #248 | f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}" |
| #249 | ) |
| #250 | |
| #251 | entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()} |
| #252 | logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}") |
| #253 | return entity_type_map |
| #254 | |
| #255 | def _establish_nodes_relations_from_data(self, data, filters, entity_type_map): |
| #256 | """Establish relations among the extracted nodes.""" |
| #257 | |
| #258 | # Compose user identification string for prompt |
| #259 | user_identity = f"user_id: {filters['user_id']}" |
| #260 | if filters.get("agent_id"): |
| #261 | user_identity += f", agent_id: {filters['agent_id']}" |
| #262 | if filters.get("run_id"): |
| #263 | user_identity += f", run_id: {filters['run_id']}" |
| #264 | |
| #265 | if self.config.graph_store.custom_prompt: |
| #266 | system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity) |
| #267 | # Add the custom prompt line if configured |
| #268 | system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}") |
| #269 | messages = [ |
| #270 | {"role": "system", "content": system_content}, |
| #271 | {"role": "user", "content": data}, |
| #272 | ] |
| #273 | else: |
| #274 | system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity) |
| #275 | messages = [ |
| #276 | {"role": "system", "content": system_content}, |
| #277 | {"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"}, |
| #278 | ] |
| #279 | |
| #280 | _tools = [RELATIONS_TOOL] |
| #281 | if self.llm_provider in ["azure_openai_structured", "openai_structured"]: |
| #282 | _tools = [RELATIONS_STRUCT_TOOL] |
| #283 | |
| #284 | extracted_entities = self.llm.generate_response( |
| #285 | messages=messages, |
| #286 | tools=_tools, |
| #287 | ) |
| #288 | |
| #289 | entities = [] |
| #290 | if extracted_entities.get("tool_calls"): |
| #291 | entities = extracted_entities["tool_calls"][0].get("arguments", {}).get("entities", []) |
| #292 | |
| #293 | entities = self._remove_spaces_from_entities(entities) |
| #294 | logger.debug(f"Extracted entities: {entities}") |
| #295 | return entities |
| #296 | |
| #297 | def _search_graph_db(self, node_list, filters, limit=100, threshold=None): |
| #298 | """Search similar nodes among and their respective incoming and outgoing relations.""" |
| #299 | result_relations = [] |
| #300 | |
| #301 | params = { |
| #302 | "threshold": threshold if threshold else self.threshold, |
| #303 | "user_id": filters["user_id"], |
| #304 | "limit": limit, |
| #305 | } |
| #306 | # Build node properties for filtering |
| #307 | node_props = ["user_id: $user_id"] |
| #308 | if filters.get("agent_id"): |
| #309 | node_props.append("agent_id: $agent_id") |
| #310 | params["agent_id"] = filters["agent_id"] |
| #311 | if filters.get("run_id"): |
| #312 | node_props.append("run_id: $run_id") |
| #313 | params["run_id"] = filters["run_id"] |
| #314 | node_props_str = ", ".join(node_props) |
| #315 | |
| #316 | for node in node_list: |
| #317 | n_embedding = self.embedding_model.embed(node) |
| #318 | params["n_embedding"] = n_embedding |
| #319 | |
| #320 | results = [] |
| #321 | for match_fragment in [ |
| #322 | f"(n)-[r]->(m {self.node_label} {{{node_props_str}}}) WITH n as src, r, m as dst, similarity", |
| #323 | f"(m {self.node_label} {{{node_props_str}}})-[r]->(n) WITH m as src, r, n as dst, similarity" |
| #324 | ]: |
| #325 | results.extend(self.kuzu_execute( |
| #326 | f""" |
| #327 | MATCH (n {self.node_label} {{{node_props_str}}}) |
| #328 | WHERE n.embedding IS NOT NULL |
| #329 | WITH n, array_cosine_similarity(n.embedding, CAST($n_embedding,'FLOAT[{self.embedding_dims}]')) AS similarity |
| #330 | WHERE similarity >= CAST($threshold, 'DOUBLE') |
| #331 | MATCH {match_fragment} |
| #332 | RETURN |
| #333 | src.name AS source, |
| #334 | id(src) AS source_id, |
| #335 | r.name AS relationship, |
| #336 | id(r) AS relation_id, |
| #337 | dst.name AS destination, |
| #338 | id(dst) AS destination_id, |
| #339 | similarity |
| #340 | LIMIT $limit |
| #341 | """, |
| #342 | parameters=params)) |
| #343 | |
| #344 | # Kuzu does not support sort/limit over unions. Do it manually for now. |
| #345 | result_relations.extend(sorted(results, key=lambda x: x["similarity"], reverse=True)[:limit]) |
| #346 | |
| #347 | return result_relations |
| #348 | |
| #349 | def _get_delete_entities_from_search_output(self, search_output, data, filters): |
| #350 | """Get the entities to be deleted from the search output.""" |
| #351 | search_output_string = format_entities(search_output) |
| #352 | |
| #353 | # Compose user identification string for prompt |
| #354 | user_identity = f"user_id: {filters['user_id']}" |
| #355 | if filters.get("agent_id"): |
| #356 | user_identity += f", agent_id: {filters['agent_id']}" |
| #357 | if filters.get("run_id"): |
| #358 | user_identity += f", run_id: {filters['run_id']}" |
| #359 | |
| #360 | system_prompt, user_prompt = get_delete_messages(search_output_string, data, user_identity) |
| #361 | |
| #362 | _tools = [DELETE_MEMORY_TOOL_GRAPH] |
| #363 | if self.llm_provider in ["azure_openai_structured", "openai_structured"]: |
| #364 | _tools = [ |
| #365 | DELETE_MEMORY_STRUCT_TOOL_GRAPH, |
| #366 | ] |
| #367 | |
| #368 | memory_updates = self.llm.generate_response( |
| #369 | messages=[ |
| #370 | {"role": "system", "content": system_prompt}, |
| #371 | {"role": "user", "content": user_prompt}, |
| #372 | ], |
| #373 | tools=_tools, |
| #374 | ) |
| #375 | |
| #376 | to_be_deleted = [] |
| #377 | for item in memory_updates.get("tool_calls", []): |
| #378 | if item.get("name") == "delete_graph_memory": |
| #379 | to_be_deleted.append(item.get("arguments")) |
| #380 | # Clean entities formatting |
| #381 | to_be_deleted = self._remove_spaces_from_entities(to_be_deleted) |
| #382 | logger.debug(f"Deleted relationships: {to_be_deleted}") |
| #383 | return to_be_deleted |
| #384 | |
| #385 | def _delete_entities(self, to_be_deleted, filters): |
| #386 | """Delete the entities from the graph.""" |
| #387 | user_id = filters["user_id"] |
| #388 | agent_id = filters.get("agent_id", None) |
| #389 | run_id = filters.get("run_id", None) |
| #390 | results = [] |
| #391 | |
| #392 | for item in to_be_deleted: |
| #393 | source = item["source"] |
| #394 | destination = item["destination"] |
| #395 | relationship = item["relationship"] |
| #396 | |
| #397 | params = { |
| #398 | "source_name": source, |
| #399 | "dest_name": destination, |
| #400 | "user_id": user_id, |
| #401 | "relationship_name": relationship, |
| #402 | } |
| #403 | # Build node properties for filtering |
| #404 | source_props = ["name: $source_name", "user_id: $user_id"] |
| #405 | dest_props = ["name: $dest_name", "user_id: $user_id"] |
| #406 | if agent_id: |
| #407 | source_props.append("agent_id: $agent_id") |
| #408 | dest_props.append("agent_id: $agent_id") |
| #409 | params["agent_id"] = agent_id |
| #410 | if run_id: |
| #411 | source_props.append("run_id: $run_id") |
| #412 | dest_props.append("run_id: $run_id") |
| #413 | params["run_id"] = run_id |
| #414 | source_props_str = ", ".join(source_props) |
| #415 | dest_props_str = ", ".join(dest_props) |
| #416 | |
| #417 | # Delete the specific relationship between nodes |
| #418 | cypher = f""" |
| #419 | MATCH (n {self.node_label} {{{source_props_str}}}) |
| #420 | -[r {self.rel_label} {{name: $relationship_name}}]-> |
| #421 | (m {self.node_label} {{{dest_props_str}}}) |
| #422 | DELETE r |
| #423 | RETURN |
| #424 | n.name AS source, |
| #425 | r.name AS relationship, |
| #426 | m.name AS target |
| #427 | """ |
| #428 | |
| #429 | result = self.kuzu_execute(cypher, parameters=params) |
| #430 | results.append(result) |
| #431 | |
| #432 | return results |
| #433 | |
| #434 | def _add_entities(self, to_be_added, filters, entity_type_map): |
| #435 | """Add the new entities to the graph. Merge the nodes if they already exist.""" |
| #436 | user_id = filters["user_id"] |
| #437 | agent_id = filters.get("agent_id", None) |
| #438 | run_id = filters.get("run_id", None) |
| #439 | results = [] |
| #440 | for item in to_be_added: |
| #441 | # entities |
| #442 | source = item["source"] |
| #443 | source_label = self.node_label |
| #444 | |
| #445 | destination = item["destination"] |
| #446 | destination_label = self.node_label |
| #447 | |
| #448 | relationship = item["relationship"] |
| #449 | relationship_label = self.rel_label |
| #450 | |
| #451 | # embeddings |
| #452 | source_embedding = self.embedding_model.embed(source) |
| #453 | dest_embedding = self.embedding_model.embed(destination) |
| #454 | |
| #455 | # search for the nodes with the closest embeddings |
| #456 | source_node_search_result = self._search_source_node(source_embedding, filters, threshold=self.threshold) |
| #457 | destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=self.threshold) |
| #458 | |
| #459 | if not destination_node_search_result and source_node_search_result: |
| #460 | params = { |
| #461 | "table_id": source_node_search_result[0]["id"]["table"], |
| #462 | "offset_id": source_node_search_result[0]["id"]["offset"], |
| #463 | "destination_name": destination, |
| #464 | "destination_embedding": dest_embedding, |
| #465 | "relationship_name": relationship, |
| #466 | "user_id": user_id, |
| #467 | } |
| #468 | # Build source MERGE properties |
| #469 | merge_props = ["name: $destination_name", "user_id: $user_id"] |
| #470 | if agent_id: |
| #471 | merge_props.append("agent_id: $agent_id") |
| #472 | params["agent_id"] = agent_id |
| #473 | if run_id: |
| #474 | merge_props.append("run_id: $run_id") |
| #475 | params["run_id"] = run_id |
| #476 | merge_props_str = ", ".join(merge_props) |
| #477 | |
| #478 | cypher = f""" |
| #479 | MATCH (source) |
| #480 | WHERE id(source) = internal_id($table_id, $offset_id) |
| #481 | SET source.mentions = coalesce(source.mentions, 0) + 1 |
| #482 | WITH source |
| #483 | MERGE (destination {destination_label} {{{merge_props_str}}}) |
| #484 | ON CREATE SET |
| #485 | destination.created = current_timestamp(), |
| #486 | destination.mentions = 1, |
| #487 | destination.embedding = CAST($destination_embedding,'FLOAT[{self.embedding_dims}]') |
| #488 | ON MATCH SET |
| #489 | destination.mentions = coalesce(destination.mentions, 0) + 1, |
| #490 | destination.embedding = CAST($destination_embedding,'FLOAT[{self.embedding_dims}]') |
| #491 | WITH source, destination |
| #492 | MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination) |
| #493 | ON CREATE SET |
| #494 | r.created = current_timestamp(), |
| #495 | r.mentions = 1 |
| #496 | ON MATCH SET |
| #497 | r.mentions = coalesce(r.mentions, 0) + 1 |
| #498 | RETURN |
| #499 | source.name AS source, |
| #500 | r.name AS relationship, |
| #501 | destination.name AS target |
| #502 | """ |
| #503 | elif destination_node_search_result and not source_node_search_result: |
| #504 | params = { |
| #505 | "table_id": destination_node_search_result[0]["id"]["table"], |
| #506 | "offset_id": destination_node_search_result[0]["id"]["offset"], |
| #507 | "source_name": source, |
| #508 | "source_embedding": source_embedding, |
| #509 | "user_id": user_id, |
| #510 | "relationship_name": relationship, |
| #511 | } |
| #512 | # Build source MERGE properties |
| #513 | merge_props = ["name: $source_name", "user_id: $user_id"] |
| #514 | if agent_id: |
| #515 | merge_props.append("agent_id: $agent_id") |
| #516 | params["agent_id"] = agent_id |
| #517 | if run_id: |
| #518 | merge_props.append("run_id: $run_id") |
| #519 | params["run_id"] = run_id |
| #520 | merge_props_str = ", ".join(merge_props) |
| #521 | |
| #522 | cypher = f""" |
| #523 | MATCH (destination) |
| #524 | WHERE id(destination) = internal_id($table_id, $offset_id) |
| #525 | SET destination.mentions = coalesce(destination.mentions, 0) + 1 |
| #526 | WITH destination |
| #527 | MERGE (source {source_label} {{{merge_props_str}}}) |
| #528 | ON CREATE SET |
| #529 | source.created = current_timestamp(), |
| #530 | source.mentions = 1, |
| #531 | source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]') |
| #532 | ON MATCH SET |
| #533 | source.mentions = coalesce(source.mentions, 0) + 1, |
| #534 | source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]') |
| #535 | WITH source, destination |
| #536 | MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination) |
| #537 | ON CREATE SET |
| #538 | r.created = current_timestamp(), |
| #539 | r.mentions = 1 |
| #540 | ON MATCH SET |
| #541 | r.mentions = coalesce(r.mentions, 0) + 1 |
| #542 | RETURN |
| #543 | source.name AS source, |
| #544 | r.name AS relationship, |
| #545 | destination.name AS target |
| #546 | """ |
| #547 | elif source_node_search_result and destination_node_search_result: |
| #548 | cypher = f""" |
| #549 | MATCH (source) |
| #550 | WHERE id(source) = internal_id($src_table, $src_offset) |
| #551 | SET source.mentions = coalesce(source.mentions, 0) + 1 |
| #552 | WITH source |
| #553 | MATCH (destination) |
| #554 | WHERE id(destination) = internal_id($dst_table, $dst_offset) |
| #555 | SET destination.mentions = coalesce(destination.mentions, 0) + 1 |
| #556 | MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination) |
| #557 | ON CREATE SET |
| #558 | r.created = current_timestamp(), |
| #559 | r.updated = current_timestamp(), |
| #560 | r.mentions = 1 |
| #561 | ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1 |
| #562 | RETURN |
| #563 | source.name AS source, |
| #564 | r.name AS relationship, |
| #565 | destination.name AS target |
| #566 | """ |
| #567 | |
| #568 | params = { |
| #569 | "src_table": source_node_search_result[0]["id"]["table"], |
| #570 | "src_offset": source_node_search_result[0]["id"]["offset"], |
| #571 | "dst_table": destination_node_search_result[0]["id"]["table"], |
| #572 | "dst_offset": destination_node_search_result[0]["id"]["offset"], |
| #573 | "relationship_name": relationship, |
| #574 | } |
| #575 | else: |
| #576 | params = { |
| #577 | "source_name": source, |
| #578 | "dest_name": destination, |
| #579 | "relationship_name": relationship, |
| #580 | "source_embedding": source_embedding, |
| #581 | "dest_embedding": dest_embedding, |
| #582 | "user_id": user_id, |
| #583 | } |
| #584 | # Build dynamic MERGE props for both source and destination |
| #585 | source_props = ["name: $source_name", "user_id: $user_id"] |
| #586 | dest_props = ["name: $dest_name", "user_id: $user_id"] |
| #587 | if agent_id: |
| #588 | source_props.append("agent_id: $agent_id") |
| #589 | dest_props.append("agent_id: $agent_id") |
| #590 | params["agent_id"] = agent_id |
| #591 | if run_id: |
| #592 | source_props.append("run_id: $run_id") |
| #593 | dest_props.append("run_id: $run_id") |
| #594 | params["run_id"] = run_id |
| #595 | source_props_str = ", ".join(source_props) |
| #596 | dest_props_str = ", ".join(dest_props) |
| #597 | |
| #598 | cypher = f""" |
| #599 | MERGE (source {source_label} {{{source_props_str}}}) |
| #600 | ON CREATE SET |
| #601 | source.created = current_timestamp(), |
| #602 | source.mentions = 1, |
| #603 | source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]') |
| #604 | ON MATCH SET |
| #605 | source.mentions = coalesce(source.mentions, 0) + 1, |
| #606 | source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]') |
| #607 | WITH source |
| #608 | MERGE (destination {destination_label} {{{dest_props_str}}}) |
| #609 | ON CREATE SET |
| #610 | destination.created = current_timestamp(), |
| #611 | destination.mentions = 1, |
| #612 | destination.embedding = CAST($dest_embedding,'FLOAT[{self.embedding_dims}]') |
| #613 | ON MATCH SET |
| #614 | destination.mentions = coalesce(destination.mentions, 0) + 1, |
| #615 | destination.embedding = CAST($dest_embedding,'FLOAT[{self.embedding_dims}]') |
| #616 | WITH source, destination |
| #617 | MERGE (source)-[rel {relationship_label} {{name: $relationship_name}}]->(destination) |
| #618 | ON CREATE SET |
| #619 | rel.created = current_timestamp(), |
| #620 | rel.mentions = 1 |
| #621 | ON MATCH SET |
| #622 | rel.mentions = coalesce(rel.mentions, 0) + 1 |
| #623 | RETURN |
| #624 | source.name AS source, |
| #625 | rel.name AS relationship, |
| #626 | destination.name AS target |
| #627 | """ |
| #628 | |
| #629 | result = self.kuzu_execute(cypher, parameters=params) |
| #630 | results.append(result) |
| #631 | |
| #632 | return results |
| #633 | |
| #634 | def _remove_spaces_from_entities(self, entity_list): |
| #635 | for item in entity_list: |
| #636 | item["source"] = item["source"].lower().replace(" ", "_") |
| #637 | item["relationship"] = item["relationship"].lower().replace(" ", "_") |
| #638 | item["destination"] = item["destination"].lower().replace(" ", "_") |
| #639 | return entity_list |
| #640 | |
| #641 | def _search_source_node(self, source_embedding, filters, threshold=0.9): |
| #642 | params = { |
| #643 | "source_embedding": source_embedding, |
| #644 | "user_id": filters["user_id"], |
| #645 | "threshold": threshold, |
| #646 | } |
| #647 | where_conditions = ["source_candidate.embedding IS NOT NULL", "source_candidate.user_id = $user_id"] |
| #648 | if filters.get("agent_id"): |
| #649 | where_conditions.append("source_candidate.agent_id = $agent_id") |
| #650 | params["agent_id"] = filters["agent_id"] |
| #651 | if filters.get("run_id"): |
| #652 | where_conditions.append("source_candidate.run_id = $run_id") |
| #653 | params["run_id"] = filters["run_id"] |
| #654 | where_clause = " AND ".join(where_conditions) |
| #655 | |
| #656 | cypher = f""" |
| #657 | MATCH (source_candidate {self.node_label}) |
| #658 | WHERE {where_clause} |
| #659 | |
| #660 | WITH source_candidate, |
| #661 | array_cosine_similarity(source_candidate.embedding, CAST($source_embedding,'FLOAT[{self.embedding_dims}]')) AS source_similarity |
| #662 | |
| #663 | WHERE source_similarity >= $threshold |
| #664 | |
| #665 | WITH source_candidate, source_similarity |
| #666 | ORDER BY source_similarity DESC |
| #667 | LIMIT 2 |
| #668 | |
| #669 | RETURN id(source_candidate) as id, source_similarity |
| #670 | """ |
| #671 | |
| #672 | return self.kuzu_execute(cypher, parameters=params) |
| #673 | |
| #674 | def _search_destination_node(self, destination_embedding, filters, threshold=0.9): |
| #675 | params = { |
| #676 | "destination_embedding": destination_embedding, |
| #677 | "user_id": filters["user_id"], |
| #678 | "threshold": threshold, |
| #679 | } |
| #680 | where_conditions = ["destination_candidate.embedding IS NOT NULL", "destination_candidate.user_id = $user_id"] |
| #681 | if filters.get("agent_id"): |
| #682 | where_conditions.append("destination_candidate.agent_id = $agent_id") |
| #683 | params["agent_id"] = filters["agent_id"] |
| #684 | if filters.get("run_id"): |
| #685 | where_conditions.append("destination_candidate.run_id = $run_id") |
| #686 | params["run_id"] = filters["run_id"] |
| #687 | where_clause = " AND ".join(where_conditions) |
| #688 | |
| #689 | cypher = f""" |
| #690 | MATCH (destination_candidate {self.node_label}) |
| #691 | WHERE {where_clause} |
| #692 | |
| #693 | WITH destination_candidate, |
| #694 | array_cosine_similarity(destination_candidate.embedding, CAST($destination_embedding,'FLOAT[{self.embedding_dims}]')) AS destination_similarity |
| #695 | |
| #696 | WHERE destination_similarity >= $threshold |
| #697 | |
| #698 | WITH destination_candidate, destination_similarity |
| #699 | ORDER BY destination_similarity DESC |
| #700 | LIMIT 2 |
| #701 | |
| #702 | RETURN id(destination_candidate) as id, destination_similarity |
| #703 | """ |
| #704 | |
| #705 | return self.kuzu_execute(cypher, parameters=params) |
| #706 | |
| #707 | # Reset is not defined in base.py |
| #708 | def reset(self): |
| #709 | """Reset the graph by clearing all nodes and relationships.""" |
| #710 | logger.warning("Clearing graph...") |
| #711 | cypher_query = """ |
| #712 | MATCH (n) DETACH DELETE n |
| #713 | """ |
| #714 | return self.kuzu_execute(cypher_query) |
| #715 |