repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import unittest |
| #2 | from unittest.mock import MagicMock, patch |
| #3 | import pytest |
| #4 | from mem0.graphs.neptune.neptunedb import MemoryGraph |
| #5 | from mem0.graphs.neptune.base import NeptuneBase |
| #6 | |
| #7 | |
| #8 | class TestNeptuneMemory(unittest.TestCase): |
| #9 | """Test suite for the Neptune Memory implementation.""" |
| #10 | |
| #11 | def setUp(self): |
| #12 | """Set up test fixtures before each test method.""" |
| #13 | |
| #14 | # Create a mock config |
| #15 | self.config = MagicMock() |
| #16 | self.config.graph_store.config.endpoint = "neptune-db://test-graph" |
| #17 | self.config.graph_store.config.base_label = True |
| #18 | self.config.graph_store.threshold = 0.7 |
| #19 | self.config.llm.provider = "openai_structured" |
| #20 | self.config.graph_store.llm = None |
| #21 | self.config.graph_store.custom_prompt = None |
| #22 | self.config.vector_store.provider = "qdrant" |
| #23 | self.config.vector_store.config = MagicMock() |
| #24 | |
| #25 | # Create mock for NeptuneGraph |
| #26 | self.mock_graph = MagicMock() |
| #27 | |
| #28 | # Create mocks for static methods |
| #29 | self.mock_embedding_model = MagicMock() |
| #30 | self.mock_llm = MagicMock() |
| #31 | self.mock_vector_store = MagicMock() |
| #32 | |
| #33 | # Patch the necessary components |
| #34 | self.neptune_graph_patcher = patch("mem0.graphs.neptune.neptunedb.NeptuneGraph") |
| #35 | self.mock_neptune_graph = self.neptune_graph_patcher.start() |
| #36 | self.mock_neptune_graph.return_value = self.mock_graph |
| #37 | |
| #38 | # Patch the static methods |
| #39 | self.create_embedding_model_patcher = patch.object(NeptuneBase, "_create_embedding_model") |
| #40 | self.mock_create_embedding_model = self.create_embedding_model_patcher.start() |
| #41 | self.mock_create_embedding_model.return_value = self.mock_embedding_model |
| #42 | |
| #43 | self.create_llm_patcher = patch.object(NeptuneBase, "_create_llm") |
| #44 | self.mock_create_llm = self.create_llm_patcher.start() |
| #45 | self.mock_create_llm.return_value = self.mock_llm |
| #46 | |
| #47 | self.create_vector_store_patcher = patch.object(NeptuneBase, "_create_vector_store") |
| #48 | self.mock_create_vector_store = self.create_vector_store_patcher.start() |
| #49 | self.mock_create_vector_store.return_value = self.mock_vector_store |
| #50 | |
| #51 | # Create the MemoryGraph instance |
| #52 | self.memory_graph = MemoryGraph(self.config) |
| #53 | |
| #54 | # Set up common test data |
| #55 | self.user_id = "test_user" |
| #56 | self.test_filters = {"user_id": self.user_id} |
| #57 | |
| #58 | def tearDown(self): |
| #59 | """Tear down test fixtures after each test method.""" |
| #60 | self.neptune_graph_patcher.stop() |
| #61 | self.create_embedding_model_patcher.stop() |
| #62 | self.create_llm_patcher.stop() |
| #63 | self.create_vector_store_patcher.stop() |
| #64 | |
| #65 | def test_initialization(self): |
| #66 | """Test that the MemoryGraph is initialized correctly.""" |
| #67 | self.assertEqual(self.memory_graph.graph, self.mock_graph) |
| #68 | self.assertEqual(self.memory_graph.embedding_model, self.mock_embedding_model) |
| #69 | self.assertEqual(self.memory_graph.llm, self.mock_llm) |
| #70 | self.assertEqual(self.memory_graph.vector_store, self.mock_vector_store) |
| #71 | self.assertEqual(self.memory_graph.llm_provider, "openai_structured") |
| #72 | self.assertEqual(self.memory_graph.node_label, ":`__Entity__`") |
| #73 | self.assertEqual(self.memory_graph.threshold, 0.7) |
| #74 | self.assertEqual(self.memory_graph.vector_store_limit, 5) |
| #75 | |
| #76 | def test_collection_name_variants(self): |
| #77 | """Test all collection_name configuration variants.""" |
| #78 | |
| #79 | # Test 1: graph_store.config.collection_name is set |
| #80 | config1 = MagicMock() |
| #81 | config1.graph_store.config.endpoint = "neptune-db://test-graph" |
| #82 | config1.graph_store.config.base_label = True |
| #83 | config1.graph_store.config.collection_name = "custom_collection" |
| #84 | config1.llm.provider = "openai" |
| #85 | config1.graph_store.llm = None |
| #86 | config1.vector_store.provider = "qdrant" |
| #87 | config1.vector_store.config = MagicMock() |
| #88 | |
| #89 | MemoryGraph(config1) |
| #90 | self.assertEqual(config1.vector_store.config.collection_name, "custom_collection") |
| #91 | |
| #92 | # Test 2: vector_store.config.collection_name exists, graph_store.config.collection_name is None |
| #93 | config2 = MagicMock() |
| #94 | config2.graph_store.config.endpoint = "neptune-db://test-graph" |
| #95 | config2.graph_store.config.base_label = True |
| #96 | config2.graph_store.config.collection_name = None |
| #97 | config2.llm.provider = "openai" |
| #98 | config2.graph_store.llm = None |
| #99 | config2.vector_store.provider = "qdrant" |
| #100 | config2.vector_store.config = MagicMock() |
| #101 | config2.vector_store.config.collection_name = "existing_collection" |
| #102 | |
| #103 | MemoryGraph(config2) |
| #104 | self.assertEqual(config2.vector_store.config.collection_name, "existing_collection_neptune_vector_store") |
| #105 | |
| #106 | # Test 3: Neither collection_name is set (default case) |
| #107 | config3 = MagicMock() |
| #108 | config3.graph_store.config.endpoint = "neptune-db://test-graph" |
| #109 | config3.graph_store.config.base_label = True |
| #110 | config3.graph_store.config.collection_name = None |
| #111 | config3.llm.provider = "openai" |
| #112 | config3.graph_store.llm = None |
| #113 | config3.vector_store.provider = "qdrant" |
| #114 | config3.vector_store.config = MagicMock() |
| #115 | config3.vector_store.config.collection_name = None |
| #116 | |
| #117 | MemoryGraph(config3) |
| #118 | self.assertEqual(config3.vector_store.config.collection_name, "mem0_neptune_vector_store") |
| #119 | |
| #120 | def test_init(self): |
| #121 | """Test the class init functions""" |
| #122 | |
| #123 | # Create a mock config with bad endpoint |
| #124 | config_no_endpoint = MagicMock() |
| #125 | config_no_endpoint.graph_store.config.endpoint = None |
| #126 | |
| #127 | # Create the MemoryGraph instance |
| #128 | with pytest.raises(ValueError): |
| #129 | MemoryGraph(config_no_endpoint) |
| #130 | |
| #131 | # Create a mock config with wrong endpoint type |
| #132 | config_wrong_endpoint = MagicMock() |
| #133 | config_wrong_endpoint.graph_store.config.endpoint = "neptune-graph://test-graph" |
| #134 | |
| #135 | with pytest.raises(ValueError): |
| #136 | MemoryGraph(config_wrong_endpoint) |
| #137 | |
| #138 | def test_add_method(self): |
| #139 | """Test the add method with mocked components.""" |
| #140 | |
| #141 | # Mock the necessary methods that add() calls |
| #142 | self.memory_graph._retrieve_nodes_from_data = MagicMock(return_value={"alice": "person", "bob": "person"}) |
| #143 | self.memory_graph._establish_nodes_relations_from_data = MagicMock( |
| #144 | return_value=[{"source": "alice", "relationship": "knows", "destination": "bob"}] |
| #145 | ) |
| #146 | self.memory_graph._search_graph_db = MagicMock(return_value=[]) |
| #147 | self.memory_graph._get_delete_entities_from_search_output = MagicMock(return_value=[]) |
| #148 | self.memory_graph._delete_entities = MagicMock(return_value=[]) |
| #149 | self.memory_graph._add_entities = MagicMock( |
| #150 | return_value=[{"source": "alice", "relationship": "knows", "target": "bob"}] |
| #151 | ) |
| #152 | |
| #153 | # Call the add method |
| #154 | result = self.memory_graph.add("Alice knows Bob", self.test_filters) |
| #155 | |
| #156 | # Verify the method calls |
| #157 | self.memory_graph._retrieve_nodes_from_data.assert_called_once_with("Alice knows Bob", self.test_filters) |
| #158 | self.memory_graph._establish_nodes_relations_from_data.assert_called_once() |
| #159 | self.memory_graph._search_graph_db.assert_called_once() |
| #160 | self.memory_graph._get_delete_entities_from_search_output.assert_called_once() |
| #161 | self.memory_graph._delete_entities.assert_called_once_with([], self.user_id) |
| #162 | self.memory_graph._add_entities.assert_called_once() |
| #163 | |
| #164 | # Check the result structure |
| #165 | self.assertIn("deleted_entities", result) |
| #166 | self.assertIn("added_entities", result) |
| #167 | |
| #168 | def test_search_method(self): |
| #169 | """Test the search method with mocked components.""" |
| #170 | # Mock the necessary methods that search() calls |
| #171 | self.memory_graph._retrieve_nodes_from_data = MagicMock(return_value={"alice": "person"}) |
| #172 | |
| #173 | # Mock search results |
| #174 | mock_search_results = [ |
| #175 | {"source": "alice", "relationship": "knows", "destination": "bob"}, |
| #176 | {"source": "alice", "relationship": "works_with", "destination": "charlie"}, |
| #177 | ] |
| #178 | self.memory_graph._search_graph_db = MagicMock(return_value=mock_search_results) |
| #179 | |
| #180 | # Mock BM25Okapi |
| #181 | with patch("mem0.graphs.neptune.base.BM25Okapi") as mock_bm25: |
| #182 | mock_bm25_instance = MagicMock() |
| #183 | mock_bm25.return_value = mock_bm25_instance |
| #184 | |
| #185 | # Mock get_top_n to return reranked results |
| #186 | reranked_results = [["alice", "knows", "bob"], ["alice", "works_with", "charlie"]] |
| #187 | mock_bm25_instance.get_top_n.return_value = reranked_results |
| #188 | |
| #189 | # Call the search method |
| #190 | result = self.memory_graph.search("Find Alice", self.test_filters, limit=5) |
| #191 | |
| #192 | # Verify the method calls |
| #193 | self.memory_graph._retrieve_nodes_from_data.assert_called_once_with("Find Alice", self.test_filters) |
| #194 | self.memory_graph._search_graph_db.assert_called_once_with(node_list=["alice"], filters=self.test_filters) |
| #195 | |
| #196 | # Check the result structure |
| #197 | self.assertEqual(len(result), 2) |
| #198 | self.assertEqual(result[0]["source"], "alice") |
| #199 | self.assertEqual(result[0]["relationship"], "knows") |
| #200 | self.assertEqual(result[0]["destination"], "bob") |
| #201 | |
| #202 | def test_get_all_method(self): |
| #203 | """Test the get_all method.""" |
| #204 | |
| #205 | # Mock the _get_all_cypher method |
| #206 | mock_cypher = "MATCH (n) RETURN n" |
| #207 | mock_params = {"user_id": self.user_id, "limit": 10} |
| #208 | self.memory_graph._get_all_cypher = MagicMock(return_value=(mock_cypher, mock_params)) |
| #209 | |
| #210 | # Mock the graph.query result |
| #211 | mock_query_result = [ |
| #212 | {"source": "alice", "relationship": "knows", "target": "bob"}, |
| #213 | {"source": "bob", "relationship": "works_with", "target": "charlie"}, |
| #214 | ] |
| #215 | self.mock_graph.query.return_value = mock_query_result |
| #216 | |
| #217 | # Call the get_all method |
| #218 | result = self.memory_graph.get_all(self.test_filters, limit=10) |
| #219 | |
| #220 | # Verify the method calls |
| #221 | self.memory_graph._get_all_cypher.assert_called_once_with(self.test_filters, 10) |
| #222 | self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params) |
| #223 | |
| #224 | # Check the result structure |
| #225 | self.assertEqual(len(result), 2) |
| #226 | self.assertEqual(result[0]["source"], "alice") |
| #227 | self.assertEqual(result[0]["relationship"], "knows") |
| #228 | self.assertEqual(result[0]["target"], "bob") |
| #229 | |
| #230 | def test_delete_all_method(self): |
| #231 | """Test the delete_all method.""" |
| #232 | # Mock the _delete_all_cypher method |
| #233 | mock_cypher = "MATCH (n) DETACH DELETE n" |
| #234 | mock_params = {"user_id": self.user_id} |
| #235 | self.memory_graph._delete_all_cypher = MagicMock(return_value=(mock_cypher, mock_params)) |
| #236 | |
| #237 | # Call the delete_all method |
| #238 | self.memory_graph.delete_all(self.test_filters) |
| #239 | |
| #240 | # Verify the method calls |
| #241 | self.memory_graph._delete_all_cypher.assert_called_once_with(self.test_filters) |
| #242 | self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params) |
| #243 | |
| #244 | def test_search_source_node(self): |
| #245 | """Test the _search_source_node method.""" |
| #246 | # Mock embedding |
| #247 | mock_embedding = [0.1, 0.2, 0.3] |
| #248 | |
| #249 | # Mock the _search_source_node_cypher method |
| #250 | mock_cypher = "MATCH (n) RETURN n" |
| #251 | mock_params = {"source_embedding": mock_embedding, "user_id": self.user_id, "threshold": 0.9} |
| #252 | self.memory_graph._search_source_node_cypher = MagicMock(return_value=(mock_cypher, mock_params)) |
| #253 | |
| #254 | # Mock the graph.query result |
| #255 | mock_query_result = [{"id(source_candidate)": 123, "cosine_similarity": 0.95}] |
| #256 | self.mock_graph.query.return_value = mock_query_result |
| #257 | |
| #258 | # Call the _search_source_node method |
| #259 | result = self.memory_graph._search_source_node(mock_embedding, self.user_id, threshold=0.9) |
| #260 | |
| #261 | # Verify the method calls |
| #262 | self.memory_graph._search_source_node_cypher.assert_called_once_with(mock_embedding, self.user_id, 0.9) |
| #263 | self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params) |
| #264 | |
| #265 | # Check the result |
| #266 | self.assertEqual(result, mock_query_result) |
| #267 | |
| #268 | def test_search_destination_node(self): |
| #269 | """Test the _search_destination_node method.""" |
| #270 | # Mock embedding |
| #271 | mock_embedding = [0.1, 0.2, 0.3] |
| #272 | |
| #273 | # Mock the _search_destination_node_cypher method |
| #274 | mock_cypher = "MATCH (n) RETURN n" |
| #275 | mock_params = {"destination_embedding": mock_embedding, "user_id": self.user_id, "threshold": 0.9} |
| #276 | self.memory_graph._search_destination_node_cypher = MagicMock(return_value=(mock_cypher, mock_params)) |
| #277 | |
| #278 | # Mock the graph.query result |
| #279 | mock_query_result = [{"id(destination_candidate)": 456, "cosine_similarity": 0.92}] |
| #280 | self.mock_graph.query.return_value = mock_query_result |
| #281 | |
| #282 | # Call the _search_destination_node method |
| #283 | result = self.memory_graph._search_destination_node(mock_embedding, self.user_id, threshold=0.9) |
| #284 | |
| #285 | # Verify the method calls |
| #286 | self.memory_graph._search_destination_node_cypher.assert_called_once_with(mock_embedding, self.user_id, 0.9) |
| #287 | self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params) |
| #288 | |
| #289 | # Check the result |
| #290 | self.assertEqual(result, mock_query_result) |
| #291 | |
| #292 | def test_search_graph_db(self): |
| #293 | """Test the _search_graph_db method.""" |
| #294 | # Mock node list |
| #295 | node_list = ["alice", "bob"] |
| #296 | |
| #297 | # Mock embedding |
| #298 | mock_embedding = [0.1, 0.2, 0.3] |
| #299 | self.mock_embedding_model.embed.return_value = mock_embedding |
| #300 | |
| #301 | # Mock the _search_graph_db_cypher method |
| #302 | mock_cypher = "MATCH (n) RETURN n" |
| #303 | mock_params = {"n_embedding": mock_embedding, "user_id": self.user_id, "threshold": 0.7, "limit": 10} |
| #304 | self.memory_graph._search_graph_db_cypher = MagicMock(return_value=(mock_cypher, mock_params)) |
| #305 | |
| #306 | # Mock the graph.query results |
| #307 | mock_query_result1 = [{"source": "alice", "relationship": "knows", "destination": "bob"}] |
| #308 | mock_query_result2 = [{"source": "bob", "relationship": "works_with", "destination": "charlie"}] |
| #309 | self.mock_graph.query.side_effect = [mock_query_result1, mock_query_result2] |
| #310 | |
| #311 | # Call the _search_graph_db method |
| #312 | result = self.memory_graph._search_graph_db(node_list, self.test_filters, limit=10) |
| #313 | |
| #314 | # Verify the method calls |
| #315 | self.assertEqual(self.mock_embedding_model.embed.call_count, 2) |
| #316 | self.assertEqual(self.memory_graph._search_graph_db_cypher.call_count, 2) |
| #317 | self.assertEqual(self.mock_graph.query.call_count, 2) |
| #318 | |
| #319 | # Check the result |
| #320 | expected_result = mock_query_result1 + mock_query_result2 |
| #321 | self.assertEqual(result, expected_result) |
| #322 | |
| #323 | def test_add_entities(self): |
| #324 | """Test the _add_entities method.""" |
| #325 | # Mock data |
| #326 | to_be_added = [{"source": "alice", "relationship": "knows", "destination": "bob"}] |
| #327 | entity_type_map = {"alice": "person", "bob": "person"} |
| #328 | |
| #329 | # Mock embeddings |
| #330 | mock_embedding = [0.1, 0.2, 0.3] |
| #331 | self.mock_embedding_model.embed.return_value = mock_embedding |
| #332 | |
| #333 | # Mock search results |
| #334 | mock_source_search = [{"id(source_candidate)": 123, "cosine_similarity": 0.95}] |
| #335 | mock_dest_search = [{"id(destination_candidate)": 456, "cosine_similarity": 0.92}] |
| #336 | |
| #337 | # Mock the search methods |
| #338 | self.memory_graph._search_source_node = MagicMock(return_value=mock_source_search) |
| #339 | self.memory_graph._search_destination_node = MagicMock(return_value=mock_dest_search) |
| #340 | |
| #341 | # Mock the _add_entities_cypher method |
| #342 | mock_cypher = "MATCH (n) RETURN n" |
| #343 | mock_params = {"source_id": 123, "destination_id": 456} |
| #344 | self.memory_graph._add_entities_cypher = MagicMock(return_value=(mock_cypher, mock_params)) |
| #345 | |
| #346 | # Mock the graph.query result |
| #347 | mock_query_result = [{"source": "alice", "relationship": "knows", "target": "bob"}] |
| #348 | self.mock_graph.query.return_value = mock_query_result |
| #349 | |
| #350 | # Call the _add_entities method |
| #351 | result = self.memory_graph._add_entities(to_be_added, self.user_id, entity_type_map) |
| #352 | |
| #353 | # Verify the method calls |
| #354 | self.assertEqual(self.mock_embedding_model.embed.call_count, 2) |
| #355 | self.memory_graph._search_source_node.assert_called_once_with(mock_embedding, self.user_id, threshold=0.7) |
| #356 | self.memory_graph._search_destination_node.assert_called_once_with(mock_embedding, self.user_id, threshold=0.7) |
| #357 | self.memory_graph._add_entities_cypher.assert_called_once() |
| #358 | self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params) |
| #359 | |
| #360 | # Check the result |
| #361 | self.assertEqual(result, [mock_query_result]) |
| #362 | |
| #363 | def test_delete_entities(self): |
| #364 | """Test the _delete_entities method.""" |
| #365 | # Mock data |
| #366 | to_be_deleted = [{"source": "alice", "relationship": "knows", "destination": "bob"}] |
| #367 | |
| #368 | # Mock the _delete_entities_cypher method |
| #369 | mock_cypher = "MATCH (n) RETURN n" |
| #370 | mock_params = {"source_name": "alice", "dest_name": "bob", "user_id": self.user_id} |
| #371 | self.memory_graph._delete_entities_cypher = MagicMock(return_value=(mock_cypher, mock_params)) |
| #372 | |
| #373 | # Mock the graph.query result |
| #374 | mock_query_result = [{"source": "alice", "relationship": "knows", "target": "bob"}] |
| #375 | self.mock_graph.query.return_value = mock_query_result |
| #376 | |
| #377 | # Call the _delete_entities method |
| #378 | result = self.memory_graph._delete_entities(to_be_deleted, self.user_id) |
| #379 | |
| #380 | # Verify the method calls |
| #381 | self.memory_graph._delete_entities_cypher.assert_called_once_with("alice", "bob", "knows", self.user_id) |
| #382 | self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params) |
| #383 | |
| #384 | # Check the result |
| #385 | self.assertEqual(result, [mock_query_result]) |
| #386 | |
| #387 | |
| #388 | if __name__ == "__main__": |
| #389 | unittest.main() |
| #390 |