repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import hashlib |
| #2 | import json |
| #3 | import logging |
| #4 | from typing import Any, Optional, Union |
| #5 | |
| #6 | from dotenv import load_dotenv |
| #7 | from langchain.docstore.document import Document |
| #8 | |
| #9 | from embedchain.cache import ( |
| #10 | adapt, |
| #11 | get_gptcache_session, |
| #12 | gptcache_data_convert, |
| #13 | gptcache_update_cache_callback, |
| #14 | ) |
| #15 | from embedchain.chunkers.base_chunker import BaseChunker |
| #16 | from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig |
| #17 | from embedchain.config.base_app_config import BaseAppConfig |
| #18 | from embedchain.core.db.models import ChatHistory, DataSource |
| #19 | from embedchain.data_formatter import DataFormatter |
| #20 | from embedchain.embedder.base import BaseEmbedder |
| #21 | from embedchain.helpers.json_serializable import JSONSerializable |
| #22 | from embedchain.llm.base import BaseLlm |
| #23 | from embedchain.loaders.base_loader import BaseLoader |
| #24 | from embedchain.models.data_type import ( |
| #25 | DataType, |
| #26 | DirectDataType, |
| #27 | IndirectDataType, |
| #28 | SpecialDataType, |
| #29 | ) |
| #30 | from embedchain.utils.misc import detect_datatype, is_valid_json_string |
| #31 | from embedchain.vectordb.base import BaseVectorDB |
| #32 | |
| #33 | load_dotenv() |
| #34 | |
| #35 | logger = logging.getLogger(__name__) |
| #36 | |
| #37 | |
| #38 | class EmbedChain(JSONSerializable): |
| #39 | def __init__( |
| #40 | self, |
| #41 | config: BaseAppConfig, |
| #42 | llm: BaseLlm, |
| #43 | db: BaseVectorDB = None, |
| #44 | embedder: BaseEmbedder = None, |
| #45 | system_prompt: Optional[str] = None, |
| #46 | ): |
| #47 | """ |
| #48 | Initializes the EmbedChain instance, sets up a vector DB client and |
| #49 | creates a collection. |
| #50 | |
| #51 | :param config: Configuration just for the app, not the db or llm or embedder. |
| #52 | :type config: BaseAppConfig |
| #53 | :param llm: Instance of the LLM you want to use. |
| #54 | :type llm: BaseLlm |
| #55 | :param db: Instance of the Database to use, defaults to None |
| #56 | :type db: BaseVectorDB, optional |
| #57 | :param embedder: instance of the embedder to use, defaults to None |
| #58 | :type embedder: BaseEmbedder, optional |
| #59 | :param system_prompt: System prompt to use in the llm query, defaults to None |
| #60 | :type system_prompt: Optional[str], optional |
| #61 | :raises ValueError: No database or embedder provided. |
| #62 | """ |
| #63 | self.config = config |
| #64 | self.cache_config = None |
| #65 | self.memory_config = None |
| #66 | self.mem0_memory = None |
| #67 | # Llm |
| #68 | self.llm = llm |
| #69 | # Database has support for config assignment for backwards compatibility |
| #70 | if db is None and (not hasattr(self.config, "db") or self.config.db is None): |
| #71 | raise ValueError("App requires Database.") |
| #72 | self.db = db or self.config.db |
| #73 | # Embedder |
| #74 | if embedder is None: |
| #75 | raise ValueError("App requires Embedder.") |
| #76 | self.embedder = embedder |
| #77 | |
| #78 | # Initialize database |
| #79 | self.db._set_embedder(self.embedder) |
| #80 | self.db._initialize() |
| #81 | # Set collection name from app config for backwards compatibility. |
| #82 | if config.collection_name: |
| #83 | self.db.set_collection_name(config.collection_name) |
| #84 | |
| #85 | # Add variables that are "shortcuts" |
| #86 | if system_prompt: |
| #87 | self.llm.config.system_prompt = system_prompt |
| #88 | |
| #89 | # Fetch the history from the database if exists |
| #90 | self.llm.update_history(app_id=self.config.id) |
| #91 | |
| #92 | # Attributes that aren't subclass related. |
| #93 | self.user_asks = [] |
| #94 | |
| #95 | self.chunker: Optional[ChunkerConfig] = None |
| #96 | |
| #97 | @property |
| #98 | def collect_metrics(self): |
| #99 | return self.config.collect_metrics |
| #100 | |
| #101 | @collect_metrics.setter |
| #102 | def collect_metrics(self, value): |
| #103 | if not isinstance(value, bool): |
| #104 | raise ValueError(f"Boolean value expected but got {type(value)}.") |
| #105 | self.config.collect_metrics = value |
| #106 | |
| #107 | @property |
| #108 | def online(self): |
| #109 | return self.llm.config.online |
| #110 | |
| #111 | @online.setter |
| #112 | def online(self, value): |
| #113 | if not isinstance(value, bool): |
| #114 | raise ValueError(f"Boolean value expected but got {type(value)}.") |
| #115 | self.llm.config.online = value |
| #116 | |
| #117 | def add( |
| #118 | self, |
| #119 | source: Any, |
| #120 | data_type: Optional[DataType] = None, |
| #121 | metadata: Optional[dict[str, Any]] = None, |
| #122 | config: Optional[AddConfig] = None, |
| #123 | dry_run=False, |
| #124 | loader: Optional[BaseLoader] = None, |
| #125 | chunker: Optional[BaseChunker] = None, |
| #126 | **kwargs: Optional[dict[str, Any]], |
| #127 | ): |
| #128 | """ |
| #129 | Adds the data from the given URL to the vector db. |
| #130 | Loads the data, chunks it, create embedding for each chunk |
| #131 | and then stores the embedding to vector database. |
| #132 | |
| #133 | :param source: The data to embed, can be a URL, local file or raw content, depending on the data type. |
| #134 | :type source: Any |
| #135 | :param data_type: Automatically detected, but can be forced with this argument. The type of the data to add, |
| #136 | defaults to None |
| #137 | :type data_type: Optional[DataType], optional |
| #138 | :param metadata: Metadata associated with the data source., defaults to None |
| #139 | :type metadata: Optional[dict[str, Any]], optional |
| #140 | :param config: The `AddConfig` instance to use as configuration options., defaults to None |
| #141 | :type config: Optional[AddConfig], optional |
| #142 | :raises ValueError: Invalid data type |
| #143 | :param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended. |
| #144 | defaults to False |
| #145 | :type dry_run: bool |
| #146 | :param loader: The loader to use to load the data, defaults to None |
| #147 | :type loader: BaseLoader, optional |
| #148 | :param chunker: The chunker to use to chunk the data, defaults to None |
| #149 | :type chunker: BaseChunker, optional |
| #150 | :param kwargs: To read more params for the query function |
| #151 | :type kwargs: dict[str, Any] |
| #152 | :return: source_hash, a md5-hash of the source, in hexadecimal representation. |
| #153 | :rtype: str |
| #154 | """ |
| #155 | if config is not None: |
| #156 | pass |
| #157 | elif self.chunker is not None: |
| #158 | config = AddConfig(chunker=self.chunker) |
| #159 | else: |
| #160 | config = AddConfig() |
| #161 | |
| #162 | try: |
| #163 | DataType(source) |
| #164 | logger.warning( |
| #165 | f"""Starting from version v0.0.40, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`""" # noqa #E501 |
| #166 | ) |
| #167 | logger.warning( |
| #168 | "Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code." # noqa #E501 |
| #169 | ) |
| #170 | source, data_type = data_type, source |
| #171 | except ValueError: |
| #172 | pass |
| #173 | |
| #174 | if data_type: |
| #175 | try: |
| #176 | data_type = DataType(data_type) |
| #177 | except ValueError: |
| #178 | logger.info( |
| #179 | f"Invalid data_type: '{data_type}', using `custom` instead.\n Check docs to pass the valid data type: `https://docs.embedchain.ai/data-sources/overview`" # noqa: E501 |
| #180 | ) |
| #181 | data_type = DataType.CUSTOM |
| #182 | |
| #183 | if not data_type: |
| #184 | data_type = detect_datatype(source) |
| #185 | |
| #186 | # `source_hash` is the md5 hash of the source argument |
| #187 | source_hash = hashlib.md5(str(source).encode("utf-8")).hexdigest() |
| #188 | |
| #189 | self.user_asks.append([source, data_type.value, metadata]) |
| #190 | |
| #191 | data_formatter = DataFormatter(data_type, config, loader, chunker) |
| #192 | documents, metadatas, _ids, new_chunks = self._load_and_embed( |
| #193 | data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, config, dry_run, **kwargs |
| #194 | ) |
| #195 | if data_type in {DataType.DOCS_SITE}: |
| #196 | self.is_docs_site_instance = True |
| #197 | |
| #198 | # Convert the source to a string if it is not already |
| #199 | if not isinstance(source, str): |
| #200 | source = str(source) |
| #201 | |
| #202 | # Insert the data into the 'ec_data_sources' table |
| #203 | self.db_session.add( |
| #204 | DataSource( |
| #205 | hash=source_hash, |
| #206 | app_id=self.config.id, |
| #207 | type=data_type.value, |
| #208 | value=source, |
| #209 | metadata=json.dumps(metadata), |
| #210 | ) |
| #211 | ) |
| #212 | try: |
| #213 | self.db_session.commit() |
| #214 | except Exception as e: |
| #215 | logger.error(f"Error adding data source: {e}") |
| #216 | self.db_session.rollback() |
| #217 | |
| #218 | if dry_run: |
| #219 | data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type} |
| #220 | logger.debug(f"Dry run info : {data_chunks_info}") |
| #221 | return data_chunks_info |
| #222 | |
| #223 | # Send anonymous telemetry |
| #224 | if self.config.collect_metrics: |
| #225 | # it's quicker to check the variable twice than to count words when they won't be submitted. |
| #226 | word_count = data_formatter.chunker.get_word_count(documents) |
| #227 | |
| #228 | # Send anonymous telemetry |
| #229 | event_properties = { |
| #230 | **self._telemetry_props, |
| #231 | "data_type": data_type.value, |
| #232 | "word_count": word_count, |
| #233 | "chunks_count": new_chunks, |
| #234 | } |
| #235 | self.telemetry.capture(event_name="add", properties=event_properties) |
| #236 | |
| #237 | return source_hash |
| #238 | |
| #239 | def _get_existing_doc_id(self, chunker: BaseChunker, src: Any): |
| #240 | """ |
| #241 | Get id of existing document for a given source, based on the data type |
| #242 | """ |
| #243 | # Find existing embeddings for the source |
| #244 | # Depending on the data type, existing embeddings are checked for. |
| #245 | if chunker.data_type.value in [item.value for item in DirectDataType]: |
| #246 | # DirectDataTypes can't be updated. |
| #247 | # Think of a text: |
| #248 | # Either it's the same, then it won't change, so it's not an update. |
| #249 | # Or it's different, then it will be added as a new text. |
| #250 | return None |
| #251 | elif chunker.data_type.value in [item.value for item in IndirectDataType]: |
| #252 | # These types have an indirect source reference |
| #253 | # As long as the reference is the same, they can be updated. |
| #254 | where = {"url": src} |
| #255 | if chunker.data_type == DataType.JSON and is_valid_json_string(src): |
| #256 | url = hashlib.sha256((src).encode("utf-8")).hexdigest() |
| #257 | where = {"url": url} |
| #258 | |
| #259 | if self.config.id is not None: |
| #260 | where.update({"app_id": self.config.id}) |
| #261 | |
| #262 | existing_embeddings = self.db.get( |
| #263 | where=where, |
| #264 | limit=1, |
| #265 | ) |
| #266 | if len(existing_embeddings.get("metadatas", [])) > 0: |
| #267 | return existing_embeddings["metadatas"][0]["doc_id"] |
| #268 | else: |
| #269 | return None |
| #270 | elif chunker.data_type.value in [item.value for item in SpecialDataType]: |
| #271 | # These types don't contain indirect references. |
| #272 | # Through custom logic, they can be attributed to a source and be updated. |
| #273 | if chunker.data_type == DataType.QNA_PAIR: |
| #274 | # QNA_PAIRs update the answer if the question already exists. |
| #275 | where = {"question": src[0]} |
| #276 | if self.config.id is not None: |
| #277 | where.update({"app_id": self.config.id}) |
| #278 | |
| #279 | existing_embeddings = self.db.get( |
| #280 | where=where, |
| #281 | limit=1, |
| #282 | ) |
| #283 | if len(existing_embeddings.get("metadatas", [])) > 0: |
| #284 | return existing_embeddings["metadatas"][0]["doc_id"] |
| #285 | else: |
| #286 | return None |
| #287 | else: |
| #288 | raise NotImplementedError( |
| #289 | f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data" |
| #290 | ) |
| #291 | else: |
| #292 | raise TypeError( |
| #293 | f"{chunker.data_type} is type {type(chunker.data_type)}. " |
| #294 | "When it should be DirectDataType, IndirectDataType or SpecialDataType." |
| #295 | ) |
| #296 | |
| #297 | def _load_and_embed( |
| #298 | self, |
| #299 | loader: BaseLoader, |
| #300 | chunker: BaseChunker, |
| #301 | src: Any, |
| #302 | metadata: Optional[dict[str, Any]] = None, |
| #303 | source_hash: Optional[str] = None, |
| #304 | add_config: Optional[AddConfig] = None, |
| #305 | dry_run=False, |
| #306 | **kwargs: Optional[dict[str, Any]], |
| #307 | ): |
| #308 | """ |
| #309 | Loads the data from the given URL, chunks it, and adds it to database. |
| #310 | |
| #311 | :param loader: The loader to use to load the data. |
| #312 | :type loader: BaseLoader |
| #313 | :param chunker: The chunker to use to chunk the data. |
| #314 | :type chunker: BaseChunker |
| #315 | :param src: The data to be handled by the loader. Can be a URL for |
| #316 | remote sources or local content for local loaders. |
| #317 | :type src: Any |
| #318 | :param metadata: Metadata associated with the data source. |
| #319 | :type metadata: dict[str, Any], optional |
| #320 | :param source_hash: Hexadecimal hash of the source. |
| #321 | :type source_hash: str, optional |
| #322 | :param add_config: The `AddConfig` instance to use as configuration options. |
| #323 | :type add_config: AddConfig, optional |
| #324 | :param dry_run: A dry run returns chunks and doesn't update DB. |
| #325 | :type dry_run: bool, defaults to False |
| #326 | :return: (list) documents (embedded text), (list) metadata, (list) ids, (int) number of chunks |
| #327 | """ |
| #328 | existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src) |
| #329 | app_id = self.config.id if self.config is not None else None |
| #330 | |
| #331 | # Create chunks |
| #332 | embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker, **kwargs) |
| #333 | # spread chunking results |
| #334 | documents = embeddings_data["documents"] |
| #335 | metadatas = embeddings_data["metadatas"] |
| #336 | ids = embeddings_data["ids"] |
| #337 | new_doc_id = embeddings_data["doc_id"] |
| #338 | |
| #339 | if existing_doc_id and existing_doc_id == new_doc_id: |
| #340 | logger.info("Doc content has not changed. Skipping creating chunks and embeddings") |
| #341 | return [], [], [], 0 |
| #342 | |
| #343 | # this means that doc content has changed. |
| #344 | if existing_doc_id and existing_doc_id != new_doc_id: |
| #345 | logger.info("Doc content has changed. Recomputing chunks and embeddings intelligently.") |
| #346 | self.db.delete({"doc_id": existing_doc_id}) |
| #347 | |
| #348 | # get existing ids, and discard doc if any common id exist. |
| #349 | where = {"url": src} |
| #350 | if chunker.data_type == DataType.JSON and is_valid_json_string(src): |
| #351 | url = hashlib.sha256((src).encode("utf-8")).hexdigest() |
| #352 | where = {"url": url} |
| #353 | |
| #354 | # if data type is qna_pair, we check for question |
| #355 | if chunker.data_type == DataType.QNA_PAIR: |
| #356 | where = {"question": src[0]} |
| #357 | |
| #358 | if self.config.id is not None: |
| #359 | where["app_id"] = self.config.id |
| #360 | |
| #361 | db_result = self.db.get(ids=ids, where=where) # optional filter |
| #362 | existing_ids = set(db_result["ids"]) |
| #363 | if len(existing_ids): |
| #364 | data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)} |
| #365 | data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids} |
| #366 | |
| #367 | if not data_dict: |
| #368 | src_copy = src |
| #369 | if len(src_copy) > 50: |
| #370 | src_copy = src[:50] + "..." |
| #371 | logger.info(f"All data from {src_copy} already exists in the database.") |
| #372 | # Make sure to return a matching return type |
| #373 | return [], [], [], 0 |
| #374 | |
| #375 | ids = list(data_dict.keys()) |
| #376 | documents, metadatas = zip(*data_dict.values()) |
| #377 | |
| #378 | # Loop though all metadatas and add extras. |
| #379 | new_metadatas = [] |
| #380 | for m in metadatas: |
| #381 | # Add app id in metadatas so that they can be queried on later |
| #382 | if self.config.id: |
| #383 | m["app_id"] = self.config.id |
| #384 | |
| #385 | # Add hashed source |
| #386 | m["hash"] = source_hash |
| #387 | |
| #388 | # Note: Metadata is the function argument |
| #389 | if metadata: |
| #390 | # Spread whatever is in metadata into the new object. |
| #391 | m.update(metadata) |
| #392 | |
| #393 | new_metadatas.append(m) |
| #394 | metadatas = new_metadatas |
| #395 | |
| #396 | if dry_run: |
| #397 | return list(documents), metadatas, ids, 0 |
| #398 | |
| #399 | # Count before, to calculate a delta in the end. |
| #400 | chunks_before_addition = self.db.count() |
| #401 | |
| #402 | # Filter out empty documents and ensure they meet the API requirements |
| #403 | valid_documents = [doc for doc in documents if doc and isinstance(doc, str)] |
| #404 | |
| #405 | documents = valid_documents |
| #406 | |
| #407 | # Chunk documents into batches of 2048 and handle each batch |
| #408 | # helps wigth large loads of embeddings that hit OpenAI limits |
| #409 | document_batches = [documents[i : i + 2048] for i in range(0, len(documents), 2048)] |
| #410 | metadata_batches = [metadatas[i : i + 2048] for i in range(0, len(metadatas), 2048)] |
| #411 | id_batches = [ids[i : i + 2048] for i in range(0, len(ids), 2048)] |
| #412 | for batch_docs, batch_meta, batch_ids in zip(document_batches, metadata_batches, id_batches): |
| #413 | try: |
| #414 | # Add only valid batches |
| #415 | if batch_docs: |
| #416 | self.db.add(documents=batch_docs, metadatas=batch_meta, ids=batch_ids, **kwargs) |
| #417 | except Exception as e: |
| #418 | logger.info(f"Failed to add batch due to a bad request: {e}") |
| #419 | # Handle the error, e.g., by logging, retrying, or skipping |
| #420 | pass |
| #421 | |
| #422 | count_new_chunks = self.db.count() - chunks_before_addition |
| #423 | logger.info(f"Successfully saved {str(src)[:100]} ({chunker.data_type}). New chunks count: {count_new_chunks}") |
| #424 | |
| #425 | return list(documents), metadatas, ids, count_new_chunks |
| #426 | |
| #427 | @staticmethod |
| #428 | def _format_result(results): |
| #429 | return [ |
| #430 | (Document(page_content=result[0], metadata=result[1] or {}), result[2]) |
| #431 | for result in zip( |
| #432 | results["documents"][0], |
| #433 | results["metadatas"][0], |
| #434 | results["distances"][0], |
| #435 | ) |
| #436 | ] |
| #437 | |
| #438 | def _retrieve_from_database( |
| #439 | self, |
| #440 | input_query: str, |
| #441 | config: Optional[BaseLlmConfig] = None, |
| #442 | where=None, |
| #443 | citations: bool = False, |
| #444 | **kwargs: Optional[dict[str, Any]], |
| #445 | ) -> Union[list[tuple[str, str, str]], list[str]]: |
| #446 | """ |
| #447 | Queries the vector database based on the given input query. |
| #448 | Gets relevant doc based on the query |
| #449 | |
| #450 | :param input_query: The query to use. |
| #451 | :type input_query: str |
| #452 | :param config: The query configuration, defaults to None |
| #453 | :type config: Optional[BaseLlmConfig], optional |
| #454 | :param where: A dictionary of key-value pairs to filter the database results, defaults to None |
| #455 | :type where: _type_, optional |
| #456 | :param citations: A boolean to indicate if db should fetch citation source |
| #457 | :type citations: bool |
| #458 | :return: List of contents of the document that matched your query |
| #459 | :rtype: list[str] |
| #460 | """ |
| #461 | query_config = config or self.llm.config |
| #462 | if where is not None: |
| #463 | where = where |
| #464 | else: |
| #465 | where = {} |
| #466 | if query_config is not None and query_config.where is not None: |
| #467 | where = query_config.where |
| #468 | |
| #469 | if self.config.id is not None: |
| #470 | where.update({"app_id": self.config.id}) |
| #471 | |
| #472 | contexts = self.db.query( |
| #473 | input_query=input_query, |
| #474 | n_results=query_config.number_documents, |
| #475 | where=where, |
| #476 | citations=citations, |
| #477 | **kwargs, |
| #478 | ) |
| #479 | |
| #480 | return contexts |
| #481 | |
| #482 | def query( |
| #483 | self, |
| #484 | input_query: str, |
| #485 | config: BaseLlmConfig = None, |
| #486 | dry_run=False, |
| #487 | where: Optional[dict] = None, |
| #488 | citations: bool = False, |
| #489 | **kwargs: dict[str, Any], |
| #490 | ) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]: |
| #491 | """ |
| #492 | Queries the vector database based on the given input query. |
| #493 | Gets relevant doc based on the query and then passes it to an |
| #494 | LLM as context to get the answer. |
| #495 | |
| #496 | :param input_query: The query to use. |
| #497 | :type input_query: str |
| #498 | :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call. |
| #499 | To persistently use a config, declare it during app init., defaults to None |
| #500 | :type config: BaseLlmConfig, optional |
| #501 | :param dry_run: A dry run does everything except send the resulting prompt to |
| #502 | the LLM. The purpose is to test the prompt, not the response., defaults to False |
| #503 | :type dry_run: bool, optional |
| #504 | :param where: A dictionary of key-value pairs to filter the database results., defaults to None |
| #505 | :type where: dict[str, str], optional |
| #506 | :param citations: A boolean to indicate if db should fetch citation source |
| #507 | :type citations: bool |
| #508 | :param kwargs: To read more params for the query function. Ex. we use citations boolean |
| #509 | param to return context along with the answer |
| #510 | :type kwargs: dict[str, Any] |
| #511 | :return: The answer to the query, with citations if the citation flag is True |
| #512 | or the dry run result |
| #513 | :rtype: str, if citations is False and token_usage is False, otherwise if citations is true then |
| #514 | tuple[str, list[tuple[str,str,str]]] and if token_usage is true then |
| #515 | tuple[str, list[tuple[str,str,str]], dict[str, Any]] |
| #516 | """ |
| #517 | contexts = self._retrieve_from_database( |
| #518 | input_query=input_query, config=config, where=where, citations=citations, **kwargs |
| #519 | ) |
| #520 | if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): |
| #521 | contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) |
| #522 | else: |
| #523 | contexts_data_for_llm_query = contexts |
| #524 | |
| #525 | if self.cache_config is not None: |
| #526 | logger.info("Cache enabled. Checking cache...") |
| #527 | answer = adapt( |
| #528 | llm_handler=self.llm.query, |
| #529 | cache_data_convert=gptcache_data_convert, |
| #530 | update_cache_callback=gptcache_update_cache_callback, |
| #531 | session=get_gptcache_session(session_id=self.config.id), |
| #532 | input_query=input_query, |
| #533 | contexts=contexts_data_for_llm_query, |
| #534 | config=config, |
| #535 | dry_run=dry_run, |
| #536 | ) |
| #537 | else: |
| #538 | if self.llm.config.token_usage: |
| #539 | answer, token_info = self.llm.query( |
| #540 | input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run |
| #541 | ) |
| #542 | else: |
| #543 | answer = self.llm.query( |
| #544 | input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run |
| #545 | ) |
| #546 | |
| #547 | # Send anonymous telemetry |
| #548 | if self.config.collect_metrics: |
| #549 | self.telemetry.capture(event_name="query", properties=self._telemetry_props) |
| #550 | |
| #551 | if citations: |
| #552 | if self.llm.config.token_usage: |
| #553 | return {"answer": answer, "contexts": contexts, "usage": token_info} |
| #554 | return answer, contexts |
| #555 | if self.llm.config.token_usage: |
| #556 | return {"answer": answer, "usage": token_info} |
| #557 | |
| #558 | logger.warning( |
| #559 | "Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`." |
| #560 | ) |
| #561 | return answer |
| #562 | |
| #563 | def chat( |
| #564 | self, |
| #565 | input_query: str, |
| #566 | config: Optional[BaseLlmConfig] = None, |
| #567 | dry_run=False, |
| #568 | session_id: str = "default", |
| #569 | where: Optional[dict[str, str]] = None, |
| #570 | citations: bool = False, |
| #571 | **kwargs: dict[str, Any], |
| #572 | ) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]: |
| #573 | """ |
| #574 | Queries the vector database on the given input query. |
| #575 | Gets relevant doc based on the query and then passes it to an |
| #576 | LLM as context to get the answer. |
| #577 | |
| #578 | Maintains the whole conversation in memory. |
| #579 | |
| #580 | :param input_query: The query to use. |
| #581 | :type input_query: str |
| #582 | :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call. |
| #583 | To persistently use a config, declare it during app init., defaults to None |
| #584 | :type config: BaseLlmConfig, optional |
| #585 | :param dry_run: A dry run does everything except send the resulting prompt to |
| #586 | the LLM. The purpose is to test the prompt, not the response., defaults to False |
| #587 | :type dry_run: bool, optional |
| #588 | :param session_id: The session id to use for chat history, defaults to 'default'. |
| #589 | :type session_id: str, optional |
| #590 | :param where: A dictionary of key-value pairs to filter the database results., defaults to None |
| #591 | :type where: dict[str, str], optional |
| #592 | :param citations: A boolean to indicate if db should fetch citation source |
| #593 | :type citations: bool |
| #594 | :param kwargs: To read more params for the query function. Ex. we use citations boolean |
| #595 | param to return context along with the answer |
| #596 | :type kwargs: dict[str, Any] |
| #597 | :return: The answer to the query, with citations if the citation flag is True |
| #598 | or the dry run result |
| #599 | :rtype: str, if citations is False and token_usage is False, otherwise if citations is true then |
| #600 | tuple[str, list[tuple[str,str,str]]] and if token_usage is true then |
| #601 | tuple[str, list[tuple[str,str,str]], dict[str, Any]] |
| #602 | """ |
| #603 | contexts = self._retrieve_from_database( |
| #604 | input_query=input_query, config=config, where=where, citations=citations, **kwargs |
| #605 | ) |
| #606 | if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): |
| #607 | contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) |
| #608 | else: |
| #609 | contexts_data_for_llm_query = contexts |
| #610 | |
| #611 | memories = None |
| #612 | if self.mem0_memory: |
| #613 | memories = self.mem0_memory.search( |
| #614 | query=input_query, agent_id=self.config.id, user_id=session_id, limit=self.memory_config.top_k |
| #615 | ) |
| #616 | |
| #617 | # Update the history beforehand so that we can handle multiple chat sessions in the same python session |
| #618 | self.llm.update_history(app_id=self.config.id, session_id=session_id) |
| #619 | |
| #620 | if self.cache_config is not None: |
| #621 | logger.debug("Cache enabled. Checking cache...") |
| #622 | cache_id = f"{session_id}--{self.config.id}" |
| #623 | answer = adapt( |
| #624 | llm_handler=self.llm.chat, |
| #625 | cache_data_convert=gptcache_data_convert, |
| #626 | update_cache_callback=gptcache_update_cache_callback, |
| #627 | session=get_gptcache_session(session_id=cache_id), |
| #628 | input_query=input_query, |
| #629 | contexts=contexts_data_for_llm_query, |
| #630 | config=config, |
| #631 | dry_run=dry_run, |
| #632 | ) |
| #633 | else: |
| #634 | logger.debug("Cache disabled. Running chat without cache.") |
| #635 | if self.llm.config.token_usage: |
| #636 | answer, token_info = self.llm.query( |
| #637 | input_query=input_query, |
| #638 | contexts=contexts_data_for_llm_query, |
| #639 | config=config, |
| #640 | dry_run=dry_run, |
| #641 | memories=memories, |
| #642 | ) |
| #643 | else: |
| #644 | answer = self.llm.query( |
| #645 | input_query=input_query, |
| #646 | contexts=contexts_data_for_llm_query, |
| #647 | config=config, |
| #648 | dry_run=dry_run, |
| #649 | memories=memories, |
| #650 | ) |
| #651 | |
| #652 | # Add to Mem0 memory if enabled |
| #653 | # Adding answer here because it would be much useful than input question itself |
| #654 | if self.mem0_memory: |
| #655 | self.mem0_memory.add(data=answer, agent_id=self.config.id, user_id=session_id) |
| #656 | |
| #657 | # add conversation in memory |
| #658 | self.llm.add_history(self.config.id, input_query, answer, session_id=session_id) |
| #659 | |
| #660 | # Send anonymous telemetry |
| #661 | if self.config.collect_metrics: |
| #662 | self.telemetry.capture(event_name="chat", properties=self._telemetry_props) |
| #663 | |
| #664 | if citations: |
| #665 | if self.llm.config.token_usage: |
| #666 | return {"answer": answer, "contexts": contexts, "usage": token_info} |
| #667 | return answer, contexts |
| #668 | if self.llm.config.token_usage: |
| #669 | return {"answer": answer, "usage": token_info} |
| #670 | |
| #671 | logger.warning( |
| #672 | "Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`." |
| #673 | ) |
| #674 | return answer |
| #675 | |
| #676 | def search(self, query, num_documents=3, where=None, raw_filter=None, namespace=None): |
| #677 | """ |
| #678 | Search for similar documents related to the query in the vector database. |
| #679 | |
| #680 | Args: |
| #681 | query (str): The query to use. |
| #682 | num_documents (int, optional): Number of similar documents to fetch. Defaults to 3. |
| #683 | where (dict[str, any], optional): Filter criteria for the search. |
| #684 | raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search. |
| #685 | namespace (str, optional): The namespace to search in. Defaults to None. |
| #686 | |
| #687 | Raises: |
| #688 | ValueError: If both `raw_filter` and `where` are used simultaneously. |
| #689 | |
| #690 | Returns: |
| #691 | list[dict]: A list of dictionaries, each containing the 'context' and 'metadata' of a document. |
| #692 | """ |
| #693 | # Send anonymous telemetry |
| #694 | if self.config.collect_metrics: |
| #695 | self.telemetry.capture(event_name="search", properties=self._telemetry_props) |
| #696 | |
| #697 | if raw_filter and where: |
| #698 | raise ValueError("You can't use both `raw_filter` and `where` together.") |
| #699 | |
| #700 | filter_type = "raw_filter" if raw_filter else "where" |
| #701 | filter_criteria = raw_filter if raw_filter else where |
| #702 | |
| #703 | params = { |
| #704 | "input_query": query, |
| #705 | "n_results": num_documents, |
| #706 | "citations": True, |
| #707 | "app_id": self.config.id, |
| #708 | "namespace": namespace, |
| #709 | filter_type: filter_criteria, |
| #710 | } |
| #711 | |
| #712 | return [{"context": c[0], "metadata": c[1]} for c in self.db.query(**params)] |
| #713 | |
| #714 | def set_collection_name(self, name: str): |
| #715 | """ |
| #716 | Set the name of the collection. A collection is an isolated space for vectors. |
| #717 | |
| #718 | Using `app.db.set_collection_name` method is preferred to this. |
| #719 | |
| #720 | :param name: Name of the collection. |
| #721 | :type name: str |
| #722 | """ |
| #723 | self.db.set_collection_name(name) |
| #724 | # Create the collection if it does not exist |
| #725 | self.db._get_or_create_collection(name) |
| #726 | # TODO: Check whether it is necessary to assign to the `self.collection` attribute, |
| #727 | # since the main purpose is the creation. |
| #728 | |
| #729 | def reset(self): |
| #730 | """ |
| #731 | Resets the database. Deletes all embeddings irreversibly. |
| #732 | `App` does not have to be reinitialized after using this method. |
| #733 | """ |
| #734 | try: |
| #735 | self.db_session.query(DataSource).filter_by(app_id=self.config.id).delete() |
| #736 | self.db_session.query(ChatHistory).filter_by(app_id=self.config.id).delete() |
| #737 | self.db_session.commit() |
| #738 | except Exception as e: |
| #739 | logger.error(f"Error deleting data sources: {e}") |
| #740 | self.db_session.rollback() |
| #741 | return None |
| #742 | self.db.reset() |
| #743 | self.delete_all_chat_history(app_id=self.config.id) |
| #744 | # Send anonymous telemetry |
| #745 | if self.config.collect_metrics: |
| #746 | self.telemetry.capture(event_name="reset", properties=self._telemetry_props) |
| #747 | |
| #748 | def get_history( |
| #749 | self, |
| #750 | num_rounds: int = 10, |
| #751 | display_format: bool = True, |
| #752 | session_id: Optional[str] = "default", |
| #753 | fetch_all: bool = False, |
| #754 | ): |
| #755 | history = self.llm.memory.get( |
| #756 | app_id=self.config.id, |
| #757 | session_id=session_id, |
| #758 | num_rounds=num_rounds, |
| #759 | display_format=display_format, |
| #760 | fetch_all=fetch_all, |
| #761 | ) |
| #762 | return history |
| #763 | |
| #764 | def delete_session_chat_history(self, session_id: str = "default"): |
| #765 | self.llm.memory.delete(app_id=self.config.id, session_id=session_id) |
| #766 | self.llm.update_history(app_id=self.config.id) |
| #767 | |
| #768 | def delete_all_chat_history(self, app_id: str): |
| #769 | self.llm.memory.delete(app_id=app_id) |
| #770 | self.llm.update_history(app_id=app_id) |
| #771 | |
| #772 | def delete(self, source_id: str): |
| #773 | """ |
| #774 | Deletes the data from the database. |
| #775 | :param source_hash: The hash of the source. |
| #776 | :type source_hash: str |
| #777 | """ |
| #778 | try: |
| #779 | self.db_session.query(DataSource).filter_by(hash=source_id, app_id=self.config.id).delete() |
| #780 | self.db_session.commit() |
| #781 | except Exception as e: |
| #782 | logger.error(f"Error deleting data sources: {e}") |
| #783 | self.db_session.rollback() |
| #784 | return None |
| #785 | self.db.delete(where={"hash": source_id}) |
| #786 | logger.info(f"Successfully deleted {source_id}") |
| #787 | # Send anonymous telemetry |
| #788 | if self.config.collect_metrics: |
| #789 | self.telemetry.capture(event_name="delete", properties=self._telemetry_props) |
| #790 |