repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import copy |
| #2 | import os |
| #3 | from typing import Any, Optional, Union |
| #4 | |
| #5 | try: |
| #6 | from qdrant_client import QdrantClient |
| #7 | from qdrant_client.http import models |
| #8 | from qdrant_client.http.models import Batch |
| #9 | from qdrant_client.models import Distance, VectorParams |
| #10 | except ImportError: |
| #11 | raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None |
| #12 | |
| #13 | from tqdm import tqdm |
| #14 | |
| #15 | from embedchain.config.vector_db.qdrant import QdrantDBConfig |
| #16 | from embedchain.vectordb.base import BaseVectorDB |
| #17 | |
| #18 | |
| #19 | class QdrantDB(BaseVectorDB): |
| #20 | """ |
| #21 | Qdrant as vector database |
| #22 | """ |
| #23 | |
| #24 | def __init__(self, config: QdrantDBConfig = None): |
| #25 | """ |
| #26 | Qdrant as vector database |
| #27 | :param config. Qdrant database config to be used for connection |
| #28 | """ |
| #29 | if config is None: |
| #30 | config = QdrantDBConfig() |
| #31 | else: |
| #32 | if not isinstance(config, QdrantDBConfig): |
| #33 | raise TypeError( |
| #34 | "config is not a `QdrantDBConfig` instance. " |
| #35 | "Please make sure the type is right and that you are passing an instance." |
| #36 | ) |
| #37 | self.config = config |
| #38 | self.batch_size = self.config.batch_size |
| #39 | self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")) |
| #40 | # Call parent init here because embedder is needed |
| #41 | super().__init__(config=self.config) |
| #42 | |
| #43 | def _initialize(self): |
| #44 | """ |
| #45 | This method is needed because `embedder` attribute needs to be set externally before it can be initialized. |
| #46 | """ |
| #47 | if not self.embedder: |
| #48 | raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") |
| #49 | |
| #50 | self.collection_name = self._get_or_create_collection() |
| #51 | all_collections = self.client.get_collections() |
| #52 | collection_names = [collection.name for collection in all_collections.collections] |
| #53 | if self.collection_name not in collection_names: |
| #54 | self.client.recreate_collection( |
| #55 | collection_name=self.collection_name, |
| #56 | vectors_config=VectorParams( |
| #57 | size=self.embedder.vector_dimension, |
| #58 | distance=Distance.COSINE, |
| #59 | hnsw_config=self.config.hnsw_config, |
| #60 | quantization_config=self.config.quantization_config, |
| #61 | on_disk=self.config.on_disk, |
| #62 | ), |
| #63 | ) |
| #64 | |
| #65 | def _get_or_create_db(self): |
| #66 | return self.client |
| #67 | |
| #68 | def _get_or_create_collection(self): |
| #69 | return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-") |
| #70 | |
| #71 | def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): |
| #72 | """ |
| #73 | Get existing doc ids present in vector database |
| #74 | |
| #75 | :param ids: _list of doc ids to check for existence |
| #76 | :type ids: list[str] |
| #77 | :param where: to filter data |
| #78 | :type where: dict[str, any] |
| #79 | :param limit: The number of entries to be fetched |
| #80 | :type limit: Optional int, defaults to None |
| #81 | :return: All the existing IDs |
| #82 | :rtype: Set[str] |
| #83 | """ |
| #84 | |
| #85 | keys = set(where.keys() if where is not None else set()) |
| #86 | |
| #87 | qdrant_must_filters = [] |
| #88 | |
| #89 | if ids: |
| #90 | qdrant_must_filters.append( |
| #91 | models.FieldCondition( |
| #92 | key="identifier", |
| #93 | match=models.MatchAny( |
| #94 | any=ids, |
| #95 | ), |
| #96 | ) |
| #97 | ) |
| #98 | |
| #99 | if len(keys) > 0: |
| #100 | for key in keys: |
| #101 | qdrant_must_filters.append( |
| #102 | models.FieldCondition( |
| #103 | key="metadata.{}".format(key), |
| #104 | match=models.MatchValue( |
| #105 | value=where.get(key), |
| #106 | ), |
| #107 | ) |
| #108 | ) |
| #109 | |
| #110 | offset = 0 |
| #111 | existing_ids = [] |
| #112 | metadatas = [] |
| #113 | while offset is not None: |
| #114 | response = self.client.scroll( |
| #115 | collection_name=self.collection_name, |
| #116 | scroll_filter=models.Filter(must=qdrant_must_filters), |
| #117 | offset=offset, |
| #118 | limit=self.batch_size, |
| #119 | ) |
| #120 | offset = response[1] |
| #121 | for doc in response[0]: |
| #122 | existing_ids.append(doc.payload["identifier"]) |
| #123 | metadatas.append(doc.payload["metadata"]) |
| #124 | return {"ids": existing_ids, "metadatas": metadatas} |
| #125 | |
| #126 | def add( |
| #127 | self, |
| #128 | documents: list[str], |
| #129 | metadatas: list[object], |
| #130 | ids: list[str], |
| #131 | **kwargs: Optional[dict[str, any]], |
| #132 | ): |
| #133 | """add data in vector database |
| #134 | :param documents: list of texts to add |
| #135 | :type documents: list[str] |
| #136 | :param metadatas: list of metadata associated with docs |
| #137 | :type metadatas: list[object] |
| #138 | :param ids: ids of docs |
| #139 | :type ids: list[str] |
| #140 | """ |
| #141 | embeddings = self.embedder.embedding_fn(documents) |
| #142 | |
| #143 | payloads = [] |
| #144 | qdrant_ids = [] |
| #145 | for id, document, metadata in zip(ids, documents, metadatas): |
| #146 | metadata["text"] = document |
| #147 | qdrant_ids.append(id) |
| #148 | payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)}) |
| #149 | |
| #150 | for i in tqdm(range(0, len(qdrant_ids), self.batch_size), desc="Adding data in batches"): |
| #151 | self.client.upsert( |
| #152 | collection_name=self.collection_name, |
| #153 | points=Batch( |
| #154 | ids=qdrant_ids[i : i + self.batch_size], |
| #155 | payloads=payloads[i : i + self.batch_size], |
| #156 | vectors=embeddings[i : i + self.batch_size], |
| #157 | ), |
| #158 | **kwargs, |
| #159 | ) |
| #160 | |
| #161 | def query( |
| #162 | self, |
| #163 | input_query: str, |
| #164 | n_results: int, |
| #165 | where: dict[str, any], |
| #166 | citations: bool = False, |
| #167 | **kwargs: Optional[dict[str, Any]], |
| #168 | ) -> Union[list[tuple[str, dict]], list[str]]: |
| #169 | """ |
| #170 | query contents from vector database based on vector similarity |
| #171 | :param input_query: query string |
| #172 | :type input_query: str |
| #173 | :param n_results: no of similar documents to fetch from database |
| #174 | :type n_results: int |
| #175 | :param where: Optional. to filter data |
| #176 | :type where: dict[str, any] |
| #177 | :param citations: we use citations boolean param to return context along with the answer. |
| #178 | :type citations: bool, default is False. |
| #179 | :return: The content of the document that matched your query, |
| #180 | along with url of the source and doc_id (if citations flag is true) |
| #181 | :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] |
| #182 | """ |
| #183 | query_vector = self.embedder.embedding_fn([input_query])[0] |
| #184 | keys = set(where.keys() if where is not None else set()) |
| #185 | |
| #186 | qdrant_must_filters = [] |
| #187 | if len(keys) > 0: |
| #188 | for key in keys: |
| #189 | qdrant_must_filters.append( |
| #190 | models.FieldCondition( |
| #191 | key="metadata.{}".format(key), |
| #192 | match=models.MatchValue( |
| #193 | value=where.get(key), |
| #194 | ), |
| #195 | ) |
| #196 | ) |
| #197 | |
| #198 | results = self.client.search( |
| #199 | collection_name=self.collection_name, |
| #200 | query_filter=models.Filter(must=qdrant_must_filters), |
| #201 | query_vector=query_vector, |
| #202 | limit=n_results, |
| #203 | **kwargs, |
| #204 | ) |
| #205 | |
| #206 | contexts = [] |
| #207 | for result in results: |
| #208 | context = result.payload["text"] |
| #209 | if citations: |
| #210 | metadata = result.payload["metadata"] |
| #211 | metadata["score"] = result.score |
| #212 | contexts.append(tuple((context, metadata))) |
| #213 | else: |
| #214 | contexts.append(context) |
| #215 | return contexts |
| #216 | |
| #217 | def count(self) -> int: |
| #218 | response = self.client.get_collection(collection_name=self.collection_name) |
| #219 | return response.points_count |
| #220 | |
| #221 | def reset(self): |
| #222 | self.client.delete_collection(collection_name=self.collection_name) |
| #223 | self._initialize() |
| #224 | |
| #225 | def set_collection_name(self, name: str): |
| #226 | """ |
| #227 | Set the name of the collection. A collection is an isolated space for vectors. |
| #228 | |
| #229 | :param name: Name of the collection. |
| #230 | :type name: str |
| #231 | """ |
| #232 | if not isinstance(name, str): |
| #233 | raise TypeError("Collection name must be a string") |
| #234 | self.config.collection_name = name |
| #235 | self.collection_name = self._get_or_create_collection() |
| #236 | |
| #237 | @staticmethod |
| #238 | def _generate_query(where: dict): |
| #239 | must_fields = [] |
| #240 | for key, value in where.items(): |
| #241 | must_fields.append( |
| #242 | models.FieldCondition( |
| #243 | key=f"metadata.{key}", |
| #244 | match=models.MatchValue( |
| #245 | value=value, |
| #246 | ), |
| #247 | ) |
| #248 | ) |
| #249 | return models.Filter(must=must_fields) |
| #250 | |
| #251 | def delete(self, where: dict): |
| #252 | db_filter = self._generate_query(where) |
| #253 | self.client.delete(collection_name=self.collection_name, points_selector=db_filter) |
| #254 |