repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import json |
| #2 | import logging |
| #3 | import re |
| #4 | from typing import Any, Dict, List, Optional, Union |
| #5 | |
| #6 | try: |
| #7 | import boto3 |
| #8 | from botocore.exceptions import ClientError, NoCredentialsError |
| #9 | except ImportError: |
| #10 | raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.") |
| #11 | |
| #12 | from mem0.configs.llms.base import BaseLlmConfig |
| #13 | from mem0.configs.llms.aws_bedrock import AWSBedrockConfig |
| #14 | from mem0.llms.base import LLMBase |
| #15 | from mem0.memory.utils import extract_json |
| #16 | |
| #17 | logger = logging.getLogger(__name__) |
| #18 | |
| #19 | PROVIDERS = [ |
| #20 | "ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer", |
| #21 | "deepseek", "gpt-oss", "perplexity", "snowflake", "titan", "command", "j2", "llama" |
| #22 | ] |
| #23 | |
| #24 | |
| #25 | def extract_provider(model: str) -> str: |
| #26 | """Extract provider from model identifier.""" |
| #27 | for provider in PROVIDERS: |
| #28 | if re.search(rf"\b{re.escape(provider)}\b", model): |
| #29 | return provider |
| #30 | raise ValueError(f"Unknown provider in model: {model}") |
| #31 | |
| #32 | |
| #33 | class AWSBedrockLLM(LLMBase): |
| #34 | """ |
| #35 | AWS Bedrock LLM integration for Mem0. |
| #36 | |
| #37 | Supports all available Bedrock models with automatic provider detection. |
| #38 | """ |
| #39 | |
| #40 | def __init__(self, config: Optional[Union[AWSBedrockConfig, BaseLlmConfig, Dict]] = None): |
| #41 | """ |
| #42 | Initialize AWS Bedrock LLM. |
| #43 | |
| #44 | Args: |
| #45 | config: AWS Bedrock configuration object |
| #46 | """ |
| #47 | # Convert to AWSBedrockConfig if needed |
| #48 | if config is None: |
| #49 | config = AWSBedrockConfig() |
| #50 | elif isinstance(config, dict): |
| #51 | config = AWSBedrockConfig(**config) |
| #52 | elif isinstance(config, BaseLlmConfig) and not isinstance(config, AWSBedrockConfig): |
| #53 | # Convert BaseLlmConfig to AWSBedrockConfig |
| #54 | config = AWSBedrockConfig( |
| #55 | model=config.model, |
| #56 | temperature=config.temperature, |
| #57 | max_tokens=config.max_tokens, |
| #58 | top_p=config.top_p, |
| #59 | top_k=config.top_k, |
| #60 | enable_vision=getattr(config, "enable_vision", False), |
| #61 | ) |
| #62 | |
| #63 | super().__init__(config) |
| #64 | self.config = config |
| #65 | |
| #66 | # Initialize AWS client |
| #67 | self._initialize_aws_client() |
| #68 | |
| #69 | # Get model configuration |
| #70 | self.model_config = self.config.get_model_config() |
| #71 | self.provider = extract_provider(self.config.model) |
| #72 | |
| #73 | # Initialize provider-specific settings |
| #74 | self._initialize_provider_settings() |
| #75 | |
| #76 | def _initialize_aws_client(self): |
| #77 | """Initialize AWS Bedrock client with proper credentials.""" |
| #78 | try: |
| #79 | aws_config = self.config.get_aws_config() |
| #80 | |
| #81 | # Create Bedrock runtime client |
| #82 | self.client = boto3.client("bedrock-runtime", **aws_config) |
| #83 | |
| #84 | # Test connection |
| #85 | self._test_connection() |
| #86 | |
| #87 | except NoCredentialsError: |
| #88 | raise ValueError( |
| #89 | "AWS credentials not found. Please set AWS_ACCESS_KEY_ID, " |
| #90 | "AWS_SECRET_ACCESS_KEY, and AWS_REGION environment variables, " |
| #91 | "or provide them in the config." |
| #92 | ) |
| #93 | except ClientError as e: |
| #94 | if e.response["Error"]["Code"] == "UnauthorizedOperation": |
| #95 | raise ValueError( |
| #96 | f"Unauthorized access to Bedrock. Please ensure your AWS credentials " |
| #97 | f"have permission to access Bedrock in region {self.config.aws_region}." |
| #98 | ) |
| #99 | else: |
| #100 | raise ValueError(f"AWS Bedrock error: {e}") |
| #101 | |
| #102 | def _test_connection(self): |
| #103 | """Test connection to AWS Bedrock service.""" |
| #104 | try: |
| #105 | # List available models to test connection |
| #106 | bedrock_client = boto3.client("bedrock", **self.config.get_aws_config()) |
| #107 | response = bedrock_client.list_foundation_models() |
| #108 | self.available_models = [model["modelId"] for model in response["modelSummaries"]] |
| #109 | |
| #110 | # Check if our model is available |
| #111 | if self.config.model not in self.available_models: |
| #112 | logger.warning(f"Model {self.config.model} may not be available in region {self.config.aws_region}") |
| #113 | logger.info(f"Available models: {', '.join(self.available_models[:5])}...") |
| #114 | |
| #115 | except Exception as e: |
| #116 | logger.warning(f"Could not verify model availability: {e}") |
| #117 | self.available_models = [] |
| #118 | |
| #119 | def _initialize_provider_settings(self): |
| #120 | """Initialize provider-specific settings and capabilities.""" |
| #121 | # Determine capabilities based on provider and model |
| #122 | self.supports_tools = self.provider in ["anthropic", "cohere", "amazon"] |
| #123 | self.supports_vision = self.provider in ["anthropic", "amazon", "meta", "mistral"] |
| #124 | self.supports_streaming = self.provider in ["anthropic", "cohere", "mistral", "amazon", "meta"] |
| #125 | |
| #126 | # Set message formatting method |
| #127 | if self.provider == "anthropic": |
| #128 | self._format_messages = self._format_messages_anthropic |
| #129 | elif self.provider == "cohere": |
| #130 | self._format_messages = self._format_messages_cohere |
| #131 | elif self.provider == "amazon": |
| #132 | self._format_messages = self._format_messages_amazon |
| #133 | elif self.provider == "meta": |
| #134 | self._format_messages = self._format_messages_meta |
| #135 | elif self.provider == "mistral": |
| #136 | self._format_messages = self._format_messages_mistral |
| #137 | else: |
| #138 | self._format_messages = self._format_messages_generic |
| #139 | |
| #140 | def _format_messages_anthropic(self, messages: List[Dict[str, str]]) -> tuple[List[Dict[str, Any]], Optional[str]]: |
| #141 | """Format messages for Anthropic models.""" |
| #142 | formatted_messages = [] |
| #143 | system_message = None |
| #144 | |
| #145 | for message in messages: |
| #146 | role = message["role"] |
| #147 | content = message["content"] |
| #148 | |
| #149 | if role == "system": |
| #150 | # Anthropic supports system messages as a separate parameter |
| #151 | # see: https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts |
| #152 | system_message = content |
| #153 | elif role == "user": |
| #154 | # Use Converse API format |
| #155 | formatted_messages.append({"role": "user", "content": [{"text": content}]}) |
| #156 | elif role == "assistant": |
| #157 | # Use Converse API format |
| #158 | formatted_messages.append({"role": "assistant", "content": [{"text": content}]}) |
| #159 | |
| #160 | return formatted_messages, system_message |
| #161 | |
| #162 | def _format_messages_cohere(self, messages: List[Dict[str, str]]) -> str: |
| #163 | """Format messages for Cohere models.""" |
| #164 | formatted_messages = [] |
| #165 | |
| #166 | for message in messages: |
| #167 | role = message["role"].capitalize() |
| #168 | content = message["content"] |
| #169 | formatted_messages.append(f"{role}: {content}") |
| #170 | |
| #171 | return "\n".join(formatted_messages) |
| #172 | |
| #173 | def _format_messages_amazon(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]: |
| #174 | """Format messages for Amazon models (including Nova).""" |
| #175 | formatted_messages = [] |
| #176 | |
| #177 | for message in messages: |
| #178 | role = message["role"] |
| #179 | content = message["content"] |
| #180 | |
| #181 | if role == "system": |
| #182 | # Amazon models support system messages |
| #183 | formatted_messages.append({"role": "system", "content": content}) |
| #184 | elif role == "user": |
| #185 | formatted_messages.append({"role": "user", "content": content}) |
| #186 | elif role == "assistant": |
| #187 | formatted_messages.append({"role": "assistant", "content": content}) |
| #188 | |
| #189 | return formatted_messages |
| #190 | |
| #191 | def _format_messages_meta(self, messages: List[Dict[str, str]]) -> str: |
| #192 | """Format messages for Meta models.""" |
| #193 | formatted_messages = [] |
| #194 | |
| #195 | for message in messages: |
| #196 | role = message["role"].capitalize() |
| #197 | content = message["content"] |
| #198 | formatted_messages.append(f"{role}: {content}") |
| #199 | |
| #200 | return "\n".join(formatted_messages) |
| #201 | |
| #202 | def _format_messages_mistral(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]: |
| #203 | """Format messages for Mistral models.""" |
| #204 | formatted_messages = [] |
| #205 | |
| #206 | for message in messages: |
| #207 | role = message["role"] |
| #208 | content = message["content"] |
| #209 | |
| #210 | if role == "system": |
| #211 | # Mistral supports system messages |
| #212 | formatted_messages.append({"role": "system", "content": content}) |
| #213 | elif role == "user": |
| #214 | formatted_messages.append({"role": "user", "content": content}) |
| #215 | elif role == "assistant": |
| #216 | formatted_messages.append({"role": "assistant", "content": content}) |
| #217 | |
| #218 | return formatted_messages |
| #219 | |
| #220 | def _format_messages_generic(self, messages: List[Dict[str, str]]) -> str: |
| #221 | """Generic message formatting for other providers.""" |
| #222 | formatted_messages = [] |
| #223 | |
| #224 | for message in messages: |
| #225 | role = message["role"].capitalize() |
| #226 | content = message["content"] |
| #227 | formatted_messages.append(f"\n\n{role}: {content}") |
| #228 | |
| #229 | return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:" |
| #230 | |
| #231 | def _prepare_input(self, prompt: str) -> Dict[str, Any]: |
| #232 | """ |
| #233 | Prepare input for the current provider's model. |
| #234 | |
| #235 | Args: |
| #236 | prompt: Text prompt to process |
| #237 | |
| #238 | Returns: |
| #239 | Prepared input dictionary |
| #240 | """ |
| #241 | # Base configuration |
| #242 | input_body = {"prompt": prompt} |
| #243 | |
| #244 | # Provider-specific parameter mappings |
| #245 | provider_mappings = { |
| #246 | "meta": {"max_tokens": "max_gen_len"}, |
| #247 | "ai21": {"max_tokens": "maxTokens", "top_p": "topP"}, |
| #248 | "mistral": {"max_tokens": "max_tokens"}, |
| #249 | "cohere": {"max_tokens": "max_tokens", "top_p": "p"}, |
| #250 | "amazon": {"max_tokens": "maxTokenCount", "top_p": "topP"}, |
| #251 | "anthropic": {"max_tokens": "max_tokens", "top_p": "top_p"}, |
| #252 | } |
| #253 | |
| #254 | # Apply provider mappings |
| #255 | if self.provider in provider_mappings: |
| #256 | for old_key, new_key in provider_mappings[self.provider].items(): |
| #257 | if old_key in self.model_config: |
| #258 | input_body[new_key] = self.model_config[old_key] |
| #259 | |
| #260 | # Special handling for specific providers |
| #261 | if self.provider == "cohere" and "cohere.command" in self.config.model: |
| #262 | input_body["message"] = input_body.pop("prompt") |
| #263 | elif self.provider == "amazon": |
| #264 | # Amazon Nova and other Amazon models |
| #265 | if "nova" in self.config.model.lower(): |
| #266 | # Nova models use the converse API format |
| #267 | input_body = { |
| #268 | "messages": [{"role": "user", "content": prompt}], |
| #269 | "max_tokens": self.model_config.get("max_tokens", 5000), |
| #270 | "temperature": self.model_config.get("temperature", 0.1), |
| #271 | "top_p": self.model_config.get("top_p", 0.9), |
| #272 | } |
| #273 | else: |
| #274 | # Legacy Amazon models |
| #275 | input_body = { |
| #276 | "inputText": prompt, |
| #277 | "textGenerationConfig": { |
| #278 | "maxTokenCount": self.model_config.get("max_tokens", 5000), |
| #279 | "topP": self.model_config.get("top_p", 0.9), |
| #280 | "temperature": self.model_config.get("temperature", 0.1), |
| #281 | }, |
| #282 | } |
| #283 | # Remove None values |
| #284 | input_body["textGenerationConfig"] = { |
| #285 | k: v for k, v in input_body["textGenerationConfig"].items() if v is not None |
| #286 | } |
| #287 | elif self.provider == "anthropic": |
| #288 | input_body = { |
| #289 | "messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}], |
| #290 | "max_tokens": self.model_config.get("max_tokens", 2000), |
| #291 | "temperature": self.model_config.get("temperature", 0.1), |
| #292 | "top_p": self.model_config.get("top_p", 0.9), |
| #293 | "anthropic_version": "bedrock-2023-05-31", |
| #294 | } |
| #295 | elif self.provider == "meta": |
| #296 | input_body = { |
| #297 | "prompt": prompt, |
| #298 | "max_gen_len": self.model_config.get("max_tokens", 5000), |
| #299 | "temperature": self.model_config.get("temperature", 0.1), |
| #300 | "top_p": self.model_config.get("top_p", 0.9), |
| #301 | } |
| #302 | elif self.provider == "mistral": |
| #303 | input_body = { |
| #304 | "prompt": prompt, |
| #305 | "max_tokens": self.model_config.get("max_tokens", 5000), |
| #306 | "temperature": self.model_config.get("temperature", 0.1), |
| #307 | "top_p": self.model_config.get("top_p", 0.9), |
| #308 | } |
| #309 | else: |
| #310 | # Generic case - add all model config parameters |
| #311 | input_body.update(self.model_config) |
| #312 | |
| #313 | return input_body |
| #314 | |
| #315 | def _convert_tool_format(self, original_tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| #316 | """ |
| #317 | Convert tools to Bedrock-compatible format. |
| #318 | |
| #319 | Args: |
| #320 | original_tools: List of tool definitions |
| #321 | |
| #322 | Returns: |
| #323 | Converted tools in Bedrock format |
| #324 | """ |
| #325 | new_tools = [] |
| #326 | |
| #327 | for tool in original_tools: |
| #328 | if tool["type"] == "function": |
| #329 | function = tool["function"] |
| #330 | new_tool = { |
| #331 | "toolSpec": { |
| #332 | "name": function["name"], |
| #333 | "description": function.get("description", ""), |
| #334 | "inputSchema": { |
| #335 | "json": { |
| #336 | "type": "object", |
| #337 | "properties": {}, |
| #338 | "required": function["parameters"].get("required", []), |
| #339 | } |
| #340 | }, |
| #341 | } |
| #342 | } |
| #343 | |
| #344 | # Add properties |
| #345 | for prop, details in function["parameters"].get("properties", {}).items(): |
| #346 | new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = details |
| #347 | |
| #348 | new_tools.append(new_tool) |
| #349 | |
| #350 | return new_tools |
| #351 | |
| #352 | def _parse_response( |
| #353 | self, response: Dict[str, Any], tools: Optional[List[Dict]] = None |
| #354 | ) -> Union[str, Dict[str, Any]]: |
| #355 | """ |
| #356 | Parse response from Bedrock API. |
| #357 | |
| #358 | Args: |
| #359 | response: Raw API response |
| #360 | tools: List of tools if used |
| #361 | |
| #362 | Returns: |
| #363 | Parsed response |
| #364 | """ |
| #365 | if tools: |
| #366 | # Handle tool-enabled responses |
| #367 | processed_response = {"tool_calls": []} |
| #368 | |
| #369 | if response.get("output", {}).get("message", {}).get("content"): |
| #370 | for item in response["output"]["message"]["content"]: |
| #371 | if "toolUse" in item: |
| #372 | processed_response["tool_calls"].append( |
| #373 | { |
| #374 | "name": item["toolUse"]["name"], |
| #375 | "arguments": json.loads(extract_json(json.dumps(item["toolUse"]["input"]))), |
| #376 | } |
| #377 | ) |
| #378 | |
| #379 | return processed_response |
| #380 | |
| #381 | # Handle regular text responses |
| #382 | try: |
| #383 | response_body = response.get("body").read().decode() |
| #384 | response_json = json.loads(response_body) |
| #385 | |
| #386 | # Provider-specific response parsing |
| #387 | if self.provider == "anthropic": |
| #388 | return response_json.get("content", [{"text": ""}])[0].get("text", "") |
| #389 | elif self.provider == "amazon": |
| #390 | # Handle both Nova and legacy Amazon models |
| #391 | if "nova" in self.config.model.lower(): |
| #392 | # Nova models return content in a different format |
| #393 | if "content" in response_json: |
| #394 | return response_json["content"][0]["text"] |
| #395 | elif "completion" in response_json: |
| #396 | return response_json["completion"] |
| #397 | else: |
| #398 | # Legacy Amazon models |
| #399 | return response_json.get("completion", "") |
| #400 | elif self.provider == "meta": |
| #401 | return response_json.get("generation", "") |
| #402 | elif self.provider == "mistral": |
| #403 | return response_json.get("outputs", [{"text": ""}])[0].get("text", "") |
| #404 | elif self.provider == "cohere": |
| #405 | return response_json.get("generations", [{"text": ""}])[0].get("text", "") |
| #406 | elif self.provider == "ai21": |
| #407 | return response_json.get("completions", [{"data", {"text": ""}}])[0].get("data", {}).get("text", "") |
| #408 | else: |
| #409 | # Generic parsing - try common response fields |
| #410 | for field in ["content", "text", "completion", "generation"]: |
| #411 | if field in response_json: |
| #412 | if isinstance(response_json[field], list) and response_json[field]: |
| #413 | return response_json[field][0].get("text", "") |
| #414 | elif isinstance(response_json[field], str): |
| #415 | return response_json[field] |
| #416 | |
| #417 | # Fallback |
| #418 | return str(response_json) |
| #419 | |
| #420 | except Exception as e: |
| #421 | logger.warning(f"Could not parse response: {e}") |
| #422 | return "Error parsing response" |
| #423 | |
| #424 | def generate_response( |
| #425 | self, |
| #426 | messages: List[Dict[str, str]], |
| #427 | response_format: Optional[str] = None, |
| #428 | tools: Optional[List[Dict]] = None, |
| #429 | tool_choice: str = "auto", |
| #430 | stream: bool = False, |
| #431 | **kwargs, |
| #432 | ) -> Union[str, Dict[str, Any]]: |
| #433 | """ |
| #434 | Generate response using AWS Bedrock. |
| #435 | |
| #436 | Args: |
| #437 | messages: List of message dictionaries |
| #438 | response_format: Response format specification |
| #439 | tools: List of tools for function calling |
| #440 | tool_choice: Tool choice method |
| #441 | stream: Whether to stream the response |
| #442 | **kwargs: Additional parameters |
| #443 | |
| #444 | Returns: |
| #445 | Generated response |
| #446 | """ |
| #447 | try: |
| #448 | if tools and self.supports_tools: |
| #449 | # Use converse method for tool-enabled models |
| #450 | return self._generate_with_tools(messages, tools, stream) |
| #451 | else: |
| #452 | # Use standard invoke_model method |
| #453 | return self._generate_standard(messages, stream) |
| #454 | |
| #455 | except Exception as e: |
| #456 | logger.error(f"Failed to generate response: {e}") |
| #457 | raise RuntimeError(f"Failed to generate response: {e}") |
| #458 | |
| #459 | @staticmethod |
| #460 | def _convert_tools_to_converse_format(tools: List[Dict]) -> List[Dict]: |
| #461 | """Convert OpenAI-style tools to Converse API format.""" |
| #462 | if not tools: |
| #463 | return [] |
| #464 | |
| #465 | converse_tools = [] |
| #466 | for tool in tools: |
| #467 | if tool.get("type") == "function" and "function" in tool: |
| #468 | func = tool["function"] |
| #469 | converse_tool = { |
| #470 | "toolSpec": { |
| #471 | "name": func["name"], |
| #472 | "description": func.get("description", ""), |
| #473 | "inputSchema": { |
| #474 | "json": func.get("parameters", {}) |
| #475 | } |
| #476 | } |
| #477 | } |
| #478 | converse_tools.append(converse_tool) |
| #479 | |
| #480 | return converse_tools |
| #481 | |
| #482 | def _generate_with_tools(self, messages: List[Dict[str, str]], tools: List[Dict], stream: bool = False) -> Dict[str, Any]: |
| #483 | """Generate response with tool calling support using correct message format.""" |
| #484 | # Format messages for tool-enabled models |
| #485 | system_message = None |
| #486 | if self.provider == "anthropic": |
| #487 | formatted_messages, system_message = self._format_messages_anthropic(messages) |
| #488 | elif self.provider == "amazon": |
| #489 | formatted_messages = self._format_messages_amazon(messages) |
| #490 | else: |
| #491 | formatted_messages = [{"role": "user", "content": [{"text": messages[-1]["content"]}]}] |
| #492 | |
| #493 | # Prepare tool configuration in Converse API format |
| #494 | tool_config = None |
| #495 | if tools: |
| #496 | converse_tools = self._convert_tools_to_converse_format(tools) |
| #497 | if converse_tools: |
| #498 | tool_config = {"tools": converse_tools} |
| #499 | |
| #500 | # Prepare converse parameters |
| #501 | converse_params = { |
| #502 | "modelId": self.config.model, |
| #503 | "messages": formatted_messages, |
| #504 | "inferenceConfig": { |
| #505 | "maxTokens": self.model_config.get("max_tokens", 2000), |
| #506 | "temperature": self.model_config.get("temperature", 0.1), |
| #507 | "topP": self.model_config.get("top_p", 0.9), |
| #508 | } |
| #509 | } |
| #510 | |
| #511 | # Add system message if present (for Anthropic) |
| #512 | if system_message: |
| #513 | converse_params["system"] = [{"text": system_message}] |
| #514 | |
| #515 | # Add tool config if present |
| #516 | if tool_config: |
| #517 | converse_params["toolConfig"] = tool_config |
| #518 | |
| #519 | # Make API call |
| #520 | response = self.client.converse(**converse_params) |
| #521 | |
| #522 | return self._parse_response(response, tools) |
| #523 | |
| #524 | def _generate_standard(self, messages: List[Dict[str, str]], stream: bool = False) -> str: |
| #525 | """Generate standard text response using Converse API for Anthropic models.""" |
| #526 | # For Anthropic models, always use Converse API |
| #527 | if self.provider == "anthropic": |
| #528 | formatted_messages, system_message = self._format_messages_anthropic(messages) |
| #529 | |
| #530 | # Prepare converse parameters |
| #531 | converse_params = { |
| #532 | "modelId": self.config.model, |
| #533 | "messages": formatted_messages, |
| #534 | "inferenceConfig": { |
| #535 | "maxTokens": self.model_config.get("max_tokens", 2000), |
| #536 | "temperature": self.model_config.get("temperature", 0.1), |
| #537 | "topP": self.model_config.get("top_p", 0.9), |
| #538 | } |
| #539 | } |
| #540 | |
| #541 | # Add system message if present |
| #542 | if system_message: |
| #543 | converse_params["system"] = [{"text": system_message}] |
| #544 | |
| #545 | # Use converse API for Anthropic models |
| #546 | response = self.client.converse(**converse_params) |
| #547 | |
| #548 | # Parse Converse API response |
| #549 | if hasattr(response, 'output') and hasattr(response.output, 'message'): |
| #550 | return response.output.message.content[0].text |
| #551 | elif 'output' in response and 'message' in response['output']: |
| #552 | return response['output']['message']['content'][0]['text'] |
| #553 | else: |
| #554 | return str(response) |
| #555 | |
| #556 | elif self.provider == "amazon" and "nova" in self.config.model.lower(): |
| #557 | # Nova models use converse API even without tools |
| #558 | formatted_messages = self._format_messages_amazon(messages) |
| #559 | input_body = { |
| #560 | "messages": formatted_messages, |
| #561 | "max_tokens": self.model_config.get("max_tokens", 5000), |
| #562 | "temperature": self.model_config.get("temperature", 0.1), |
| #563 | "top_p": self.model_config.get("top_p", 0.9), |
| #564 | } |
| #565 | |
| #566 | # Use converse API for Nova models |
| #567 | response = self.client.converse( |
| #568 | modelId=self.config.model, |
| #569 | messages=input_body["messages"], |
| #570 | inferenceConfig={ |
| #571 | "maxTokens": input_body["max_tokens"], |
| #572 | "temperature": input_body["temperature"], |
| #573 | "topP": input_body["top_p"], |
| #574 | } |
| #575 | ) |
| #576 | |
| #577 | return self._parse_response(response) |
| #578 | else: |
| #579 | # For other providers and legacy Amazon models (like Titan) |
| #580 | if self.provider == "amazon": |
| #581 | # Legacy Amazon models need string formatting, not array formatting |
| #582 | prompt = self._format_messages_generic(messages) |
| #583 | else: |
| #584 | prompt = self._format_messages(messages) |
| #585 | input_body = self._prepare_input(prompt) |
| #586 | |
| #587 | # Convert to JSON |
| #588 | body = json.dumps(input_body) |
| #589 | |
| #590 | # Make API call |
| #591 | response = self.client.invoke_model( |
| #592 | body=body, |
| #593 | modelId=self.config.model, |
| #594 | accept="application/json", |
| #595 | contentType="application/json", |
| #596 | ) |
| #597 | |
| #598 | return self._parse_response(response) |
| #599 | |
| #600 | def list_available_models(self) -> List[Dict[str, Any]]: |
| #601 | """List all available models in the current region.""" |
| #602 | try: |
| #603 | bedrock_client = boto3.client("bedrock", **self.config.get_aws_config()) |
| #604 | response = bedrock_client.list_foundation_models() |
| #605 | |
| #606 | models = [] |
| #607 | for model in response["modelSummaries"]: |
| #608 | provider = extract_provider(model["modelId"]) |
| #609 | models.append( |
| #610 | { |
| #611 | "model_id": model["modelId"], |
| #612 | "provider": provider, |
| #613 | "model_name": model["modelId"].split(".", 1)[1] |
| #614 | if "." in model["modelId"] |
| #615 | else model["modelId"], |
| #616 | "modelArn": model.get("modelArn", ""), |
| #617 | "providerName": model.get("providerName", ""), |
| #618 | "inputModalities": model.get("inputModalities", []), |
| #619 | "outputModalities": model.get("outputModalities", []), |
| #620 | "responseStreamingSupported": model.get("responseStreamingSupported", False), |
| #621 | } |
| #622 | ) |
| #623 | |
| #624 | return models |
| #625 | |
| #626 | except Exception as e: |
| #627 | logger.warning(f"Could not list models: {e}") |
| #628 | return [] |
| #629 | |
| #630 | def get_model_capabilities(self) -> Dict[str, Any]: |
| #631 | """Get capabilities of the current model.""" |
| #632 | return { |
| #633 | "model_id": self.config.model, |
| #634 | "provider": self.provider, |
| #635 | "model_name": self.config.model_name, |
| #636 | "supports_tools": self.supports_tools, |
| #637 | "supports_vision": self.supports_vision, |
| #638 | "supports_streaming": self.supports_streaming, |
| #639 | "max_tokens": self.model_config.get("max_tokens", 2000), |
| #640 | } |
| #641 | |
| #642 | def validate_model_access(self) -> bool: |
| #643 | """Validate if the model is accessible.""" |
| #644 | try: |
| #645 | # Try to invoke the model with a minimal request |
| #646 | if self.provider == "amazon" and "nova" in self.config.model.lower(): |
| #647 | # Test Nova model with converse API |
| #648 | test_messages = [{"role": "user", "content": "test"}] |
| #649 | self.client.converse( |
| #650 | modelId=self.config.model, |
| #651 | messages=test_messages, |
| #652 | inferenceConfig={"maxTokens": 10} |
| #653 | ) |
| #654 | else: |
| #655 | # Test other models with invoke_model |
| #656 | test_body = json.dumps({"prompt": "test"}) |
| #657 | self.client.invoke_model( |
| #658 | body=test_body, |
| #659 | modelId=self.config.model, |
| #660 | accept="application/json", |
| #661 | contentType="application/json", |
| #662 | ) |
| #663 | return True |
| #664 | except Exception: |
| #665 | return False |
| #666 |