repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | from typing import Any, Optional, Union |
| #3 | |
| #4 | from chromadb import Collection, QueryResult |
| #5 | from langchain.docstore.document import Document |
| #6 | from tqdm import tqdm |
| #7 | |
| #8 | from embedchain.config import ChromaDbConfig |
| #9 | from embedchain.helpers.json_serializable import register_deserializable |
| #10 | from embedchain.vectordb.base import BaseVectorDB |
| #11 | |
| #12 | try: |
| #13 | import chromadb |
| #14 | from chromadb.config import Settings |
| #15 | from chromadb.errors import InvalidDimensionException |
| #16 | except RuntimeError: |
| #17 | from embedchain.utils.misc import use_pysqlite3 |
| #18 | |
| #19 | use_pysqlite3() |
| #20 | import chromadb |
| #21 | from chromadb.config import Settings |
| #22 | from chromadb.errors import InvalidDimensionException |
| #23 | |
| #24 | |
| #25 | logger = logging.getLogger(__name__) |
| #26 | |
| #27 | |
| #28 | @register_deserializable |
| #29 | class ChromaDB(BaseVectorDB): |
| #30 | """Vector database using ChromaDB.""" |
| #31 | |
| #32 | def __init__(self, config: Optional[ChromaDbConfig] = None): |
| #33 | """Initialize a new ChromaDB instance |
| #34 | |
| #35 | :param config: Configuration options for Chroma, defaults to None |
| #36 | :type config: Optional[ChromaDbConfig], optional |
| #37 | """ |
| #38 | if config: |
| #39 | self.config = config |
| #40 | else: |
| #41 | self.config = ChromaDbConfig() |
| #42 | |
| #43 | self.settings = Settings(anonymized_telemetry=False) |
| #44 | self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False |
| #45 | self.batch_size = self.config.batch_size |
| #46 | if self.config.chroma_settings: |
| #47 | for key, value in self.config.chroma_settings.items(): |
| #48 | if hasattr(self.settings, key): |
| #49 | setattr(self.settings, key, value) |
| #50 | |
| #51 | if self.config.host and self.config.port: |
| #52 | logger.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}") |
| #53 | self.settings.chroma_server_host = self.config.host |
| #54 | self.settings.chroma_server_http_port = self.config.port |
| #55 | self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI" |
| #56 | else: |
| #57 | if self.config.dir is None: |
| #58 | self.config.dir = "db" |
| #59 | |
| #60 | self.settings.persist_directory = self.config.dir |
| #61 | self.settings.is_persistent = True |
| #62 | |
| #63 | self.client = chromadb.Client(self.settings) |
| #64 | super().__init__(config=self.config) |
| #65 | |
| #66 | def _initialize(self): |
| #67 | """ |
| #68 | This method is needed because `embedder` attribute needs to be set externally before it can be initialized. |
| #69 | """ |
| #70 | if not self.embedder: |
| #71 | raise ValueError( |
| #72 | "Embedder not set. Please set an embedder with `_set_embedder()` function before initialization." |
| #73 | ) |
| #74 | self._get_or_create_collection(self.config.collection_name) |
| #75 | |
| #76 | def _get_or_create_db(self): |
| #77 | """Called during initialization""" |
| #78 | return self.client |
| #79 | |
| #80 | @staticmethod |
| #81 | def _generate_where_clause(where: dict[str, any]) -> dict[str, any]: |
| #82 | # If only one filter is supplied, return it as is |
| #83 | # (no need to wrap in $and based on chroma docs) |
| #84 | if where is None: |
| #85 | return {} |
| #86 | if len(where.keys()) <= 1: |
| #87 | return where |
| #88 | where_filters = [] |
| #89 | for k, v in where.items(): |
| #90 | if isinstance(v, str): |
| #91 | where_filters.append({k: v}) |
| #92 | return {"$and": where_filters} |
| #93 | |
| #94 | def _get_or_create_collection(self, name: str) -> Collection: |
| #95 | """ |
| #96 | Get or create a named collection. |
| #97 | |
| #98 | :param name: Name of the collection |
| #99 | :type name: str |
| #100 | :raises ValueError: No embedder configured. |
| #101 | :return: Created collection |
| #102 | :rtype: Collection |
| #103 | """ |
| #104 | if not hasattr(self, "embedder") or not self.embedder: |
| #105 | raise ValueError("Cannot create a Chroma database collection without an embedder.") |
| #106 | self.collection = self.client.get_or_create_collection( |
| #107 | name=name, |
| #108 | embedding_function=self.embedder.embedding_fn, |
| #109 | ) |
| #110 | return self.collection |
| #111 | |
| #112 | def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): |
| #113 | """ |
| #114 | Get existing doc ids present in vector database |
| #115 | |
| #116 | :param ids: list of doc ids to check for existence |
| #117 | :type ids: list[str] |
| #118 | :param where: Optional. to filter data |
| #119 | :type where: dict[str, Any] |
| #120 | :param limit: Optional. maximum number of documents |
| #121 | :type limit: Optional[int] |
| #122 | :return: Existing documents. |
| #123 | :rtype: list[str] |
| #124 | """ |
| #125 | args = {} |
| #126 | if ids: |
| #127 | args["ids"] = ids |
| #128 | if where: |
| #129 | args["where"] = self._generate_where_clause(where) |
| #130 | if limit: |
| #131 | args["limit"] = limit |
| #132 | return self.collection.get(**args) |
| #133 | |
| #134 | def add( |
| #135 | self, |
| #136 | documents: list[str], |
| #137 | metadatas: list[object], |
| #138 | ids: list[str], |
| #139 | **kwargs: Optional[dict[str, Any]], |
| #140 | ) -> Any: |
| #141 | """ |
| #142 | Add vectors to chroma database |
| #143 | |
| #144 | :param documents: Documents |
| #145 | :type documents: list[str] |
| #146 | :param metadatas: Metadatas |
| #147 | :type metadatas: list[object] |
| #148 | :param ids: ids |
| #149 | :type ids: list[str] |
| #150 | """ |
| #151 | size = len(documents) |
| #152 | if len(documents) != size or len(metadatas) != size or len(ids) != size: |
| #153 | raise ValueError( |
| #154 | "Cannot add documents to chromadb with inconsistent sizes. Documents size: {}, Metadata size: {}," |
| #155 | " Ids size: {}".format(len(documents), len(metadatas), len(ids)) |
| #156 | ) |
| #157 | |
| #158 | for i in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in chromadb"): |
| #159 | self.collection.add( |
| #160 | documents=documents[i : i + self.batch_size], |
| #161 | metadatas=metadatas[i : i + self.batch_size], |
| #162 | ids=ids[i : i + self.batch_size], |
| #163 | ) |
| #164 | self.config |
| #165 | |
| #166 | @staticmethod |
| #167 | def _format_result(results: QueryResult) -> list[tuple[Document, float]]: |
| #168 | """ |
| #169 | Format Chroma results |
| #170 | |
| #171 | :param results: ChromaDB query results to format. |
| #172 | :type results: QueryResult |
| #173 | :return: Formatted results |
| #174 | :rtype: list[tuple[Document, float]] |
| #175 | """ |
| #176 | return [ |
| #177 | (Document(page_content=result[0], metadata=result[1] or {}), result[2]) |
| #178 | for result in zip( |
| #179 | results["documents"][0], |
| #180 | results["metadatas"][0], |
| #181 | results["distances"][0], |
| #182 | ) |
| #183 | ] |
| #184 | |
| #185 | def query( |
| #186 | self, |
| #187 | input_query: str, |
| #188 | n_results: int, |
| #189 | where: Optional[dict[str, any]] = None, |
| #190 | raw_filter: Optional[dict[str, any]] = None, |
| #191 | citations: bool = False, |
| #192 | **kwargs: Optional[dict[str, any]], |
| #193 | ) -> Union[list[tuple[str, dict]], list[str]]: |
| #194 | """ |
| #195 | Query contents from vector database based on vector similarity |
| #196 | |
| #197 | :param input_query: query string |
| #198 | :type input_query: str |
| #199 | :param n_results: no of similar documents to fetch from database |
| #200 | :type n_results: int |
| #201 | :param where: to filter data |
| #202 | :type where: dict[str, Any] |
| #203 | :param raw_filter: Raw filter to apply |
| #204 | :type raw_filter: dict[str, Any] |
| #205 | :param citations: we use citations boolean param to return context along with the answer. |
| #206 | :type citations: bool, default is False. |
| #207 | :raises InvalidDimensionException: Dimensions do not match. |
| #208 | :return: The content of the document that matched your query, |
| #209 | along with url of the source and doc_id (if citations flag is true) |
| #210 | :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] |
| #211 | """ |
| #212 | if where and raw_filter: |
| #213 | raise ValueError("Both `where` and `raw_filter` cannot be used together.") |
| #214 | |
| #215 | where_clause = None |
| #216 | if raw_filter: |
| #217 | where_clause = raw_filter |
| #218 | if where: |
| #219 | where_clause = self._generate_where_clause(where) |
| #220 | try: |
| #221 | result = self.collection.query( |
| #222 | query_texts=[ |
| #223 | input_query, |
| #224 | ], |
| #225 | n_results=n_results, |
| #226 | where=where_clause, |
| #227 | ) |
| #228 | except InvalidDimensionException as e: |
| #229 | raise InvalidDimensionException( |
| #230 | e.message() |
| #231 | + ". This is commonly a side-effect when an embedding function, different from the one used to add the" |
| #232 | " embeddings, is used to retrieve an embedding from the database." |
| #233 | ) from None |
| #234 | results_formatted = self._format_result(result) |
| #235 | contexts = [] |
| #236 | for result in results_formatted: |
| #237 | context = result[0].page_content |
| #238 | if citations: |
| #239 | metadata = result[0].metadata |
| #240 | metadata["score"] = result[1] |
| #241 | contexts.append((context, metadata)) |
| #242 | else: |
| #243 | contexts.append(context) |
| #244 | return contexts |
| #245 | |
| #246 | def set_collection_name(self, name: str): |
| #247 | """ |
| #248 | Set the name of the collection. A collection is an isolated space for vectors. |
| #249 | |
| #250 | :param name: Name of the collection. |
| #251 | :type name: str |
| #252 | """ |
| #253 | if not isinstance(name, str): |
| #254 | raise TypeError("Collection name must be a string") |
| #255 | self.config.collection_name = name |
| #256 | self._get_or_create_collection(self.config.collection_name) |
| #257 | |
| #258 | def count(self) -> int: |
| #259 | """ |
| #260 | Count number of documents/chunks embedded in the database. |
| #261 | |
| #262 | :return: number of documents |
| #263 | :rtype: int |
| #264 | """ |
| #265 | return self.collection.count() |
| #266 | |
| #267 | def delete(self, where): |
| #268 | return self.collection.delete(where=self._generate_where_clause(where)) |
| #269 | |
| #270 | def reset(self): |
| #271 | """ |
| #272 | Resets the database. Deletes all embeddings irreversibly. |
| #273 | """ |
| #274 | # Delete all data from the collection |
| #275 | try: |
| #276 | self.client.delete_collection(self.config.collection_name) |
| #277 | except ValueError: |
| #278 | raise ValueError( |
| #279 | "For safety reasons, resetting is disabled. " |
| #280 | "Please enable it by setting `allow_reset=True` in your ChromaDbConfig" |
| #281 | ) from None |
| #282 | # Recreate |
| #283 | self._get_or_create_collection(self.config.collection_name) |
| #284 | |
| #285 | # Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset. |
| #286 | # A downside of this implementation is, if you have two instances, |
| #287 | # the other instance will not get the updated `self.collection` attribute. |
| #288 | # A better way would be to create the collection if it is called again after being reset. |
| #289 | # That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't. |
| #290 | # That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do. |
| #291 |