repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | from typing import Any, Optional, Union |
| #3 | |
| #4 | try: |
| #5 | from elasticsearch import Elasticsearch |
| #6 | from elasticsearch.helpers import bulk |
| #7 | except ImportError: |
| #8 | raise ImportError( |
| #9 | "Elasticsearch requires extra dependencies. Install with `pip install --upgrade embedchain[elasticsearch]`" |
| #10 | ) from None |
| #11 | |
| #12 | from embedchain.config import ElasticsearchDBConfig |
| #13 | from embedchain.helpers.json_serializable import register_deserializable |
| #14 | from embedchain.utils.misc import chunks |
| #15 | from embedchain.vectordb.base import BaseVectorDB |
| #16 | |
| #17 | logger = logging.getLogger(__name__) |
| #18 | |
| #19 | |
| #20 | @register_deserializable |
| #21 | class ElasticsearchDB(BaseVectorDB): |
| #22 | """ |
| #23 | Elasticsearch as vector database |
| #24 | """ |
| #25 | |
| #26 | def __init__( |
| #27 | self, |
| #28 | config: Optional[ElasticsearchDBConfig] = None, |
| #29 | es_config: Optional[ElasticsearchDBConfig] = None, # Backwards compatibility |
| #30 | ): |
| #31 | """Elasticsearch as vector database. |
| #32 | |
| #33 | :param config: Elasticsearch database config, defaults to None |
| #34 | :type config: ElasticsearchDBConfig, optional |
| #35 | :param es_config: `es_config` is supported as an alias for `config` (for backwards compatibility), |
| #36 | defaults to None |
| #37 | :type es_config: ElasticsearchDBConfig, optional |
| #38 | :raises ValueError: No config provided |
| #39 | """ |
| #40 | if config is None and es_config is None: |
| #41 | self.config = ElasticsearchDBConfig() |
| #42 | else: |
| #43 | if not isinstance(config, ElasticsearchDBConfig): |
| #44 | raise TypeError( |
| #45 | "config is not a `ElasticsearchDBConfig` instance. " |
| #46 | "Please make sure the type is right and that you are passing an instance." |
| #47 | ) |
| #48 | self.config = config or es_config |
| #49 | if self.config.ES_URL: |
| #50 | self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS) |
| #51 | elif self.config.CLOUD_ID: |
| #52 | self.client = Elasticsearch(cloud_id=self.config.CLOUD_ID, **self.config.ES_EXTRA_PARAMS) |
| #53 | else: |
| #54 | raise ValueError( |
| #55 | "Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501 |
| #56 | ) |
| #57 | |
| #58 | self.batch_size = self.config.batch_size |
| #59 | # Call parent init here because embedder is needed |
| #60 | super().__init__(config=self.config) |
| #61 | |
| #62 | def _initialize(self): |
| #63 | """ |
| #64 | This method is needed because `embedder` attribute needs to be set externally before it can be initialized. |
| #65 | """ |
| #66 | logger.info(self.client.info()) |
| #67 | index_settings = { |
| #68 | "mappings": { |
| #69 | "properties": { |
| #70 | "text": {"type": "text"}, |
| #71 | "embeddings": {"type": "dense_vector", "index": False, "dims": self.embedder.vector_dimension}, |
| #72 | } |
| #73 | } |
| #74 | } |
| #75 | es_index = self._get_index() |
| #76 | if not self.client.indices.exists(index=es_index): |
| #77 | # create index if not exist |
| #78 | print("Creating index", es_index, index_settings) |
| #79 | self.client.indices.create(index=es_index, body=index_settings) |
| #80 | |
| #81 | def _get_or_create_db(self): |
| #82 | """Called during initialization""" |
| #83 | return self.client |
| #84 | |
| #85 | def _get_or_create_collection(self, name): |
| #86 | """Note: nothing to return here. Discuss later""" |
| #87 | |
| #88 | def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): |
| #89 | """ |
| #90 | Get existing doc ids present in vector database |
| #91 | |
| #92 | :param ids: _list of doc ids to check for existence |
| #93 | :type ids: list[str] |
| #94 | :param where: to filter data |
| #95 | :type where: dict[str, any] |
| #96 | :return: ids |
| #97 | :rtype: Set[str] |
| #98 | """ |
| #99 | if ids: |
| #100 | query = {"bool": {"must": [{"ids": {"values": ids}}]}} |
| #101 | else: |
| #102 | query = {"bool": {"must": []}} |
| #103 | |
| #104 | if where: |
| #105 | for key, value in where.items(): |
| #106 | query["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}}) |
| #107 | |
| #108 | response = self.client.search(index=self._get_index(), query=query, _source=True, size=limit) |
| #109 | docs = response["hits"]["hits"] |
| #110 | ids = [doc["_id"] for doc in docs] |
| #111 | doc_ids = [doc["_source"]["metadata"]["doc_id"] for doc in docs] |
| #112 | |
| #113 | # Result is modified for compatibility with other vector databases |
| #114 | # TODO: Add method in vector database to return result in a standard format |
| #115 | result = {"ids": ids, "metadatas": []} |
| #116 | |
| #117 | for doc_id in doc_ids: |
| #118 | result["metadatas"].append({"doc_id": doc_id}) |
| #119 | |
| #120 | return result |
| #121 | |
| #122 | def add( |
| #123 | self, |
| #124 | documents: list[str], |
| #125 | metadatas: list[object], |
| #126 | ids: list[str], |
| #127 | **kwargs: Optional[dict[str, any]], |
| #128 | ) -> Any: |
| #129 | """ |
| #130 | add data in vector database |
| #131 | :param documents: list of texts to add |
| #132 | :type documents: list[str] |
| #133 | :param metadatas: list of metadata associated with docs |
| #134 | :type metadatas: list[object] |
| #135 | :param ids: ids of docs |
| #136 | :type ids: list[str] |
| #137 | """ |
| #138 | |
| #139 | embeddings = self.embedder.embedding_fn(documents) |
| #140 | |
| #141 | for chunk in chunks( |
| #142 | list(zip(ids, documents, metadatas, embeddings)), |
| #143 | self.batch_size, |
| #144 | desc="Inserting batches in elasticsearch", |
| #145 | ): # noqa: E501 |
| #146 | ids, docs, metadatas, embeddings = [], [], [], [] |
| #147 | for id, text, metadata, embedding in chunk: |
| #148 | ids.append(id) |
| #149 | docs.append(text) |
| #150 | metadatas.append(metadata) |
| #151 | embeddings.append(embedding) |
| #152 | |
| #153 | batch_docs = [] |
| #154 | for id, text, metadata, embedding in zip(ids, docs, metadatas, embeddings): |
| #155 | batch_docs.append( |
| #156 | { |
| #157 | "_index": self._get_index(), |
| #158 | "_id": id, |
| #159 | "_source": {"text": text, "metadata": metadata, "embeddings": embedding}, |
| #160 | } |
| #161 | ) |
| #162 | bulk(self.client, batch_docs, **kwargs) |
| #163 | self.client.indices.refresh(index=self._get_index()) |
| #164 | |
| #165 | def query( |
| #166 | self, |
| #167 | input_query: str, |
| #168 | n_results: int, |
| #169 | where: dict[str, any], |
| #170 | citations: bool = False, |
| #171 | **kwargs: Optional[dict[str, Any]], |
| #172 | ) -> Union[list[tuple[str, dict]], list[str]]: |
| #173 | """ |
| #174 | query contents from vector database based on vector similarity |
| #175 | |
| #176 | :param input_query: query string |
| #177 | :type input_query: str |
| #178 | :param n_results: no of similar documents to fetch from database |
| #179 | :type n_results: int |
| #180 | :param where: Optional. to filter data |
| #181 | :type where: dict[str, any] |
| #182 | :return: The context of the document that matched your query, url of the source, doc_id |
| #183 | :param citations: we use citations boolean param to return context along with the answer. |
| #184 | :type citations: bool, default is False. |
| #185 | :return: The content of the document that matched your query, |
| #186 | along with url of the source and doc_id (if citations flag is true) |
| #187 | :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] |
| #188 | """ |
| #189 | input_query_vector = self.embedder.embedding_fn([input_query]) |
| #190 | query_vector = input_query_vector[0] |
| #191 | |
| #192 | # `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html` |
| #193 | query = { |
| #194 | "script_score": { |
| #195 | "query": {"bool": {"must": [{"exists": {"field": "text"}}]}}, |
| #196 | "script": { |
| #197 | "source": "cosineSimilarity(params.input_query_vector, 'embeddings') + 1.0", |
| #198 | "params": {"input_query_vector": query_vector}, |
| #199 | }, |
| #200 | } |
| #201 | } |
| #202 | |
| #203 | if where: |
| #204 | for key, value in where.items(): |
| #205 | query["script_score"]["query"]["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}}) |
| #206 | |
| #207 | _source = ["text", "metadata"] |
| #208 | response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results) |
| #209 | docs = response["hits"]["hits"] |
| #210 | contexts = [] |
| #211 | for doc in docs: |
| #212 | context = doc["_source"]["text"] |
| #213 | if citations: |
| #214 | metadata = doc["_source"]["metadata"] |
| #215 | metadata["score"] = doc["_score"] |
| #216 | contexts.append(tuple((context, metadata))) |
| #217 | else: |
| #218 | contexts.append(context) |
| #219 | return contexts |
| #220 | |
| #221 | def set_collection_name(self, name: str): |
| #222 | """ |
| #223 | Set the name of the collection. A collection is an isolated space for vectors. |
| #224 | |
| #225 | :param name: Name of the collection. |
| #226 | :type name: str |
| #227 | """ |
| #228 | if not isinstance(name, str): |
| #229 | raise TypeError("Collection name must be a string") |
| #230 | self.config.collection_name = name |
| #231 | |
| #232 | def count(self) -> int: |
| #233 | """ |
| #234 | Count number of documents/chunks embedded in the database. |
| #235 | |
| #236 | :return: number of documents |
| #237 | :rtype: int |
| #238 | """ |
| #239 | query = {"match_all": {}} |
| #240 | response = self.client.count(index=self._get_index(), query=query) |
| #241 | doc_count = response["count"] |
| #242 | return doc_count |
| #243 | |
| #244 | def reset(self): |
| #245 | """ |
| #246 | Resets the database. Deletes all embeddings irreversibly. |
| #247 | """ |
| #248 | # Delete all data from the database |
| #249 | if self.client.indices.exists(index=self._get_index()): |
| #250 | # delete index in Es |
| #251 | self.client.indices.delete(index=self._get_index()) |
| #252 | |
| #253 | def _get_index(self) -> str: |
| #254 | """Get the Elasticsearch index for a collection |
| #255 | |
| #256 | :return: Elasticsearch index |
| #257 | :rtype: str |
| #258 | """ |
| #259 | # NOTE: The method is preferred to an attribute, because if collection name changes, |
| #260 | # it's always up-to-date. |
| #261 | return f"{self.config.collection_name}_{self.embedder.vector_dimension}".lower() |
| #262 | |
| #263 | def delete(self, where): |
| #264 | """Delete documents from the database.""" |
| #265 | query = {"query": {"bool": {"must": []}}} |
| #266 | for key, value in where.items(): |
| #267 | query["query"]["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}}) |
| #268 | self.client.delete_by_query(index=self._get_index(), body=query) |
| #269 | self.client.indices.refresh(index=self._get_index()) |
| #270 |