repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | from typing import Any, Dict, List, Optional, Union |
| #2 | |
| #3 | import pyarrow as pa |
| #4 | |
| #5 | try: |
| #6 | import lancedb |
| #7 | except ImportError: |
| #8 | raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None |
| #9 | |
| #10 | from embedchain.config.vector_db.lancedb import LanceDBConfig |
| #11 | from embedchain.helpers.json_serializable import register_deserializable |
| #12 | from embedchain.vectordb.base import BaseVectorDB |
| #13 | |
| #14 | |
| #15 | @register_deserializable |
| #16 | class LanceDB(BaseVectorDB): |
| #17 | """ |
| #18 | LanceDB as vector database |
| #19 | """ |
| #20 | |
| #21 | def __init__( |
| #22 | self, |
| #23 | config: Optional[LanceDBConfig] = None, |
| #24 | ): |
| #25 | """LanceDB as vector database. |
| #26 | |
| #27 | :param config: LanceDB database config, defaults to None |
| #28 | :type config: LanceDBConfig, optional |
| #29 | """ |
| #30 | if config: |
| #31 | self.config = config |
| #32 | else: |
| #33 | self.config = LanceDBConfig() |
| #34 | |
| #35 | self.client = lancedb.connect(self.config.dir or "~/.lancedb") |
| #36 | self.embedder_check = True |
| #37 | |
| #38 | super().__init__(config=self.config) |
| #39 | |
| #40 | def _initialize(self): |
| #41 | """ |
| #42 | This method is needed because `embedder` attribute needs to be set externally before it can be initialized. |
| #43 | """ |
| #44 | if not self.embedder: |
| #45 | raise ValueError( |
| #46 | "Embedder not set. Please set an embedder with `_set_embedder()` function before initialization." |
| #47 | ) |
| #48 | else: |
| #49 | # check embedder function is working or not |
| #50 | try: |
| #51 | self.embedder.embedding_fn("Hello LanceDB") |
| #52 | except Exception: |
| #53 | self.embedder_check = False |
| #54 | |
| #55 | self._get_or_create_collection(self.config.collection_name) |
| #56 | |
| #57 | def _get_or_create_db(self): |
| #58 | """ |
| #59 | Called during initialization |
| #60 | """ |
| #61 | return self.client |
| #62 | |
| #63 | def _generate_where_clause(self, where: Dict[str, any]) -> str: |
| #64 | """ |
| #65 | This method generate where clause using dictionary containing attributes and their values |
| #66 | """ |
| #67 | |
| #68 | where_filters = "" |
| #69 | |
| #70 | if len(list(where.keys())) == 1: |
| #71 | where_filters = f"{list(where.keys())[0]} = {list(where.values())[0]}" |
| #72 | return where_filters |
| #73 | |
| #74 | where_items = list(where.items()) |
| #75 | where_count = len(where_items) |
| #76 | |
| #77 | for i, (key, value) in enumerate(where_items, start=1): |
| #78 | condition = f"{key} = {value} AND " |
| #79 | where_filters += condition |
| #80 | |
| #81 | if i == where_count: |
| #82 | condition = f"{key} = {value}" |
| #83 | where_filters += condition |
| #84 | |
| #85 | return where_filters |
| #86 | |
| #87 | def _get_or_create_collection(self, table_name: str, reset=False): |
| #88 | """ |
| #89 | Get or create a named collection. |
| #90 | |
| #91 | :param name: Name of the collection |
| #92 | :type name: str |
| #93 | :return: Created collection |
| #94 | :rtype: Collection |
| #95 | """ |
| #96 | if not self.embedder_check: |
| #97 | schema = pa.schema( |
| #98 | [ |
| #99 | pa.field("doc", pa.string()), |
| #100 | pa.field("metadata", pa.string()), |
| #101 | pa.field("id", pa.string()), |
| #102 | ] |
| #103 | ) |
| #104 | |
| #105 | else: |
| #106 | schema = pa.schema( |
| #107 | [ |
| #108 | pa.field("vector", pa.list_(pa.float32(), list_size=self.embedder.vector_dimension)), |
| #109 | pa.field("doc", pa.string()), |
| #110 | pa.field("metadata", pa.string()), |
| #111 | pa.field("id", pa.string()), |
| #112 | ] |
| #113 | ) |
| #114 | |
| #115 | if not reset: |
| #116 | if table_name not in self.client.table_names(): |
| #117 | self.collection = self.client.create_table(table_name, schema=schema) |
| #118 | |
| #119 | else: |
| #120 | self.client.drop_table(table_name) |
| #121 | self.collection = self.client.create_table(table_name, schema=schema) |
| #122 | |
| #123 | self.collection = self.client[table_name] |
| #124 | |
| #125 | return self.collection |
| #126 | |
| #127 | def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): |
| #128 | """ |
| #129 | Get existing doc ids present in vector database |
| #130 | |
| #131 | :param ids: list of doc ids to check for existence |
| #132 | :type ids: List[str] |
| #133 | :param where: Optional. to filter data |
| #134 | :type where: Dict[str, Any] |
| #135 | :param limit: Optional. maximum number of documents |
| #136 | :type limit: Optional[int] |
| #137 | :return: Existing documents. |
| #138 | :rtype: List[str] |
| #139 | """ |
| #140 | if limit is not None: |
| #141 | max_limit = limit |
| #142 | else: |
| #143 | max_limit = 3 |
| #144 | results = {"ids": [], "metadatas": []} |
| #145 | |
| #146 | where_clause = {} |
| #147 | if where: |
| #148 | where_clause = self._generate_where_clause(where) |
| #149 | |
| #150 | if ids is not None: |
| #151 | records = ( |
| #152 | self.collection.to_lance().scanner(filter=f"id IN {tuple(ids)}", columns=["id"]).to_table().to_pydict() |
| #153 | ) |
| #154 | for id in records["id"]: |
| #155 | if where is not None: |
| #156 | result = ( |
| #157 | self.collection.search(query=id, vector_column_name="id") |
| #158 | .where(where_clause) |
| #159 | .limit(max_limit) |
| #160 | .to_list() |
| #161 | ) |
| #162 | else: |
| #163 | result = self.collection.search(query=id, vector_column_name="id").limit(max_limit).to_list() |
| #164 | results["ids"] = [r["id"] for r in result] |
| #165 | results["metadatas"] = [r["metadata"] for r in result] |
| #166 | |
| #167 | return results |
| #168 | |
| #169 | def add( |
| #170 | self, |
| #171 | documents: List[str], |
| #172 | metadatas: List[object], |
| #173 | ids: List[str], |
| #174 | ) -> Any: |
| #175 | """ |
| #176 | Add vectors to lancedb database |
| #177 | |
| #178 | :param documents: Documents |
| #179 | :type documents: List[str] |
| #180 | :param metadatas: Metadatas |
| #181 | :type metadatas: List[object] |
| #182 | :param ids: ids |
| #183 | :type ids: List[str] |
| #184 | """ |
| #185 | data = [] |
| #186 | to_ingest = list(zip(documents, metadatas, ids)) |
| #187 | |
| #188 | if not self.embedder_check: |
| #189 | for doc, meta, id in to_ingest: |
| #190 | temp = {} |
| #191 | temp["doc"] = doc |
| #192 | temp["metadata"] = str(meta) |
| #193 | temp["id"] = id |
| #194 | data.append(temp) |
| #195 | else: |
| #196 | for doc, meta, id in to_ingest: |
| #197 | temp = {} |
| #198 | temp["doc"] = doc |
| #199 | temp["vector"] = self.embedder.embedding_fn([doc])[0] |
| #200 | temp["metadata"] = str(meta) |
| #201 | temp["id"] = id |
| #202 | data.append(temp) |
| #203 | |
| #204 | self.collection.add(data=data) |
| #205 | |
| #206 | def _format_result(self, results) -> list: |
| #207 | """ |
| #208 | Format LanceDB results |
| #209 | |
| #210 | :param results: LanceDB query results to format. |
| #211 | :type results: QueryResult |
| #212 | :return: Formatted results |
| #213 | :rtype: list[tuple[Document, float]] |
| #214 | """ |
| #215 | return results.tolist() |
| #216 | |
| #217 | def query( |
| #218 | self, |
| #219 | input_query: str, |
| #220 | n_results: int = 3, |
| #221 | where: Optional[dict[str, any]] = None, |
| #222 | raw_filter: Optional[dict[str, any]] = None, |
| #223 | citations: bool = False, |
| #224 | **kwargs: Optional[dict[str, any]], |
| #225 | ) -> Union[list[tuple[str, dict]], list[str]]: |
| #226 | """ |
| #227 | Query contents from vector database based on vector similarity |
| #228 | |
| #229 | :param input_query: query string |
| #230 | :type input_query: str |
| #231 | :param n_results: no of similar documents to fetch from database |
| #232 | :type n_results: int |
| #233 | :param where: to filter data |
| #234 | :type where: dict[str, Any] |
| #235 | :param raw_filter: Raw filter to apply |
| #236 | :type raw_filter: dict[str, Any] |
| #237 | :param citations: we use citations boolean param to return context along with the answer. |
| #238 | :type citations: bool, default is False. |
| #239 | :raises InvalidDimensionException: Dimensions do not match. |
| #240 | :return: The content of the document that matched your query, |
| #241 | along with url of the source and doc_id (if citations flag is true) |
| #242 | :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] |
| #243 | """ |
| #244 | if where and raw_filter: |
| #245 | raise ValueError("Both `where` and `raw_filter` cannot be used together.") |
| #246 | try: |
| #247 | query_embedding = self.embedder.embedding_fn(input_query)[0] |
| #248 | result = self.collection.search(query_embedding).limit(n_results).to_list() |
| #249 | except Exception as e: |
| #250 | e.message() |
| #251 | |
| #252 | results_formatted = result |
| #253 | |
| #254 | contexts = [] |
| #255 | for result in results_formatted: |
| #256 | if citations: |
| #257 | metadata = result["metadata"] |
| #258 | contexts.append((result["doc"], metadata)) |
| #259 | else: |
| #260 | contexts.append(result["doc"]) |
| #261 | return contexts |
| #262 | |
| #263 | def set_collection_name(self, name: str): |
| #264 | """ |
| #265 | Set the name of the collection. A collection is an isolated space for vectors. |
| #266 | |
| #267 | :param name: Name of the collection. |
| #268 | :type name: str |
| #269 | """ |
| #270 | if not isinstance(name, str): |
| #271 | raise TypeError("Collection name must be a string") |
| #272 | self.config.collection_name = name |
| #273 | self._get_or_create_collection(self.config.collection_name) |
| #274 | |
| #275 | def count(self) -> int: |
| #276 | """ |
| #277 | Count number of documents/chunks embedded in the database. |
| #278 | |
| #279 | :return: number of documents |
| #280 | :rtype: int |
| #281 | """ |
| #282 | return self.collection.count_rows() |
| #283 | |
| #284 | def delete(self, where): |
| #285 | return self.collection.delete(where=where) |
| #286 | |
| #287 | def reset(self): |
| #288 | """ |
| #289 | Resets the database. Deletes all embeddings irreversibly. |
| #290 | """ |
| #291 | # Delete all data from the collection and recreate collection |
| #292 | if self.config.allow_reset: |
| #293 | try: |
| #294 | self._get_or_create_collection(self.config.collection_name, reset=True) |
| #295 | except ValueError: |
| #296 | raise ValueError( |
| #297 | "For safety reasons, resetting is disabled. " |
| #298 | "Please enable it by setting `allow_reset=True` in your LanceDbConfig" |
| #299 | ) from None |
| #300 | # Recreate |
| #301 | else: |
| #302 | print( |
| #303 | "For safety reasons, resetting is disabled. " |
| #304 | "Please enable it by setting `allow_reset=True` in your LanceDbConfig" |
| #305 | ) |
| #306 |