repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import copy |
| #2 | import os |
| #3 | from typing import Optional, Union |
| #4 | |
| #5 | try: |
| #6 | import weaviate |
| #7 | except ImportError: |
| #8 | raise ImportError( |
| #9 | "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`" |
| #10 | ) from None |
| #11 | |
| #12 | from embedchain.config.vector_db.weaviate import WeaviateDBConfig |
| #13 | from embedchain.helpers.json_serializable import register_deserializable |
| #14 | from embedchain.vectordb.base import BaseVectorDB |
| #15 | |
| #16 | |
| #17 | @register_deserializable |
| #18 | class WeaviateDB(BaseVectorDB): |
| #19 | """ |
| #20 | Weaviate as vector database |
| #21 | """ |
| #22 | |
| #23 | def __init__( |
| #24 | self, |
| #25 | config: Optional[WeaviateDBConfig] = None, |
| #26 | ): |
| #27 | """Weaviate as vector database. |
| #28 | :param config: Weaviate database config, defaults to None |
| #29 | :type config: WeaviateDBConfig, optional |
| #30 | :raises ValueError: No config provided |
| #31 | """ |
| #32 | if config is None: |
| #33 | self.config = WeaviateDBConfig() |
| #34 | else: |
| #35 | if not isinstance(config, WeaviateDBConfig): |
| #36 | raise TypeError( |
| #37 | "config is not a `WeaviateDBConfig` instance. " |
| #38 | "Please make sure the type is right and that you are passing an instance." |
| #39 | ) |
| #40 | self.config = config |
| #41 | self.batch_size = self.config.batch_size |
| #42 | self.client = weaviate.Client( |
| #43 | url=os.environ.get("WEAVIATE_ENDPOINT"), |
| #44 | auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")), |
| #45 | **self.config.extra_params, |
| #46 | ) |
| #47 | # Since weaviate uses graphQL, we need to keep track of metadata keys added in the vectordb. |
| #48 | # This is needed to filter data while querying. |
| #49 | self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"} |
| #50 | |
| #51 | # Call parent init here because embedder is needed |
| #52 | super().__init__(config=self.config) |
| #53 | |
| #54 | def _initialize(self): |
| #55 | """ |
| #56 | This method is needed because `embedder` attribute needs to be set externally before it can be initialized. |
| #57 | """ |
| #58 | |
| #59 | if not self.embedder: |
| #60 | raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") |
| #61 | |
| #62 | self.index_name = self._get_index_name() |
| #63 | if not self.client.schema.exists(self.index_name): |
| #64 | # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier |
| #65 | # The none vectorizer is crucial as we have our own custom embedding function |
| #66 | """ |
| #67 | TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying. |
| #68 | Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below. |
| #69 | """ |
| #70 | class_obj = { |
| #71 | "classes": [ |
| #72 | { |
| #73 | "class": self.index_name, |
| #74 | "vectorizer": "none", |
| #75 | "properties": [ |
| #76 | { |
| #77 | "name": "identifier", |
| #78 | "dataType": ["text"], |
| #79 | }, |
| #80 | { |
| #81 | "name": "text", |
| #82 | "dataType": ["text"], |
| #83 | }, |
| #84 | { |
| #85 | "name": "metadata", |
| #86 | "dataType": [self.index_name + "_metadata"], |
| #87 | }, |
| #88 | ], |
| #89 | }, |
| #90 | { |
| #91 | "class": self.index_name + "_metadata", |
| #92 | "vectorizer": "none", |
| #93 | "properties": [ |
| #94 | { |
| #95 | "name": "data_type", |
| #96 | "dataType": ["text"], |
| #97 | }, |
| #98 | { |
| #99 | "name": "doc_id", |
| #100 | "dataType": ["text"], |
| #101 | }, |
| #102 | { |
| #103 | "name": "url", |
| #104 | "dataType": ["text"], |
| #105 | }, |
| #106 | { |
| #107 | "name": "hash", |
| #108 | "dataType": ["text"], |
| #109 | }, |
| #110 | { |
| #111 | "name": "app_id", |
| #112 | "dataType": ["text"], |
| #113 | }, |
| #114 | ], |
| #115 | }, |
| #116 | ] |
| #117 | } |
| #118 | |
| #119 | self.client.schema.create(class_obj) |
| #120 | |
| #121 | def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): |
| #122 | """ |
| #123 | Get existing doc ids present in vector database |
| #124 | :param ids: _list of doc ids to check for existance |
| #125 | :type ids: list[str] |
| #126 | :param where: to filter data |
| #127 | :type where: dict[str, any] |
| #128 | :return: ids |
| #129 | :rtype: Set[str] |
| #130 | """ |
| #131 | weaviate_where_operands = [] |
| #132 | |
| #133 | if ids: |
| #134 | for doc_id in ids: |
| #135 | weaviate_where_operands.append({"path": ["identifier"], "operator": "Equal", "valueText": doc_id}) |
| #136 | |
| #137 | keys = set(where.keys() if where is not None else set()) |
| #138 | if len(keys) > 0: |
| #139 | for key in keys: |
| #140 | weaviate_where_operands.append( |
| #141 | { |
| #142 | "path": ["metadata", self.index_name + "_metadata", key], |
| #143 | "operator": "Equal", |
| #144 | "valueText": where.get(key), |
| #145 | } |
| #146 | ) |
| #147 | |
| #148 | if len(weaviate_where_operands) == 1: |
| #149 | weaviate_where_clause = weaviate_where_operands[0] |
| #150 | else: |
| #151 | weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands} |
| #152 | |
| #153 | existing_ids = [] |
| #154 | metadatas = [] |
| #155 | cursor = None |
| #156 | offset = 0 |
| #157 | has_iterated_once = False |
| #158 | query_metadata_keys = self.metadata_keys.union(keys) |
| #159 | while cursor is not None or not has_iterated_once: |
| #160 | has_iterated_once = True |
| #161 | results = self._query_with_offset( |
| #162 | self.client.query.get( |
| #163 | self.index_name, |
| #164 | [ |
| #165 | "identifier", |
| #166 | weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)), |
| #167 | ], |
| #168 | ) |
| #169 | .with_where(weaviate_where_clause) |
| #170 | .with_additional(["id"]) |
| #171 | .with_limit(limit or self.batch_size), |
| #172 | offset, |
| #173 | ) |
| #174 | |
| #175 | fetched_results = results["data"]["Get"].get(self.index_name, []) |
| #176 | if not fetched_results: |
| #177 | break |
| #178 | |
| #179 | for result in fetched_results: |
| #180 | existing_ids.append(result["identifier"]) |
| #181 | metadatas.append(result["metadata"][0]) |
| #182 | cursor = result["_additional"]["id"] |
| #183 | offset += 1 |
| #184 | |
| #185 | if limit is not None and len(existing_ids) >= limit: |
| #186 | break |
| #187 | |
| #188 | return {"ids": existing_ids, "metadatas": metadatas} |
| #189 | |
| #190 | def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]): |
| #191 | """add data in vector database |
| #192 | :param documents: list of texts to add |
| #193 | :type documents: list[str] |
| #194 | :param metadatas: list of metadata associated with docs |
| #195 | :type metadatas: list[object] |
| #196 | :param ids: ids of docs |
| #197 | :type ids: list[str] |
| #198 | """ |
| #199 | embeddings = self.embedder.embedding_fn(documents) |
| #200 | self.client.batch.configure(batch_size=self.batch_size, timeout_retries=3) # Configure batch |
| #201 | with self.client.batch as batch: # Initialize a batch process |
| #202 | for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): |
| #203 | doc = {"identifier": id, "text": text} |
| #204 | updated_metadata = {"text": text} |
| #205 | if metadata is not None: |
| #206 | updated_metadata.update(**metadata) |
| #207 | |
| #208 | obj_uuid = batch.add_data_object( |
| #209 | data_object=copy.deepcopy(doc), class_name=self.index_name, vector=embedding |
| #210 | ) |
| #211 | metadata_uuid = batch.add_data_object( |
| #212 | data_object=copy.deepcopy(updated_metadata), |
| #213 | class_name=self.index_name + "_metadata", |
| #214 | vector=embedding, |
| #215 | ) |
| #216 | batch.add_reference( |
| #217 | obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata", **kwargs |
| #218 | ) |
| #219 | |
| #220 | def query( |
| #221 | self, input_query: str, n_results: int, where: dict[str, any], citations: bool = False |
| #222 | ) -> Union[list[tuple[str, dict]], list[str]]: |
| #223 | """ |
| #224 | query contents from vector database based on vector similarity |
| #225 | :param input_query: query string |
| #226 | :type input_query: str |
| #227 | :param n_results: no of similar documents to fetch from database |
| #228 | :type n_results: int |
| #229 | :param where: Optional. to filter data |
| #230 | :type where: dict[str, any] |
| #231 | :param citations: we use citations boolean param to return context along with the answer. |
| #232 | :type citations: bool, default is False. |
| #233 | :return: The content of the document that matched your query, |
| #234 | along with url of the source and doc_id (if citations flag is true) |
| #235 | :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] |
| #236 | """ |
| #237 | query_vector = self.embedder.embedding_fn([input_query])[0] |
| #238 | keys = set(where.keys() if where is not None else set()) |
| #239 | data_fields = ["text"] |
| #240 | query_metadata_keys = self.metadata_keys.union(keys) |
| #241 | if citations: |
| #242 | data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys))) |
| #243 | |
| #244 | if len(keys) > 0: |
| #245 | weaviate_where_operands = [] |
| #246 | for key in keys: |
| #247 | weaviate_where_operands.append( |
| #248 | { |
| #249 | "path": ["metadata", self.index_name + "_metadata", key], |
| #250 | "operator": "Equal", |
| #251 | "valueText": where.get(key), |
| #252 | } |
| #253 | ) |
| #254 | if len(weaviate_where_operands) == 1: |
| #255 | weaviate_where_clause = weaviate_where_operands[0] |
| #256 | else: |
| #257 | weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands} |
| #258 | |
| #259 | results = ( |
| #260 | self.client.query.get(self.index_name, data_fields) |
| #261 | .with_where(weaviate_where_clause) |
| #262 | .with_near_vector({"vector": query_vector}) |
| #263 | .with_limit(n_results) |
| #264 | .with_additional(["distance"]) |
| #265 | .do() |
| #266 | ) |
| #267 | else: |
| #268 | results = ( |
| #269 | self.client.query.get(self.index_name, data_fields) |
| #270 | .with_near_vector({"vector": query_vector}) |
| #271 | .with_limit(n_results) |
| #272 | .with_additional(["distance"]) |
| #273 | .do() |
| #274 | ) |
| #275 | |
| #276 | if results["data"]["Get"].get(self.index_name) is None: |
| #277 | return [] |
| #278 | |
| #279 | docs = results["data"]["Get"].get(self.index_name) |
| #280 | contexts = [] |
| #281 | for doc in docs: |
| #282 | context = doc["text"] |
| #283 | if citations: |
| #284 | metadata = doc["metadata"][0] |
| #285 | score = doc["_additional"]["distance"] |
| #286 | metadata["score"] = score |
| #287 | contexts.append((context, metadata)) |
| #288 | else: |
| #289 | contexts.append(context) |
| #290 | return contexts |
| #291 | |
| #292 | def set_collection_name(self, name: str): |
| #293 | """ |
| #294 | Set the name of the collection. A collection is an isolated space for vectors. |
| #295 | :param name: Name of the collection. |
| #296 | :type name: str |
| #297 | """ |
| #298 | if not isinstance(name, str): |
| #299 | raise TypeError("Collection name must be a string") |
| #300 | self.config.collection_name = name |
| #301 | |
| #302 | def count(self) -> int: |
| #303 | """ |
| #304 | Count number of documents/chunks embedded in the database. |
| #305 | :return: number of documents |
| #306 | :rtype: int |
| #307 | """ |
| #308 | data = self.client.query.aggregate(self.index_name).with_meta_count().do() |
| #309 | return data["data"]["Aggregate"].get(self.index_name)[0]["meta"]["count"] |
| #310 | |
| #311 | def _get_or_create_db(self): |
| #312 | """Called during initialization""" |
| #313 | return self.client |
| #314 | |
| #315 | def reset(self): |
| #316 | """ |
| #317 | Resets the database. Deletes all embeddings irreversibly. |
| #318 | """ |
| #319 | # Delete all data from the database |
| #320 | self.client.batch.delete_objects( |
| #321 | self.index_name, where={"path": ["identifier"], "operator": "Like", "valueText": ".*"} |
| #322 | ) |
| #323 | |
| #324 | # Weaviate internally by default capitalizes the class name |
| #325 | def _get_index_name(self) -> str: |
| #326 | """Get the Weaviate index for a collection |
| #327 | :return: Weaviate index |
| #328 | :rtype: str |
| #329 | """ |
| #330 | return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize().replace("-", "_") |
| #331 | |
| #332 | @staticmethod |
| #333 | def _query_with_offset(query, offset): |
| #334 | if offset: |
| #335 | query.with_offset(offset) |
| #336 | results = query.do() |
| #337 | return results |
| #338 | |
| #339 | def _generate_query(self, where: dict): |
| #340 | weaviate_where_operands = [] |
| #341 | for key, value in where.items(): |
| #342 | weaviate_where_operands.append( |
| #343 | { |
| #344 | "path": ["metadata", self.index_name + "_metadata", key], |
| #345 | "operator": "Equal", |
| #346 | "valueText": value, |
| #347 | } |
| #348 | ) |
| #349 | |
| #350 | if len(weaviate_where_operands) == 1: |
| #351 | weaviate_where_clause = weaviate_where_operands[0] |
| #352 | else: |
| #353 | weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands} |
| #354 | |
| #355 | return weaviate_where_clause |
| #356 | |
| #357 | def delete(self, where: dict): |
| #358 | """Delete from database. |
| #359 | :param where: to filter data |
| #360 | :type where: dict[str, any] |
| #361 | """ |
| #362 | query = self._generate_query(where) |
| #363 | self.client.batch.delete_objects(self.index_name, where=query) |
| #364 |