repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import json |
| #2 | import os |
| #3 | import warnings |
| #4 | from typing import Any, Callable, Dict, Optional, Type, Union |
| #5 | |
| #6 | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
| #7 | from langchain.schema import BaseMessage, HumanMessage, SystemMessage |
| #8 | from langchain_core.tools import BaseTool |
| #9 | from langchain_openai import ChatOpenAI |
| #10 | from pydantic import BaseModel |
| #11 | |
| #12 | from embedchain.config import BaseLlmConfig |
| #13 | from embedchain.helpers.json_serializable import register_deserializable |
| #14 | from embedchain.llm.base import BaseLlm |
| #15 | |
| #16 | |
| #17 | @register_deserializable |
| #18 | class OpenAILlm(BaseLlm): |
| #19 | def __init__( |
| #20 | self, |
| #21 | config: Optional[BaseLlmConfig] = None, |
| #22 | tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]] = None, |
| #23 | ): |
| #24 | self.tools = tools |
| #25 | super().__init__(config=config) |
| #26 | |
| #27 | def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]: |
| #28 | if self.config.token_usage: |
| #29 | response, token_info = self._get_answer(prompt, self.config) |
| #30 | model_name = "openai/" + self.config.model |
| #31 | if model_name not in self.config.model_pricing_map: |
| #32 | raise ValueError( |
| #33 | f"Model {model_name} not found in `model_prices_and_context_window.json`. \ |
| #34 | You can disable token usage by setting `token_usage` to False." |
| #35 | ) |
| #36 | total_cost = ( |
| #37 | self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"] |
| #38 | ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"] |
| #39 | response_token_info = { |
| #40 | "prompt_tokens": token_info["prompt_tokens"], |
| #41 | "completion_tokens": token_info["completion_tokens"], |
| #42 | "total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"], |
| #43 | "total_cost": round(total_cost, 10), |
| #44 | "cost_currency": "USD", |
| #45 | } |
| #46 | return response, response_token_info |
| #47 | |
| #48 | return self._get_answer(prompt, self.config) |
| #49 | |
| #50 | def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str: |
| #51 | messages = [] |
| #52 | if config.system_prompt: |
| #53 | messages.append(SystemMessage(content=config.system_prompt)) |
| #54 | messages.append(HumanMessage(content=prompt)) |
| #55 | kwargs = { |
| #56 | "model": config.model or "gpt-4o-mini", |
| #57 | "temperature": config.temperature, |
| #58 | "max_tokens": config.max_tokens, |
| #59 | "model_kwargs": config.model_kwargs or {}, |
| #60 | } |
| #61 | api_key = config.api_key or os.environ["OPENAI_API_KEY"] |
| #62 | base_url = ( |
| #63 | config.base_url |
| #64 | or os.getenv("OPENAI_API_BASE") |
| #65 | or os.getenv("OPENAI_BASE_URL") |
| #66 | or "https://api.openai.com/v1" |
| #67 | ) |
| #68 | if os.environ.get("OPENAI_API_BASE"): |
| #69 | warnings.warn( |
| #70 | "The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.140. " |
| #71 | "Please use 'OPENAI_BASE_URL' instead.", |
| #72 | DeprecationWarning |
| #73 | ) |
| #74 | |
| #75 | if config.top_p: |
| #76 | kwargs["top_p"] = config.top_p |
| #77 | if config.default_headers: |
| #78 | kwargs["default_headers"] = config.default_headers |
| #79 | if config.stream: |
| #80 | callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()] |
| #81 | chat = ChatOpenAI( |
| #82 | **kwargs, |
| #83 | streaming=config.stream, |
| #84 | callbacks=callbacks, |
| #85 | api_key=api_key, |
| #86 | base_url=base_url, |
| #87 | http_client=config.http_client, |
| #88 | http_async_client=config.http_async_client, |
| #89 | ) |
| #90 | else: |
| #91 | chat = ChatOpenAI( |
| #92 | **kwargs, |
| #93 | api_key=api_key, |
| #94 | base_url=base_url, |
| #95 | http_client=config.http_client, |
| #96 | http_async_client=config.http_async_client, |
| #97 | ) |
| #98 | if self.tools: |
| #99 | return self._query_function_call(chat, self.tools, messages) |
| #100 | |
| #101 | chat_response = chat.invoke(messages) |
| #102 | if self.config.token_usage: |
| #103 | return chat_response.content, chat_response.response_metadata["token_usage"] |
| #104 | return chat_response.content |
| #105 | |
| #106 | def _query_function_call( |
| #107 | self, |
| #108 | chat: ChatOpenAI, |
| #109 | tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]], |
| #110 | messages: list[BaseMessage], |
| #111 | ) -> str: |
| #112 | from langchain.output_parsers.openai_tools import JsonOutputToolsParser |
| #113 | from langchain_core.utils.function_calling import convert_to_openai_tool |
| #114 | |
| #115 | openai_tools = [convert_to_openai_tool(tools)] |
| #116 | chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser()) |
| #117 | try: |
| #118 | return json.dumps(chat.invoke(messages)[0]) |
| #119 | except IndexError: |
| #120 | return "Input could not be mapped to the function!" |
| #121 |