repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | import os |
| #3 | import re |
| #4 | import tempfile |
| #5 | import time |
| #6 | import uuid |
| #7 | from pathlib import Path |
| #8 | from typing import cast |
| #9 | |
| #10 | from openai import OpenAI |
| #11 | from openai.types.beta.threads import Message |
| #12 | from openai.types.beta.threads.text_content_block import TextContentBlock |
| #13 | |
| #14 | from embedchain import Client, Pipeline |
| #15 | from embedchain.config import AddConfig |
| #16 | from embedchain.data_formatter import DataFormatter |
| #17 | from embedchain.models.data_type import DataType |
| #18 | from embedchain.telemetry.posthog import AnonymousTelemetry |
| #19 | from embedchain.utils.misc import detect_datatype |
| #20 | |
| #21 | # Set up the user directory if it doesn't exist already |
| #22 | Client.setup() |
| #23 | |
| #24 | |
| #25 | class OpenAIAssistant: |
| #26 | def __init__( |
| #27 | self, |
| #28 | name=None, |
| #29 | instructions=None, |
| #30 | tools=None, |
| #31 | thread_id=None, |
| #32 | model="gpt-4-1106-preview", |
| #33 | data_sources=None, |
| #34 | assistant_id=None, |
| #35 | log_level=logging.INFO, |
| #36 | collect_metrics=True, |
| #37 | ): |
| #38 | self.name = name or "OpenAI Assistant" |
| #39 | self.instructions = instructions |
| #40 | self.tools = tools or [{"type": "retrieval"}] |
| #41 | self.model = model |
| #42 | self.data_sources = data_sources or [] |
| #43 | self.log_level = log_level |
| #44 | self._client = OpenAI() |
| #45 | self._initialize_assistant(assistant_id) |
| #46 | self.thread_id = thread_id or self._create_thread() |
| #47 | self._telemetry_props = {"class": self.__class__.__name__} |
| #48 | self.telemetry = AnonymousTelemetry(enabled=collect_metrics) |
| #49 | self.telemetry.capture(event_name="init", properties=self._telemetry_props) |
| #50 | |
| #51 | def add(self, source, data_type=None): |
| #52 | file_path = self._prepare_source_path(source, data_type) |
| #53 | self._add_file_to_assistant(file_path) |
| #54 | |
| #55 | event_props = { |
| #56 | **self._telemetry_props, |
| #57 | "data_type": data_type or detect_datatype(source), |
| #58 | } |
| #59 | self.telemetry.capture(event_name="add", properties=event_props) |
| #60 | logging.info("Data successfully added to the assistant.") |
| #61 | |
| #62 | def chat(self, message): |
| #63 | self._send_message(message) |
| #64 | self.telemetry.capture(event_name="chat", properties=self._telemetry_props) |
| #65 | return self._get_latest_response() |
| #66 | |
| #67 | def delete_thread(self): |
| #68 | self._client.beta.threads.delete(self.thread_id) |
| #69 | self.thread_id = self._create_thread() |
| #70 | |
| #71 | # Internal methods |
| #72 | def _initialize_assistant(self, assistant_id): |
| #73 | file_ids = self._generate_file_ids(self.data_sources) |
| #74 | self.assistant = ( |
| #75 | self._client.beta.assistants.retrieve(assistant_id) |
| #76 | if assistant_id |
| #77 | else self._client.beta.assistants.create( |
| #78 | name=self.name, model=self.model, file_ids=file_ids, instructions=self.instructions, tools=self.tools |
| #79 | ) |
| #80 | ) |
| #81 | |
| #82 | def _create_thread(self): |
| #83 | thread = self._client.beta.threads.create() |
| #84 | return thread.id |
| #85 | |
| #86 | def _prepare_source_path(self, source, data_type=None): |
| #87 | if Path(source).is_file(): |
| #88 | return source |
| #89 | data_type = data_type or detect_datatype(source) |
| #90 | formatter = DataFormatter(data_type=DataType(data_type), config=AddConfig()) |
| #91 | data = formatter.loader.load_data(source)["data"] |
| #92 | return self._save_temp_data(data=data[0]["content"].encode(), source=source) |
| #93 | |
| #94 | def _add_file_to_assistant(self, file_path): |
| #95 | file_obj = self._client.files.create(file=open(file_path, "rb"), purpose="assistants") |
| #96 | self._client.beta.assistants.files.create(assistant_id=self.assistant.id, file_id=file_obj.id) |
| #97 | |
| #98 | def _generate_file_ids(self, data_sources): |
| #99 | return [ |
| #100 | self._add_file_to_assistant(self._prepare_source_path(ds["source"], ds.get("data_type"))) |
| #101 | for ds in data_sources |
| #102 | ] |
| #103 | |
| #104 | def _send_message(self, message): |
| #105 | self._client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=message) |
| #106 | self._wait_for_completion() |
| #107 | |
| #108 | def _wait_for_completion(self): |
| #109 | run = self._client.beta.threads.runs.create( |
| #110 | thread_id=self.thread_id, |
| #111 | assistant_id=self.assistant.id, |
| #112 | instructions=self.instructions, |
| #113 | ) |
| #114 | run_id = run.id |
| #115 | run_status = run.status |
| #116 | |
| #117 | while run_status in ["queued", "in_progress", "requires_action"]: |
| #118 | time.sleep(0.1) # Sleep before making the next API call to avoid hitting rate limits |
| #119 | run = self._client.beta.threads.runs.retrieve(thread_id=self.thread_id, run_id=run_id) |
| #120 | run_status = run.status |
| #121 | if run_status == "failed": |
| #122 | raise ValueError(f"Thread run failed with the following error: {run.last_error}") |
| #123 | |
| #124 | def _get_latest_response(self): |
| #125 | history = self._get_history() |
| #126 | return self._format_message(history[0]) if history else None |
| #127 | |
| #128 | def _get_history(self): |
| #129 | messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc") |
| #130 | return list(messages) |
| #131 | |
| #132 | @staticmethod |
| #133 | def _format_message(thread_message): |
| #134 | thread_message = cast(Message, thread_message) |
| #135 | content = [c.text.value for c in thread_message.content if isinstance(c, TextContentBlock)] |
| #136 | return " ".join(content) |
| #137 | |
| #138 | @staticmethod |
| #139 | def _save_temp_data(data, source): |
| #140 | special_chars_pattern = r'[\\/:*?"<>|&=% ]+' |
| #141 | sanitized_source = re.sub(special_chars_pattern, "_", source)[:256] |
| #142 | temp_dir = tempfile.mkdtemp() |
| #143 | file_path = os.path.join(temp_dir, sanitized_source) |
| #144 | with open(file_path, "wb") as file: |
| #145 | file.write(data) |
| #146 | return file_path |
| #147 | |
| #148 | |
| #149 | class AIAssistant: |
| #150 | def __init__( |
| #151 | self, |
| #152 | name=None, |
| #153 | instructions=None, |
| #154 | yaml_path=None, |
| #155 | assistant_id=None, |
| #156 | thread_id=None, |
| #157 | data_sources=None, |
| #158 | log_level=logging.INFO, |
| #159 | collect_metrics=True, |
| #160 | ): |
| #161 | self.name = name or "AI Assistant" |
| #162 | self.data_sources = data_sources or [] |
| #163 | self.log_level = log_level |
| #164 | self.instructions = instructions |
| #165 | self.assistant_id = assistant_id or str(uuid.uuid4()) |
| #166 | self.thread_id = thread_id or str(uuid.uuid4()) |
| #167 | self.pipeline = Pipeline.from_config(config_path=yaml_path) if yaml_path else Pipeline() |
| #168 | self.pipeline.local_id = self.pipeline.config.id = self.thread_id |
| #169 | |
| #170 | if self.instructions: |
| #171 | self.pipeline.system_prompt = self.instructions |
| #172 | |
| #173 | print( |
| #174 | f"🎉 Created AI Assistant with name: {self.name}, assistant_id: {self.assistant_id}, thread_id: {self.thread_id}" # noqa: E501 |
| #175 | ) |
| #176 | |
| #177 | # telemetry related properties |
| #178 | self._telemetry_props = {"class": self.__class__.__name__} |
| #179 | self.telemetry = AnonymousTelemetry(enabled=collect_metrics) |
| #180 | self.telemetry.capture(event_name="init", properties=self._telemetry_props) |
| #181 | |
| #182 | if self.data_sources: |
| #183 | for data_source in self.data_sources: |
| #184 | metadata = {"assistant_id": self.assistant_id, "thread_id": "global_knowledge"} |
| #185 | self.pipeline.add(data_source["source"], data_source.get("data_type"), metadata=metadata) |
| #186 | |
| #187 | def add(self, source, data_type=None): |
| #188 | metadata = {"assistant_id": self.assistant_id, "thread_id": self.thread_id} |
| #189 | self.pipeline.add(source, data_type=data_type, metadata=metadata) |
| #190 | event_props = { |
| #191 | **self._telemetry_props, |
| #192 | "data_type": data_type or detect_datatype(source), |
| #193 | } |
| #194 | self.telemetry.capture(event_name="add", properties=event_props) |
| #195 | |
| #196 | def chat(self, query): |
| #197 | where = { |
| #198 | "$and": [ |
| #199 | {"assistant_id": {"$eq": self.assistant_id}}, |
| #200 | {"thread_id": {"$in": [self.thread_id, "global_knowledge"]}}, |
| #201 | ] |
| #202 | } |
| #203 | return self.pipeline.chat(query, where=where) |
| #204 | |
| #205 | def delete(self): |
| #206 | self.pipeline.reset() |
| #207 |