repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | import time |
| #3 | from typing import Any, Optional, Union |
| #4 | |
| #5 | from tqdm import tqdm |
| #6 | |
| #7 | try: |
| #8 | from opensearchpy import OpenSearch |
| #9 | from opensearchpy.helpers import bulk |
| #10 | except ImportError: |
| #11 | raise ImportError( |
| #12 | "OpenSearch requires extra dependencies. Install with `pip install --upgrade embedchain[opensearch]`" |
| #13 | ) from None |
| #14 | |
| #15 | from langchain_community.embeddings.openai import OpenAIEmbeddings |
| #16 | from langchain_community.vectorstores import OpenSearchVectorSearch |
| #17 | |
| #18 | from embedchain.config import OpenSearchDBConfig |
| #19 | from embedchain.helpers.json_serializable import register_deserializable |
| #20 | from embedchain.vectordb.base import BaseVectorDB |
| #21 | |
| #22 | logger = logging.getLogger(__name__) |
| #23 | |
| #24 | |
| #25 | @register_deserializable |
| #26 | class OpenSearchDB(BaseVectorDB): |
| #27 | """ |
| #28 | OpenSearch as vector database |
| #29 | """ |
| #30 | |
| #31 | def __init__(self, config: OpenSearchDBConfig): |
| #32 | """OpenSearch as vector database. |
| #33 | |
| #34 | :param config: OpenSearch domain config |
| #35 | :type config: OpenSearchDBConfig |
| #36 | """ |
| #37 | if config is None: |
| #38 | raise ValueError("OpenSearchDBConfig is required") |
| #39 | self.config = config |
| #40 | self.batch_size = self.config.batch_size |
| #41 | self.client = OpenSearch( |
| #42 | hosts=[self.config.opensearch_url], |
| #43 | http_auth=self.config.http_auth, |
| #44 | **self.config.extra_params, |
| #45 | ) |
| #46 | info = self.client.info() |
| #47 | logger.info(f"Connected to {info['version']['distribution']}. Version: {info['version']['number']}") |
| #48 | # Remove auth credentials from config after successful connection |
| #49 | super().__init__(config=self.config) |
| #50 | |
| #51 | def _initialize(self): |
| #52 | logger.info(self.client.info()) |
| #53 | index_name = self._get_index() |
| #54 | if self.client.indices.exists(index=index_name): |
| #55 | print(f"Index '{index_name}' already exists.") |
| #56 | return |
| #57 | |
| #58 | index_body = { |
| #59 | "settings": {"knn": True}, |
| #60 | "mappings": { |
| #61 | "properties": { |
| #62 | "text": {"type": "text"}, |
| #63 | "embeddings": { |
| #64 | "type": "knn_vector", |
| #65 | "index": False, |
| #66 | "dimension": self.config.vector_dimension, |
| #67 | }, |
| #68 | } |
| #69 | }, |
| #70 | } |
| #71 | self.client.indices.create(index_name, body=index_body) |
| #72 | print(self.client.indices.get(index_name)) |
| #73 | |
| #74 | def _get_or_create_db(self): |
| #75 | """Called during initialization""" |
| #76 | return self.client |
| #77 | |
| #78 | def _get_or_create_collection(self, name): |
| #79 | """Note: nothing to return here. Discuss later""" |
| #80 | |
| #81 | def get( |
| #82 | self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None |
| #83 | ) -> set[str]: |
| #84 | """ |
| #85 | Get existing doc ids present in vector database |
| #86 | |
| #87 | :param ids: _list of doc ids to check for existence |
| #88 | :type ids: list[str] |
| #89 | :param where: to filter data |
| #90 | :type where: dict[str, any] |
| #91 | :return: ids |
| #92 | :type: set[str] |
| #93 | """ |
| #94 | query = {} |
| #95 | if ids: |
| #96 | query["query"] = {"bool": {"must": [{"ids": {"values": ids}}]}} |
| #97 | else: |
| #98 | query["query"] = {"bool": {"must": []}} |
| #99 | |
| #100 | if where: |
| #101 | for key, value in where.items(): |
| #102 | query["query"]["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}}) |
| #103 | |
| #104 | # OpenSearch syntax is different from Elasticsearch |
| #105 | response = self.client.search(index=self._get_index(), body=query, _source=True, size=limit) |
| #106 | docs = response["hits"]["hits"] |
| #107 | ids = [doc["_id"] for doc in docs] |
| #108 | doc_ids = [doc["_source"]["metadata"]["doc_id"] for doc in docs] |
| #109 | |
| #110 | # Result is modified for compatibility with other vector databases |
| #111 | # TODO: Add method in vector database to return result in a standard format |
| #112 | result = {"ids": ids, "metadatas": []} |
| #113 | |
| #114 | for doc_id in doc_ids: |
| #115 | result["metadatas"].append({"doc_id": doc_id}) |
| #116 | return result |
| #117 | |
| #118 | def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]): |
| #119 | """Adds documents to the opensearch index""" |
| #120 | |
| #121 | embeddings = self.embedder.embedding_fn(documents) |
| #122 | for batch_start in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in opensearch"): |
| #123 | batch_end = batch_start + self.batch_size |
| #124 | batch_documents = documents[batch_start:batch_end] |
| #125 | batch_embeddings = embeddings[batch_start:batch_end] |
| #126 | |
| #127 | # Create document entries for bulk upload |
| #128 | batch_entries = [ |
| #129 | { |
| #130 | "_index": self._get_index(), |
| #131 | "_id": doc_id, |
| #132 | "_source": {"text": text, "metadata": metadata, "embeddings": embedding}, |
| #133 | } |
| #134 | for doc_id, text, metadata, embedding in zip( |
| #135 | ids[batch_start:batch_end], batch_documents, metadatas[batch_start:batch_end], batch_embeddings |
| #136 | ) |
| #137 | ] |
| #138 | |
| #139 | # Perform bulk operation |
| #140 | bulk(self.client, batch_entries, **kwargs) |
| #141 | self.client.indices.refresh(index=self._get_index()) |
| #142 | |
| #143 | # Sleep to avoid rate limiting |
| #144 | time.sleep(0.1) |
| #145 | |
| #146 | def query( |
| #147 | self, |
| #148 | input_query: str, |
| #149 | n_results: int, |
| #150 | where: dict[str, any], |
| #151 | citations: bool = False, |
| #152 | **kwargs: Optional[dict[str, Any]], |
| #153 | ) -> Union[list[tuple[str, dict]], list[str]]: |
| #154 | """ |
| #155 | query contents from vector database based on vector similarity |
| #156 | |
| #157 | :param input_query: query string |
| #158 | :type input_query: str |
| #159 | :param n_results: no of similar documents to fetch from database |
| #160 | :type n_results: int |
| #161 | :param where: Optional. to filter data |
| #162 | :type where: dict[str, any] |
| #163 | :param citations: we use citations boolean param to return context along with the answer. |
| #164 | :type citations: bool, default is False. |
| #165 | :return: The content of the document that matched your query, |
| #166 | along with url of the source and doc_id (if citations flag is true) |
| #167 | :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] |
| #168 | """ |
| #169 | embeddings = OpenAIEmbeddings() |
| #170 | docsearch = OpenSearchVectorSearch( |
| #171 | index_name=self._get_index(), |
| #172 | embedding_function=embeddings, |
| #173 | opensearch_url=f"{self.config.opensearch_url}", |
| #174 | http_auth=self.config.http_auth, |
| #175 | use_ssl=hasattr(self.config, "use_ssl") and self.config.use_ssl, |
| #176 | verify_certs=hasattr(self.config, "verify_certs") and self.config.verify_certs, |
| #177 | ) |
| #178 | |
| #179 | pre_filter = {"match_all": {}} # default |
| #180 | if len(where) > 0: |
| #181 | pre_filter = {"bool": {"must": []}} |
| #182 | for key, value in where.items(): |
| #183 | pre_filter["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}}) |
| #184 | |
| #185 | docs = docsearch.similarity_search_with_score( |
| #186 | input_query, |
| #187 | search_type="script_scoring", |
| #188 | space_type="cosinesimil", |
| #189 | vector_field="embeddings", |
| #190 | text_field="text", |
| #191 | metadata_field="metadata", |
| #192 | pre_filter=pre_filter, |
| #193 | k=n_results, |
| #194 | **kwargs, |
| #195 | ) |
| #196 | |
| #197 | contexts = [] |
| #198 | for doc, score in docs: |
| #199 | context = doc.page_content |
| #200 | if citations: |
| #201 | metadata = doc.metadata |
| #202 | metadata["score"] = score |
| #203 | contexts.append(tuple((context, metadata))) |
| #204 | else: |
| #205 | contexts.append(context) |
| #206 | return contexts |
| #207 | |
| #208 | def set_collection_name(self, name: str): |
| #209 | """ |
| #210 | Set the name of the collection. A collection is an isolated space for vectors. |
| #211 | |
| #212 | :param name: Name of the collection. |
| #213 | :type name: str |
| #214 | """ |
| #215 | if not isinstance(name, str): |
| #216 | raise TypeError("Collection name must be a string") |
| #217 | self.config.collection_name = name |
| #218 | |
| #219 | def count(self) -> int: |
| #220 | """ |
| #221 | Count number of documents/chunks embedded in the database. |
| #222 | |
| #223 | :return: number of documents |
| #224 | :rtype: int |
| #225 | """ |
| #226 | query = {"query": {"match_all": {}}} |
| #227 | response = self.client.count(index=self._get_index(), body=query) |
| #228 | doc_count = response["count"] |
| #229 | return doc_count |
| #230 | |
| #231 | def reset(self): |
| #232 | """ |
| #233 | Resets the database. Deletes all embeddings irreversibly. |
| #234 | """ |
| #235 | # Delete all data from the database |
| #236 | if self.client.indices.exists(index=self._get_index()): |
| #237 | # delete index in ES |
| #238 | self.client.indices.delete(index=self._get_index()) |
| #239 | |
| #240 | def delete(self, where): |
| #241 | """Deletes a document from the OpenSearch index""" |
| #242 | query = {"query": {"bool": {"must": []}}} |
| #243 | for key, value in where.items(): |
| #244 | query["query"]["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}}) |
| #245 | self.client.delete_by_query(index=self._get_index(), body=query) |
| #246 | |
| #247 | def _get_index(self) -> str: |
| #248 | """Get the OpenSearch index for a collection |
| #249 | |
| #250 | :return: OpenSearch index |
| #251 | :rtype: str |
| #252 | """ |
| #253 | return self.config.collection_name |
| #254 |