repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import unittest |
| #2 | import uuid |
| #3 | |
| #4 | from mock import patch |
| #5 | from qdrant_client.http import models |
| #6 | from qdrant_client.http.models import Batch |
| #7 | |
| #8 | from embedchain import App |
| #9 | from embedchain.config import AppConfig |
| #10 | from embedchain.config.vector_db.pinecone import PineconeDBConfig |
| #11 | from embedchain.embedder.base import BaseEmbedder |
| #12 | from embedchain.vectordb.qdrant import QdrantDB |
| #13 | |
| #14 | |
| #15 | def mock_embedding_fn(texts: list[str]) -> list[list[float]]: |
| #16 | """A mock embedding function.""" |
| #17 | return [[1, 2, 3], [4, 5, 6]] |
| #18 | |
| #19 | |
| #20 | class TestQdrantDB(unittest.TestCase): |
| #21 | TEST_UUIDS = ["abc", "def", "ghi"] |
| #22 | |
| #23 | def test_incorrect_config_throws_error(self): |
| #24 | """Test the init method of the Qdrant class throws error for incorrect config""" |
| #25 | with self.assertRaises(TypeError): |
| #26 | QdrantDB(config=PineconeDBConfig()) |
| #27 | |
| #28 | @patch("embedchain.vectordb.qdrant.QdrantClient") |
| #29 | def test_initialize(self, qdrant_client_mock): |
| #30 | # Set the embedder |
| #31 | embedder = BaseEmbedder() |
| #32 | embedder.set_vector_dimension(1536) |
| #33 | embedder.set_embedding_fn(mock_embedding_fn) |
| #34 | |
| #35 | # Create a Qdrant instance |
| #36 | db = QdrantDB() |
| #37 | app_config = AppConfig(collect_metrics=False) |
| #38 | App(config=app_config, db=db, embedding_model=embedder) |
| #39 | |
| #40 | self.assertEqual(db.collection_name, "embedchain-store-1536") |
| #41 | self.assertEqual(db.client, qdrant_client_mock.return_value) |
| #42 | qdrant_client_mock.return_value.get_collections.assert_called_once() |
| #43 | |
| #44 | @patch("embedchain.vectordb.qdrant.QdrantClient") |
| #45 | def test_get(self, qdrant_client_mock): |
| #46 | qdrant_client_mock.return_value.scroll.return_value = ([], None) |
| #47 | |
| #48 | # Set the embedder |
| #49 | embedder = BaseEmbedder() |
| #50 | embedder.set_vector_dimension(1536) |
| #51 | embedder.set_embedding_fn(mock_embedding_fn) |
| #52 | |
| #53 | # Create a Qdrant instance |
| #54 | db = QdrantDB() |
| #55 | app_config = AppConfig(collect_metrics=False) |
| #56 | App(config=app_config, db=db, embedding_model=embedder) |
| #57 | |
| #58 | resp = db.get(ids=[], where={}) |
| #59 | self.assertEqual(resp, {"ids": [], "metadatas": []}) |
| #60 | resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"}) |
| #61 | self.assertEqual(resp2, {"ids": [], "metadatas": []}) |
| #62 | |
| #63 | @patch("embedchain.vectordb.qdrant.QdrantClient") |
| #64 | @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS) |
| #65 | def test_add(self, uuid_mock, qdrant_client_mock): |
| #66 | qdrant_client_mock.return_value.scroll.return_value = ([], None) |
| #67 | |
| #68 | # Set the embedder |
| #69 | embedder = BaseEmbedder() |
| #70 | embedder.set_vector_dimension(1536) |
| #71 | embedder.set_embedding_fn(mock_embedding_fn) |
| #72 | |
| #73 | # Create a Qdrant instance |
| #74 | db = QdrantDB() |
| #75 | app_config = AppConfig(collect_metrics=False) |
| #76 | App(config=app_config, db=db, embedding_model=embedder) |
| #77 | |
| #78 | documents = ["This is a test document.", "This is another test document."] |
| #79 | metadatas = [{}, {}] |
| #80 | ids = ["123", "456"] |
| #81 | db.add(documents, metadatas, ids) |
| #82 | qdrant_client_mock.return_value.upsert.assert_called_once_with( |
| #83 | collection_name="embedchain-store-1536", |
| #84 | points=Batch( |
| #85 | ids=["123", "456"], |
| #86 | payloads=[ |
| #87 | { |
| #88 | "identifier": "123", |
| #89 | "text": "This is a test document.", |
| #90 | "metadata": {"text": "This is a test document."}, |
| #91 | }, |
| #92 | { |
| #93 | "identifier": "456", |
| #94 | "text": "This is another test document.", |
| #95 | "metadata": {"text": "This is another test document."}, |
| #96 | }, |
| #97 | ], |
| #98 | vectors=[[1, 2, 3], [4, 5, 6]], |
| #99 | ), |
| #100 | ) |
| #101 | |
| #102 | @patch("embedchain.vectordb.qdrant.QdrantClient") |
| #103 | def test_query(self, qdrant_client_mock): |
| #104 | # Set the embedder |
| #105 | embedder = BaseEmbedder() |
| #106 | embedder.set_vector_dimension(1536) |
| #107 | embedder.set_embedding_fn(mock_embedding_fn) |
| #108 | |
| #109 | # Create a Qdrant instance |
| #110 | db = QdrantDB() |
| #111 | app_config = AppConfig(collect_metrics=False) |
| #112 | App(config=app_config, db=db, embedding_model=embedder) |
| #113 | |
| #114 | # Query for the document. |
| #115 | db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"}) |
| #116 | |
| #117 | qdrant_client_mock.return_value.search.assert_called_once_with( |
| #118 | collection_name="embedchain-store-1536", |
| #119 | query_filter=models.Filter( |
| #120 | must=[ |
| #121 | models.FieldCondition( |
| #122 | key="metadata.doc_id", |
| #123 | match=models.MatchValue( |
| #124 | value="123", |
| #125 | ), |
| #126 | ) |
| #127 | ] |
| #128 | ), |
| #129 | query_vector=[1, 2, 3], |
| #130 | limit=1, |
| #131 | ) |
| #132 | |
| #133 | @patch("embedchain.vectordb.qdrant.QdrantClient") |
| #134 | def test_count(self, qdrant_client_mock): |
| #135 | # Set the embedder |
| #136 | embedder = BaseEmbedder() |
| #137 | embedder.set_vector_dimension(1536) |
| #138 | embedder.set_embedding_fn(mock_embedding_fn) |
| #139 | |
| #140 | # Create a Qdrant instance |
| #141 | db = QdrantDB() |
| #142 | app_config = AppConfig(collect_metrics=False) |
| #143 | App(config=app_config, db=db, embedding_model=embedder) |
| #144 | |
| #145 | db.count() |
| #146 | qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1536") |
| #147 | |
| #148 | @patch("embedchain.vectordb.qdrant.QdrantClient") |
| #149 | def test_reset(self, qdrant_client_mock): |
| #150 | # Set the embedder |
| #151 | embedder = BaseEmbedder() |
| #152 | embedder.set_vector_dimension(1536) |
| #153 | embedder.set_embedding_fn(mock_embedding_fn) |
| #154 | |
| #155 | # Create a Qdrant instance |
| #156 | db = QdrantDB() |
| #157 | app_config = AppConfig(collect_metrics=False) |
| #158 | App(config=app_config, db=db, embedding_model=embedder) |
| #159 | |
| #160 | db.reset() |
| #161 | qdrant_client_mock.return_value.delete_collection.assert_called_once_with( |
| #162 | collection_name="embedchain-store-1536" |
| #163 | ) |
| #164 | |
| #165 | |
| #166 | if __name__ == "__main__": |
| #167 | unittest.main() |
| #168 |