repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import unittest |
| #2 | from unittest.mock import patch |
| #3 | |
| #4 | from embedchain import App |
| #5 | from embedchain.config import AppConfig |
| #6 | from embedchain.config.vector_db.pinecone import PineconeDBConfig |
| #7 | from embedchain.embedder.base import BaseEmbedder |
| #8 | from embedchain.vectordb.weaviate import WeaviateDB |
| #9 | |
| #10 | |
| #11 | def mock_embedding_fn(texts: list[str]) -> list[list[float]]: |
| #12 | """A mock embedding function.""" |
| #13 | return [[1, 2, 3], [4, 5, 6]] |
| #14 | |
| #15 | |
| #16 | class TestWeaviateDb(unittest.TestCase): |
| #17 | def test_incorrect_config_throws_error(self): |
| #18 | """Test the init method of the WeaviateDb class throws error for incorrect config""" |
| #19 | with self.assertRaises(TypeError): |
| #20 | WeaviateDB(config=PineconeDBConfig()) |
| #21 | |
| #22 | @patch("embedchain.vectordb.weaviate.weaviate") |
| #23 | def test_initialize(self, weaviate_mock): |
| #24 | """Test the init method of the WeaviateDb class.""" |
| #25 | weaviate_client_mock = weaviate_mock.Client.return_value |
| #26 | weaviate_client_schema_mock = weaviate_client_mock.schema |
| #27 | |
| #28 | # Mock that schema doesn't already exist so that a new schema is created |
| #29 | weaviate_client_schema_mock.exists.return_value = False |
| #30 | # Set the embedder |
| #31 | embedder = BaseEmbedder() |
| #32 | embedder.set_vector_dimension(1536) |
| #33 | embedder.set_embedding_fn(mock_embedding_fn) |
| #34 | |
| #35 | # Create a Weaviate instance |
| #36 | db = WeaviateDB() |
| #37 | app_config = AppConfig(collect_metrics=False) |
| #38 | App(config=app_config, db=db, embedding_model=embedder) |
| #39 | |
| #40 | expected_class_obj = { |
| #41 | "classes": [ |
| #42 | { |
| #43 | "class": "Embedchain_store_1536", |
| #44 | "vectorizer": "none", |
| #45 | "properties": [ |
| #46 | { |
| #47 | "name": "identifier", |
| #48 | "dataType": ["text"], |
| #49 | }, |
| #50 | { |
| #51 | "name": "text", |
| #52 | "dataType": ["text"], |
| #53 | }, |
| #54 | { |
| #55 | "name": "metadata", |
| #56 | "dataType": ["Embedchain_store_1536_metadata"], |
| #57 | }, |
| #58 | ], |
| #59 | }, |
| #60 | { |
| #61 | "class": "Embedchain_store_1536_metadata", |
| #62 | "vectorizer": "none", |
| #63 | "properties": [ |
| #64 | { |
| #65 | "name": "data_type", |
| #66 | "dataType": ["text"], |
| #67 | }, |
| #68 | { |
| #69 | "name": "doc_id", |
| #70 | "dataType": ["text"], |
| #71 | }, |
| #72 | { |
| #73 | "name": "url", |
| #74 | "dataType": ["text"], |
| #75 | }, |
| #76 | { |
| #77 | "name": "hash", |
| #78 | "dataType": ["text"], |
| #79 | }, |
| #80 | { |
| #81 | "name": "app_id", |
| #82 | "dataType": ["text"], |
| #83 | }, |
| #84 | ], |
| #85 | }, |
| #86 | ] |
| #87 | } |
| #88 | |
| #89 | # Assert that the Weaviate client was initialized |
| #90 | weaviate_mock.Client.assert_called_once() |
| #91 | self.assertEqual(db.index_name, "Embedchain_store_1536") |
| #92 | weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj) |
| #93 | |
| #94 | @patch("embedchain.vectordb.weaviate.weaviate") |
| #95 | def test_get_or_create_db(self, weaviate_mock): |
| #96 | """Test the _get_or_create_db method of the WeaviateDb class.""" |
| #97 | weaviate_client_mock = weaviate_mock.Client.return_value |
| #98 | |
| #99 | embedder = BaseEmbedder() |
| #100 | embedder.set_vector_dimension(1536) |
| #101 | embedder.set_embedding_fn(mock_embedding_fn) |
| #102 | |
| #103 | # Create a Weaviate instance |
| #104 | db = WeaviateDB() |
| #105 | app_config = AppConfig(collect_metrics=False) |
| #106 | App(config=app_config, db=db, embedding_model=embedder) |
| #107 | |
| #108 | expected_client = db._get_or_create_db() |
| #109 | self.assertEqual(expected_client, weaviate_client_mock) |
| #110 | |
| #111 | @patch("embedchain.vectordb.weaviate.weaviate") |
| #112 | def test_add(self, weaviate_mock): |
| #113 | """Test the add method of the WeaviateDb class.""" |
| #114 | weaviate_client_mock = weaviate_mock.Client.return_value |
| #115 | weaviate_client_batch_mock = weaviate_client_mock.batch |
| #116 | weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value |
| #117 | |
| #118 | # Set the embedder |
| #119 | embedder = BaseEmbedder() |
| #120 | embedder.set_vector_dimension(1536) |
| #121 | embedder.set_embedding_fn(mock_embedding_fn) |
| #122 | |
| #123 | # Create a Weaviate instance |
| #124 | db = WeaviateDB() |
| #125 | app_config = AppConfig(collect_metrics=False) |
| #126 | App(config=app_config, db=db, embedding_model=embedder) |
| #127 | |
| #128 | documents = ["This is test document"] |
| #129 | metadatas = [None] |
| #130 | ids = ["id_1"] |
| #131 | db.add(documents, metadatas, ids) |
| #132 | |
| #133 | # Check if the document was added to the database. |
| #134 | weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=100, timeout_retries=3) |
| #135 | weaviate_client_batch_enter_mock.add_data_object.assert_any_call( |
| #136 | data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3] |
| #137 | ) |
| #138 | |
| #139 | weaviate_client_batch_enter_mock.add_data_object.assert_any_call( |
| #140 | data_object={"text": documents[0]}, |
| #141 | class_name="Embedchain_store_1536_metadata", |
| #142 | vector=[1, 2, 3], |
| #143 | ) |
| #144 | |
| #145 | @patch("embedchain.vectordb.weaviate.weaviate") |
| #146 | def test_query_without_where(self, weaviate_mock): |
| #147 | """Test the query method of the WeaviateDb class.""" |
| #148 | weaviate_client_mock = weaviate_mock.Client.return_value |
| #149 | weaviate_client_query_mock = weaviate_client_mock.query |
| #150 | weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value |
| #151 | |
| #152 | # Set the embedder |
| #153 | embedder = BaseEmbedder() |
| #154 | embedder.set_vector_dimension(1536) |
| #155 | embedder.set_embedding_fn(mock_embedding_fn) |
| #156 | |
| #157 | # Create a Weaviate instance |
| #158 | db = WeaviateDB() |
| #159 | app_config = AppConfig(collect_metrics=False) |
| #160 | App(config=app_config, db=db, embedding_model=embedder) |
| #161 | |
| #162 | # Query for the document. |
| #163 | db.query(input_query="This is a test document.", n_results=1, where={}) |
| #164 | |
| #165 | weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"]) |
| #166 | weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]}) |
| #167 | |
| #168 | @patch("embedchain.vectordb.weaviate.weaviate") |
| #169 | def test_query_with_where(self, weaviate_mock): |
| #170 | """Test the query method of the WeaviateDb class.""" |
| #171 | weaviate_client_mock = weaviate_mock.Client.return_value |
| #172 | weaviate_client_query_mock = weaviate_client_mock.query |
| #173 | weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value |
| #174 | weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value |
| #175 | |
| #176 | # Set the embedder |
| #177 | embedder = BaseEmbedder() |
| #178 | embedder.set_vector_dimension(1536) |
| #179 | embedder.set_embedding_fn(mock_embedding_fn) |
| #180 | |
| #181 | # Create a Weaviate instance |
| #182 | db = WeaviateDB() |
| #183 | app_config = AppConfig(collect_metrics=False) |
| #184 | App(config=app_config, db=db, embedding_model=embedder) |
| #185 | |
| #186 | # Query for the document. |
| #187 | db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"}) |
| #188 | |
| #189 | weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"]) |
| #190 | weaviate_client_query_get_mock.with_where.assert_called_once_with( |
| #191 | {"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"} |
| #192 | ) |
| #193 | weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]}) |
| #194 | |
| #195 | @patch("embedchain.vectordb.weaviate.weaviate") |
| #196 | def test_reset(self, weaviate_mock): |
| #197 | """Test the reset method of the WeaviateDb class.""" |
| #198 | weaviate_client_mock = weaviate_mock.Client.return_value |
| #199 | weaviate_client_batch_mock = weaviate_client_mock.batch |
| #200 | |
| #201 | # Set the embedder |
| #202 | embedder = BaseEmbedder() |
| #203 | embedder.set_vector_dimension(1536) |
| #204 | embedder.set_embedding_fn(mock_embedding_fn) |
| #205 | |
| #206 | # Create a Weaviate instance |
| #207 | db = WeaviateDB() |
| #208 | app_config = AppConfig(collect_metrics=False) |
| #209 | App(config=app_config, db=db, embedding_model=embedder) |
| #210 | |
| #211 | # Reset the database. |
| #212 | db.reset() |
| #213 | |
| #214 | weaviate_client_batch_mock.delete_objects.assert_called_once_with( |
| #215 | "Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"} |
| #216 | ) |
| #217 | |
| #218 | @patch("embedchain.vectordb.weaviate.weaviate") |
| #219 | def test_count(self, weaviate_mock): |
| #220 | """Test the reset method of the WeaviateDb class.""" |
| #221 | weaviate_client_mock = weaviate_mock.Client.return_value |
| #222 | weaviate_client_query = weaviate_client_mock.query |
| #223 | |
| #224 | # Set the embedder |
| #225 | embedder = BaseEmbedder() |
| #226 | embedder.set_vector_dimension(1536) |
| #227 | embedder.set_embedding_fn(mock_embedding_fn) |
| #228 | |
| #229 | # Create a Weaviate instance |
| #230 | db = WeaviateDB() |
| #231 | app_config = AppConfig(collect_metrics=False) |
| #232 | App(config=app_config, db=db, embedding_model=embedder) |
| #233 | |
| #234 | # Reset the database. |
| #235 | db.count() |
| #236 | |
| #237 | weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536") |
| #238 |