repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import datetime |
| #2 | import itertools |
| #3 | import json |
| #4 | import logging |
| #5 | import os |
| #6 | import re |
| #7 | import string |
| #8 | from typing import Any |
| #9 | |
| #10 | from schema import Optional, Or, Schema |
| #11 | from tqdm import tqdm |
| #12 | |
| #13 | from embedchain.models.data_type import DataType |
| #14 | |
| #15 | logger = logging.getLogger(__name__) |
| #16 | |
| #17 | |
| #18 | def parse_content(content, type): |
| #19 | implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"] |
| #20 | if type not in implemented: |
| #21 | raise ValueError(f"Parser type {type} not implemented. Please choose one of {implemented}") |
| #22 | |
| #23 | from bs4 import BeautifulSoup |
| #24 | |
| #25 | soup = BeautifulSoup(content, type) |
| #26 | original_size = len(str(soup.get_text())) |
| #27 | |
| #28 | tags_to_exclude = [ |
| #29 | "nav", |
| #30 | "aside", |
| #31 | "form", |
| #32 | "header", |
| #33 | "noscript", |
| #34 | "svg", |
| #35 | "canvas", |
| #36 | "footer", |
| #37 | "script", |
| #38 | "style", |
| #39 | ] |
| #40 | for tag in soup(tags_to_exclude): |
| #41 | tag.decompose() |
| #42 | |
| #43 | ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"] |
| #44 | for id in ids_to_exclude: |
| #45 | tags = soup.find_all(id=id) |
| #46 | for tag in tags: |
| #47 | tag.decompose() |
| #48 | |
| #49 | classes_to_exclude = [ |
| #50 | "elementor-location-header", |
| #51 | "navbar-header", |
| #52 | "nav", |
| #53 | "header-sidebar-wrapper", |
| #54 | "blog-sidebar-wrapper", |
| #55 | "related-posts", |
| #56 | ] |
| #57 | for class_name in classes_to_exclude: |
| #58 | tags = soup.find_all(class_=class_name) |
| #59 | for tag in tags: |
| #60 | tag.decompose() |
| #61 | |
| #62 | content = soup.get_text() |
| #63 | content = clean_string(content) |
| #64 | |
| #65 | cleaned_size = len(content) |
| #66 | if original_size != 0: |
| #67 | logger.info( |
| #68 | f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501 |
| #69 | ) |
| #70 | |
| #71 | return content |
| #72 | |
| #73 | |
| #74 | def clean_string(text): |
| #75 | """ |
| #76 | This function takes in a string and performs a series of text cleaning operations. |
| #77 | |
| #78 | Args: |
| #79 | text (str): The text to be cleaned. This is expected to be a string. |
| #80 | |
| #81 | Returns: |
| #82 | cleaned_text (str): The cleaned text after all the cleaning operations |
| #83 | have been performed. |
| #84 | """ |
| #85 | # Stripping and reducing multiple spaces to single: |
| #86 | cleaned_text = re.sub(r"\s+", " ", text.strip()) |
| #87 | |
| #88 | # Removing backslashes: |
| #89 | cleaned_text = cleaned_text.replace("\\", "") |
| #90 | |
| #91 | # Replacing hash characters: |
| #92 | cleaned_text = cleaned_text.replace("#", " ") |
| #93 | |
| #94 | # Eliminating consecutive non-alphanumeric characters: |
| #95 | # This regex identifies consecutive non-alphanumeric characters (i.e., not |
| #96 | # a word character [a-zA-Z0-9_] and not a whitespace) in the string |
| #97 | # and replaces each group of such characters with a single occurrence of |
| #98 | # that character. |
| #99 | # For example, "!!! hello !!!" would become "! hello !". |
| #100 | cleaned_text = re.sub(r"([^\w\s])\1*", r"\1", cleaned_text) |
| #101 | |
| #102 | return cleaned_text |
| #103 | |
| #104 | |
| #105 | def is_readable(s): |
| #106 | """ |
| #107 | Heuristic to determine if a string is "readable" (mostly contains printable characters and forms meaningful words) |
| #108 | |
| #109 | :param s: string |
| #110 | :return: True if the string is more than 95% printable. |
| #111 | """ |
| #112 | len_s = len(s) |
| #113 | if len_s == 0: |
| #114 | return False |
| #115 | printable_chars = set(string.printable) |
| #116 | printable_ratio = sum(c in printable_chars for c in s) / len_s |
| #117 | return printable_ratio > 0.95 # 95% of characters are printable |
| #118 | |
| #119 | |
| #120 | def use_pysqlite3(): |
| #121 | """ |
| #122 | Swap std-lib sqlite3 with pysqlite3. |
| #123 | """ |
| #124 | import platform |
| #125 | import sqlite3 |
| #126 | |
| #127 | if platform.system() == "Linux" and sqlite3.sqlite_version_info < (3, 35, 0): |
| #128 | try: |
| #129 | # According to the Chroma team, this patch only works on Linux |
| #130 | import datetime |
| #131 | import subprocess |
| #132 | import sys |
| #133 | |
| #134 | subprocess.check_call( |
| #135 | [sys.executable, "-m", "pip", "install", "pysqlite3-binary", "--quiet", "--disable-pip-version-check"] |
| #136 | ) |
| #137 | |
| #138 | __import__("pysqlite3") |
| #139 | sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") |
| #140 | |
| #141 | # Let the user know what happened. |
| #142 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] |
| #143 | print( |
| #144 | f"{current_time} [embedchain] [INFO]", |
| #145 | "Swapped std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.", |
| #146 | f"Your original version was {sqlite3.sqlite_version}.", |
| #147 | ) |
| #148 | except Exception as e: |
| #149 | # Escape all exceptions |
| #150 | current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] |
| #151 | print( |
| #152 | f"{current_time} [embedchain] [ERROR]", |
| #153 | "Failed to swap std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.", |
| #154 | "Error:", |
| #155 | e, |
| #156 | ) |
| #157 | |
| #158 | |
| #159 | def format_source(source: str, limit: int = 20) -> str: |
| #160 | """ |
| #161 | Format a string to only take the first x and last x letters. |
| #162 | This makes it easier to display a URL, keeping familiarity while ensuring a consistent length. |
| #163 | If the string is too short, it is not sliced. |
| #164 | """ |
| #165 | if len(source) > 2 * limit: |
| #166 | return source[:limit] + "..." + source[-limit:] |
| #167 | return source |
| #168 | |
| #169 | |
| #170 | def detect_datatype(source: Any) -> DataType: |
| #171 | """ |
| #172 | Automatically detect the datatype of the given source. |
| #173 | |
| #174 | :param source: the source to base the detection on |
| #175 | :return: data_type string |
| #176 | """ |
| #177 | from urllib.parse import urlparse |
| #178 | |
| #179 | import requests |
| #180 | import yaml |
| #181 | |
| #182 | def is_openapi_yaml(yaml_content): |
| #183 | # currently the following two fields are required in openapi spec yaml config |
| #184 | return "openapi" in yaml_content and "info" in yaml_content |
| #185 | |
| #186 | def is_google_drive_folder(url): |
| #187 | # checks if url is a Google Drive folder url against a regex |
| #188 | regex = r"^drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$" |
| #189 | return re.match(regex, url) |
| #190 | |
| #191 | try: |
| #192 | if not isinstance(source, str): |
| #193 | raise ValueError("Source is not a string and thus cannot be a URL.") |
| #194 | url = urlparse(source) |
| #195 | # Check if both scheme and netloc are present. Local file system URIs are acceptable too. |
| #196 | if not all([url.scheme, url.netloc]) and url.scheme != "file": |
| #197 | raise ValueError("Not a valid URL.") |
| #198 | except ValueError: |
| #199 | url = False |
| #200 | |
| #201 | formatted_source = format_source(str(source), 30) |
| #202 | |
| #203 | if url: |
| #204 | YOUTUBE_ALLOWED_NETLOCKS = { |
| #205 | "www.youtube.com", |
| #206 | "m.youtube.com", |
| #207 | "youtu.be", |
| #208 | "youtube.com", |
| #209 | "vid.plus", |
| #210 | "www.youtube-nocookie.com", |
| #211 | } |
| #212 | |
| #213 | if url.netloc in YOUTUBE_ALLOWED_NETLOCKS: |
| #214 | logger.debug(f"Source of `{formatted_source}` detected as `youtube_video`.") |
| #215 | return DataType.YOUTUBE_VIDEO |
| #216 | |
| #217 | if url.netloc in {"notion.so", "notion.site"}: |
| #218 | logger.debug(f"Source of `{formatted_source}` detected as `notion`.") |
| #219 | return DataType.NOTION |
| #220 | |
| #221 | if url.path.endswith(".pdf"): |
| #222 | logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.") |
| #223 | return DataType.PDF_FILE |
| #224 | |
| #225 | if url.path.endswith(".xml"): |
| #226 | logger.debug(f"Source of `{formatted_source}` detected as `sitemap`.") |
| #227 | return DataType.SITEMAP |
| #228 | |
| #229 | if url.path.endswith(".csv"): |
| #230 | logger.debug(f"Source of `{formatted_source}` detected as `csv`.") |
| #231 | return DataType.CSV |
| #232 | |
| #233 | if url.path.endswith(".mdx") or url.path.endswith(".md"): |
| #234 | logger.debug(f"Source of `{formatted_source}` detected as `mdx`.") |
| #235 | return DataType.MDX |
| #236 | |
| #237 | if url.path.endswith(".docx"): |
| #238 | logger.debug(f"Source of `{formatted_source}` detected as `docx`.") |
| #239 | return DataType.DOCX |
| #240 | |
| #241 | if url.path.endswith( |
| #242 | (".mp3", ".mp4", ".mp2", ".aac", ".wav", ".flac", ".pcm", ".m4a", ".ogg", ".opus", ".webm") |
| #243 | ): |
| #244 | logger.debug(f"Source of `{formatted_source}` detected as `audio`.") |
| #245 | return DataType.AUDIO |
| #246 | |
| #247 | if url.path.endswith(".yaml"): |
| #248 | try: |
| #249 | response = requests.get(source) |
| #250 | response.raise_for_status() |
| #251 | try: |
| #252 | yaml_content = yaml.safe_load(response.text) |
| #253 | except yaml.YAMLError as exc: |
| #254 | logger.error(f"Error parsing YAML: {exc}") |
| #255 | raise TypeError(f"Not a valid data type. Error loading YAML: {exc}") |
| #256 | |
| #257 | if is_openapi_yaml(yaml_content): |
| #258 | logger.debug(f"Source of `{formatted_source}` detected as `openapi`.") |
| #259 | return DataType.OPENAPI |
| #260 | else: |
| #261 | logger.error( |
| #262 | f"Source of `{formatted_source}` does not contain all the required \ |
| #263 | fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'" |
| #264 | ) |
| #265 | raise TypeError( |
| #266 | "Not a valid data type. Check 'https://spec.openapis.org/oas/v3.1.0', \ |
| #267 | make sure you have all the required fields in YAML config data" |
| #268 | ) |
| #269 | except requests.exceptions.RequestException as e: |
| #270 | logger.error(f"Error fetching URL {formatted_source}: {e}") |
| #271 | |
| #272 | if url.path.endswith(".json"): |
| #273 | logger.debug(f"Source of `{formatted_source}` detected as `json_file`.") |
| #274 | return DataType.JSON |
| #275 | |
| #276 | if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"): |
| #277 | # `docs_site` detection via path is not accepted for local filesystem URIs, |
| #278 | # because that would mean all paths that contain `docs` are now doc sites, which is too aggressive. |
| #279 | logger.debug(f"Source of `{formatted_source}` detected as `docs_site`.") |
| #280 | return DataType.DOCS_SITE |
| #281 | |
| #282 | if "github.com" in url.netloc: |
| #283 | logger.debug(f"Source of `{formatted_source}` detected as `github`.") |
| #284 | return DataType.GITHUB |
| #285 | |
| #286 | if is_google_drive_folder(url.netloc + url.path): |
| #287 | logger.debug(f"Source of `{formatted_source}` detected as `google drive folder`.") |
| #288 | return DataType.GOOGLE_DRIVE_FOLDER |
| #289 | |
| #290 | # If none of the above conditions are met, it's a general web page |
| #291 | logger.debug(f"Source of `{formatted_source}` detected as `web_page`.") |
| #292 | return DataType.WEB_PAGE |
| #293 | |
| #294 | elif not isinstance(source, str): |
| #295 | # For datatypes where source is not a string. |
| #296 | |
| #297 | if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str): |
| #298 | logger.debug(f"Source of `{formatted_source}` detected as `qna_pair`.") |
| #299 | return DataType.QNA_PAIR |
| #300 | |
| #301 | # Raise an error if it isn't a string and also not a valid non-string type (one of the previous). |
| #302 | # We could stringify it, but it is better to raise an error and let the user decide how they want to do that. |
| #303 | raise TypeError( |
| #304 | "Source is not a string and a valid non-string type could not be detected. If you want to embed it, please stringify it, for instance by using `str(source)` or `(', ').join(source)`." # noqa: E501 |
| #305 | ) |
| #306 | |
| #307 | elif os.path.isfile(source): |
| #308 | # For datatypes that support conventional file references. |
| #309 | # Note: checking for string is not necessary anymore. |
| #310 | |
| #311 | if source.endswith(".docx"): |
| #312 | logger.debug(f"Source of `{formatted_source}` detected as `docx`.") |
| #313 | return DataType.DOCX |
| #314 | |
| #315 | if source.endswith(".csv"): |
| #316 | logger.debug(f"Source of `{formatted_source}` detected as `csv`.") |
| #317 | return DataType.CSV |
| #318 | |
| #319 | if source.endswith(".xml"): |
| #320 | logger.debug(f"Source of `{formatted_source}` detected as `xml`.") |
| #321 | return DataType.XML |
| #322 | |
| #323 | if source.endswith(".mdx") or source.endswith(".md"): |
| #324 | logger.debug(f"Source of `{formatted_source}` detected as `mdx`.") |
| #325 | return DataType.MDX |
| #326 | |
| #327 | if source.endswith(".txt"): |
| #328 | logger.debug(f"Source of `{formatted_source}` detected as `text`.") |
| #329 | return DataType.TEXT_FILE |
| #330 | |
| #331 | if source.endswith(".pdf"): |
| #332 | logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.") |
| #333 | return DataType.PDF_FILE |
| #334 | |
| #335 | if source.endswith(".yaml"): |
| #336 | with open(source, "r") as file: |
| #337 | yaml_content = yaml.safe_load(file) |
| #338 | if is_openapi_yaml(yaml_content): |
| #339 | logger.debug(f"Source of `{formatted_source}` detected as `openapi`.") |
| #340 | return DataType.OPENAPI |
| #341 | else: |
| #342 | logger.error( |
| #343 | f"Source of `{formatted_source}` does not contain all the required \ |
| #344 | fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'" |
| #345 | ) |
| #346 | raise ValueError( |
| #347 | "Invalid YAML data. Check 'https://spec.openapis.org/oas/v3.1.0', \ |
| #348 | make sure to add all the required params" |
| #349 | ) |
| #350 | |
| #351 | if source.endswith(".json"): |
| #352 | logger.debug(f"Source of `{formatted_source}` detected as `json`.") |
| #353 | return DataType.JSON |
| #354 | |
| #355 | if os.path.exists(source) and is_readable(open(source).read()): |
| #356 | logger.debug(f"Source of `{formatted_source}` detected as `text_file`.") |
| #357 | return DataType.TEXT_FILE |
| #358 | |
| #359 | # If the source is a valid file, that's not detectable as a type, an error is raised. |
| #360 | # It does not fall back to text. |
| #361 | raise ValueError( |
| #362 | "Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501 |
| #363 | ) |
| #364 | |
| #365 | else: |
| #366 | # Source is not a URL. |
| #367 | |
| #368 | # TODO: check if source is gmail query |
| #369 | |
| #370 | # check if the source is valid json string |
| #371 | if is_valid_json_string(source): |
| #372 | logger.debug(f"Source of `{formatted_source}` detected as `json`.") |
| #373 | return DataType.JSON |
| #374 | |
| #375 | # Use text as final fallback. |
| #376 | logger.debug(f"Source of `{formatted_source}` detected as `text`.") |
| #377 | return DataType.TEXT |
| #378 | |
| #379 | |
| #380 | # check if the source is valid json string |
| #381 | def is_valid_json_string(source: str): |
| #382 | try: |
| #383 | _ = json.loads(source) |
| #384 | return True |
| #385 | except json.JSONDecodeError: |
| #386 | return False |
| #387 | |
| #388 | |
| #389 | def validate_config(config_data): |
| #390 | schema = Schema( |
| #391 | { |
| #392 | Optional("app"): { |
| #393 | Optional("config"): { |
| #394 | Optional("id"): str, |
| #395 | Optional("name"): str, |
| #396 | Optional("log_level"): Or("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"), |
| #397 | Optional("collect_metrics"): bool, |
| #398 | Optional("collection_name"): str, |
| #399 | } |
| #400 | }, |
| #401 | Optional("llm"): { |
| #402 | Optional("provider"): Or( |
| #403 | "openai", |
| #404 | "azure_openai", |
| #405 | "anthropic", |
| #406 | "huggingface", |
| #407 | "cohere", |
| #408 | "together", |
| #409 | "gpt4all", |
| #410 | "ollama", |
| #411 | "jina", |
| #412 | "llama2", |
| #413 | "vertexai", |
| #414 | "google", |
| #415 | "aws_bedrock", |
| #416 | "mistralai", |
| #417 | "clarifai", |
| #418 | "vllm", |
| #419 | "groq", |
| #420 | "nvidia", |
| #421 | ), |
| #422 | Optional("config"): { |
| #423 | Optional("model"): str, |
| #424 | Optional("model_name"): str, |
| #425 | Optional("number_documents"): int, |
| #426 | Optional("temperature"): float, |
| #427 | Optional("max_tokens"): int, |
| #428 | Optional("top_p"): Or(float, int), |
| #429 | Optional("stream"): bool, |
| #430 | Optional("online"): bool, |
| #431 | Optional("token_usage"): bool, |
| #432 | Optional("template"): str, |
| #433 | Optional("prompt"): str, |
| #434 | Optional("system_prompt"): str, |
| #435 | Optional("deployment_name"): str, |
| #436 | Optional("where"): dict, |
| #437 | Optional("query_type"): str, |
| #438 | Optional("api_key"): str, |
| #439 | Optional("base_url"): str, |
| #440 | Optional("endpoint"): str, |
| #441 | Optional("model_kwargs"): dict, |
| #442 | Optional("local"): bool, |
| #443 | Optional("base_url"): str, |
| #444 | Optional("default_headers"): dict, |
| #445 | Optional("api_version"): Or(str, datetime.date), |
| #446 | Optional("http_client_proxies"): Or(str, dict), |
| #447 | Optional("http_async_client_proxies"): Or(str, dict), |
| #448 | }, |
| #449 | }, |
| #450 | Optional("vectordb"): { |
| #451 | Optional("provider"): Or( |
| #452 | "chroma", "elasticsearch", "opensearch", "lancedb", "pinecone", "qdrant", "weaviate", "zilliz" |
| #453 | ), |
| #454 | Optional("config"): object, # TODO: add particular config schema for each provider |
| #455 | }, |
| #456 | Optional("embedder"): { |
| #457 | Optional("provider"): Or( |
| #458 | "openai", |
| #459 | "gpt4all", |
| #460 | "huggingface", |
| #461 | "vertexai", |
| #462 | "azure_openai", |
| #463 | "google", |
| #464 | "mistralai", |
| #465 | "clarifai", |
| #466 | "nvidia", |
| #467 | "ollama", |
| #468 | "cohere", |
| #469 | "aws_bedrock", |
| #470 | ), |
| #471 | Optional("config"): { |
| #472 | Optional("model"): Optional(str), |
| #473 | Optional("deployment_name"): Optional(str), |
| #474 | Optional("api_key"): str, |
| #475 | Optional("api_base"): str, |
| #476 | Optional("title"): str, |
| #477 | Optional("task_type"): str, |
| #478 | Optional("vector_dimension"): int, |
| #479 | Optional("base_url"): str, |
| #480 | Optional("endpoint"): str, |
| #481 | Optional("model_kwargs"): dict, |
| #482 | Optional("http_client_proxies"): Or(str, dict), |
| #483 | Optional("http_async_client_proxies"): Or(str, dict), |
| #484 | }, |
| #485 | }, |
| #486 | Optional("embedding_model"): { |
| #487 | Optional("provider"): Or( |
| #488 | "openai", |
| #489 | "gpt4all", |
| #490 | "huggingface", |
| #491 | "vertexai", |
| #492 | "azure_openai", |
| #493 | "google", |
| #494 | "mistralai", |
| #495 | "clarifai", |
| #496 | "nvidia", |
| #497 | "ollama", |
| #498 | "aws_bedrock", |
| #499 | ), |
| #500 | Optional("config"): { |
| #501 | Optional("model"): str, |
| #502 | Optional("deployment_name"): str, |
| #503 | Optional("api_key"): str, |
| #504 | Optional("title"): str, |
| #505 | Optional("task_type"): str, |
| #506 | Optional("vector_dimension"): int, |
| #507 | Optional("base_url"): str, |
| #508 | }, |
| #509 | }, |
| #510 | Optional("chunker"): { |
| #511 | Optional("chunk_size"): int, |
| #512 | Optional("chunk_overlap"): int, |
| #513 | Optional("length_function"): str, |
| #514 | Optional("min_chunk_size"): int, |
| #515 | }, |
| #516 | Optional("cache"): { |
| #517 | Optional("similarity_evaluation"): { |
| #518 | Optional("strategy"): Or("distance", "exact"), |
| #519 | Optional("max_distance"): float, |
| #520 | Optional("positive"): bool, |
| #521 | }, |
| #522 | Optional("config"): { |
| #523 | Optional("similarity_threshold"): float, |
| #524 | Optional("auto_flush"): int, |
| #525 | }, |
| #526 | }, |
| #527 | Optional("memory"): { |
| #528 | Optional("top_k"): int, |
| #529 | }, |
| #530 | } |
| #531 | ) |
| #532 | |
| #533 | return schema.validate(config_data) |
| #534 | |
| #535 | |
| #536 | def chunks(iterable, batch_size=100, desc="Processing chunks"): |
| #537 | """A helper function to break an iterable into chunks of size batch_size.""" |
| #538 | it = iter(iterable) |
| #539 | total_size = len(iterable) |
| #540 | |
| #541 | with tqdm(total=total_size, desc=desc, unit="batch") as pbar: |
| #542 | chunk = tuple(itertools.islice(it, batch_size)) |
| #543 | while chunk: |
| #544 | yield chunk |
| #545 | pbar.update(len(chunk)) |
| #546 | chunk = tuple(itertools.islice(it, batch_size)) |
| #547 |