repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | from typing import Any, Optional, Union |
| #3 | |
| #4 | from embedchain.config import ZillizDBConfig |
| #5 | from embedchain.helpers.json_serializable import register_deserializable |
| #6 | from embedchain.vectordb.base import BaseVectorDB |
| #7 | |
| #8 | try: |
| #9 | from pymilvus import ( |
| #10 | Collection, |
| #11 | CollectionSchema, |
| #12 | DataType, |
| #13 | FieldSchema, |
| #14 | MilvusClient, |
| #15 | connections, |
| #16 | utility, |
| #17 | ) |
| #18 | except ImportError: |
| #19 | raise ImportError( |
| #20 | "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`" |
| #21 | ) from None |
| #22 | |
| #23 | logger = logging.getLogger(__name__) |
| #24 | |
| #25 | |
| #26 | @register_deserializable |
| #27 | class ZillizVectorDB(BaseVectorDB): |
| #28 | """Base class for vector database.""" |
| #29 | |
| #30 | def __init__(self, config: ZillizDBConfig = None): |
| #31 | """Initialize the database. Save the config and client as an attribute. |
| #32 | |
| #33 | :param config: Database configuration class instance. |
| #34 | :type config: ZillizDBConfig |
| #35 | """ |
| #36 | |
| #37 | if config is None: |
| #38 | self.config = ZillizDBConfig() |
| #39 | else: |
| #40 | self.config = config |
| #41 | |
| #42 | self.client = MilvusClient( |
| #43 | uri=self.config.uri, |
| #44 | token=self.config.token, |
| #45 | ) |
| #46 | |
| #47 | self.connection = connections.connect( |
| #48 | uri=self.config.uri, |
| #49 | token=self.config.token, |
| #50 | ) |
| #51 | |
| #52 | super().__init__(config=self.config) |
| #53 | |
| #54 | def _initialize(self): |
| #55 | """ |
| #56 | This method is needed because `embedder` attribute needs to be set externally before it can be initialized. |
| #57 | |
| #58 | So it's can't be done in __init__ in one step. |
| #59 | """ |
| #60 | self._get_or_create_collection(self.config.collection_name) |
| #61 | |
| #62 | def _get_or_create_db(self): |
| #63 | """Get or create the database.""" |
| #64 | return self.client |
| #65 | |
| #66 | def _get_or_create_collection(self, name): |
| #67 | """ |
| #68 | Get or create a named collection. |
| #69 | |
| #70 | :param name: Name of the collection |
| #71 | :type name: str |
| #72 | """ |
| #73 | if utility.has_collection(name): |
| #74 | logger.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.") |
| #75 | self.collection = Collection(name) |
| #76 | else: |
| #77 | fields = [ |
| #78 | FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512), |
| #79 | FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048), |
| #80 | FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.embedder.vector_dimension), |
| #81 | FieldSchema(name="metadata", dtype=DataType.JSON), |
| #82 | ] |
| #83 | |
| #84 | schema = CollectionSchema(fields, enable_dynamic_field=True) |
| #85 | self.collection = Collection(name=name, schema=schema) |
| #86 | |
| #87 | index = { |
| #88 | "index_type": "AUTOINDEX", |
| #89 | "metric_type": self.config.metric_type, |
| #90 | } |
| #91 | self.collection.create_index("embeddings", index) |
| #92 | return self.collection |
| #93 | |
| #94 | def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): |
| #95 | """ |
| #96 | Get existing doc ids present in vector database |
| #97 | |
| #98 | :param ids: list of doc ids to check for existence |
| #99 | :type ids: list[str] |
| #100 | :param where: Optional. to filter data |
| #101 | :type where: dict[str, Any] |
| #102 | :param limit: Optional. maximum number of documents |
| #103 | :type limit: Optional[int] |
| #104 | :return: Existing documents. |
| #105 | :rtype: Set[str] |
| #106 | """ |
| #107 | data_ids = [] |
| #108 | metadatas = [] |
| #109 | if self.collection.num_entities == 0 or self.collection.is_empty: |
| #110 | return {"ids": data_ids, "metadatas": metadatas} |
| #111 | |
| #112 | filter_ = "" |
| #113 | if ids: |
| #114 | filter_ = f'id in "{ids}"' |
| #115 | |
| #116 | if where: |
| #117 | if filter_: |
| #118 | filter_ += " and " |
| #119 | filter_ = f"{self._generate_zilliz_filter(where)}" |
| #120 | |
| #121 | results = self.client.query(collection_name=self.config.collection_name, filter=filter_, output_fields=["*"]) |
| #122 | for res in results: |
| #123 | data_ids.append(res.get("id")) |
| #124 | metadatas.append(res.get("metadata", {})) |
| #125 | |
| #126 | return {"ids": data_ids, "metadatas": metadatas} |
| #127 | |
| #128 | def add( |
| #129 | self, |
| #130 | documents: list[str], |
| #131 | metadatas: list[object], |
| #132 | ids: list[str], |
| #133 | **kwargs: Optional[dict[str, any]], |
| #134 | ): |
| #135 | """Add to database""" |
| #136 | embeddings = self.embedder.embedding_fn(documents) |
| #137 | |
| #138 | for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings): |
| #139 | data = {"id": id, "text": doc, "embeddings": embedding, "metadata": metadata} |
| #140 | self.client.insert(collection_name=self.config.collection_name, data=data, **kwargs) |
| #141 | |
| #142 | self.collection.load() |
| #143 | self.collection.flush() |
| #144 | self.client.flush(self.config.collection_name) |
| #145 | |
| #146 | def query( |
| #147 | self, |
| #148 | input_query: str, |
| #149 | n_results: int, |
| #150 | where: dict[str, Any], |
| #151 | citations: bool = False, |
| #152 | **kwargs: Optional[dict[str, Any]], |
| #153 | ) -> Union[list[tuple[str, dict]], list[str]]: |
| #154 | """ |
| #155 | Query contents from vector database based on vector similarity |
| #156 | |
| #157 | :param input_query: query string |
| #158 | :type input_query: str |
| #159 | :param n_results: no of similar documents to fetch from database |
| #160 | :type n_results: int |
| #161 | :param where: to filter data |
| #162 | :type where: dict[str, Any] |
| #163 | :raises InvalidDimensionException: Dimensions do not match. |
| #164 | :param citations: we use citations boolean param to return context along with the answer. |
| #165 | :type citations: bool, default is False. |
| #166 | :return: The content of the document that matched your query, |
| #167 | along with url of the source and doc_id (if citations flag is true) |
| #168 | :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] |
| #169 | """ |
| #170 | |
| #171 | if self.collection.is_empty: |
| #172 | return [] |
| #173 | |
| #174 | output_fields = ["*"] |
| #175 | input_query_vector = self.embedder.embedding_fn([input_query]) |
| #176 | query_vector = input_query_vector[0] |
| #177 | |
| #178 | query_filter = self._generate_zilliz_filter(where) |
| #179 | query_result = self.client.search( |
| #180 | collection_name=self.config.collection_name, |
| #181 | data=[query_vector], |
| #182 | filter=query_filter, |
| #183 | limit=n_results, |
| #184 | output_fields=output_fields, |
| #185 | **kwargs, |
| #186 | ) |
| #187 | query_result = query_result[0] |
| #188 | contexts = [] |
| #189 | for query in query_result: |
| #190 | data = query["entity"] |
| #191 | score = query["distance"] |
| #192 | context = data["text"] |
| #193 | |
| #194 | if citations: |
| #195 | metadata = data.get("metadata", {}) |
| #196 | metadata["score"] = score |
| #197 | contexts.append(tuple((context, metadata))) |
| #198 | else: |
| #199 | contexts.append(context) |
| #200 | return contexts |
| #201 | |
| #202 | def count(self) -> int: |
| #203 | """ |
| #204 | Count number of documents/chunks embedded in the database. |
| #205 | |
| #206 | :return: number of documents |
| #207 | :rtype: int |
| #208 | """ |
| #209 | return self.collection.num_entities |
| #210 | |
| #211 | def reset(self, collection_names: list[str] = None): |
| #212 | """ |
| #213 | Resets the database. Deletes all embeddings irreversibly. |
| #214 | """ |
| #215 | if self.config.collection_name: |
| #216 | if collection_names: |
| #217 | for collection_name in collection_names: |
| #218 | if collection_name in self.client.list_collections(): |
| #219 | self.client.drop_collection(collection_name=collection_name) |
| #220 | else: |
| #221 | self.client.drop_collection(collection_name=self.config.collection_name) |
| #222 | self._get_or_create_collection(self.config.collection_name) |
| #223 | |
| #224 | def set_collection_name(self, name: str): |
| #225 | """ |
| #226 | Set the name of the collection. A collection is an isolated space for vectors. |
| #227 | |
| #228 | :param name: Name of the collection. |
| #229 | :type name: str |
| #230 | """ |
| #231 | if not isinstance(name, str): |
| #232 | raise TypeError("Collection name must be a string") |
| #233 | self.config.collection_name = name |
| #234 | |
| #235 | def _generate_zilliz_filter(self, where: dict[str, str]): |
| #236 | operands = [] |
| #237 | for key, value in where.items(): |
| #238 | operands.append(f'(metadata["{key}"] == "{value}")') |
| #239 | return " and ".join(operands) |
| #240 | |
| #241 | def delete(self, where: dict[str, Any]): |
| #242 | """ |
| #243 | Delete the embeddings from DB. Zilliz only support deleting with keys. |
| #244 | |
| #245 | |
| #246 | :param keys: Primary keys of the table entries to delete. |
| #247 | :type keys: Union[list, str, int] |
| #248 | """ |
| #249 | data = self.get(where=where) |
| #250 | keys = data.get("ids", []) |
| #251 | if keys: |
| #252 | self.client.delete(collection_name=self.config.collection_name, pks=keys) |
| #253 |