repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import json |
| #2 | import os |
| #3 | from typing import Literal, Optional |
| #4 | |
| #5 | try: |
| #6 | import boto3 |
| #7 | except ImportError: |
| #8 | raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.") |
| #9 | |
| #10 | import numpy as np |
| #11 | |
| #12 | from mem0.configs.embeddings.base import BaseEmbedderConfig |
| #13 | from mem0.embeddings.base import EmbeddingBase |
| #14 | |
| #15 | |
| #16 | class AWSBedrockEmbedding(EmbeddingBase): |
| #17 | """AWS Bedrock embedding implementation. |
| #18 | |
| #19 | This class uses AWS Bedrock's embedding models. |
| #20 | """ |
| #21 | |
| #22 | def __init__(self, config: Optional[BaseEmbedderConfig] = None): |
| #23 | super().__init__(config) |
| #24 | |
| #25 | self.config.model = self.config.model or "amazon.titan-embed-text-v1" |
| #26 | |
| #27 | # Get AWS config from environment variables or use defaults |
| #28 | aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "") |
| #29 | aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "") |
| #30 | aws_session_token = os.environ.get("AWS_SESSION_TOKEN", "") |
| #31 | |
| #32 | # Check if AWS config is provided in the config |
| #33 | if hasattr(self.config, "aws_access_key_id"): |
| #34 | aws_access_key = self.config.aws_access_key_id |
| #35 | if hasattr(self.config, "aws_secret_access_key"): |
| #36 | aws_secret_key = self.config.aws_secret_access_key |
| #37 | |
| #38 | # AWS region is always set in config - see BaseEmbedderConfig |
| #39 | aws_region = self.config.aws_region or "us-west-2" |
| #40 | |
| #41 | self.client = boto3.client( |
| #42 | "bedrock-runtime", |
| #43 | region_name=aws_region, |
| #44 | aws_access_key_id=aws_access_key if aws_access_key else None, |
| #45 | aws_secret_access_key=aws_secret_key if aws_secret_key else None, |
| #46 | aws_session_token=aws_session_token if aws_session_token else None, |
| #47 | ) |
| #48 | |
| #49 | def _normalize_vector(self, embeddings): |
| #50 | """Normalize the embedding to a unit vector.""" |
| #51 | emb = np.array(embeddings) |
| #52 | norm_emb = emb / np.linalg.norm(emb) |
| #53 | return norm_emb.tolist() |
| #54 | |
| #55 | def _get_embedding(self, text): |
| #56 | """Call out to Bedrock embedding endpoint.""" |
| #57 | |
| #58 | # Format input body based on the provider |
| #59 | provider = self.config.model.split(".")[0] |
| #60 | input_body = {} |
| #61 | |
| #62 | if provider == "cohere": |
| #63 | input_body["input_type"] = "search_document" |
| #64 | input_body["texts"] = [text] |
| #65 | else: |
| #66 | # Amazon and other providers |
| #67 | input_body["inputText"] = text |
| #68 | |
| #69 | body = json.dumps(input_body) |
| #70 | |
| #71 | try: |
| #72 | response = self.client.invoke_model( |
| #73 | body=body, |
| #74 | modelId=self.config.model, |
| #75 | accept="application/json", |
| #76 | contentType="application/json", |
| #77 | ) |
| #78 | |
| #79 | response_body = json.loads(response.get("body").read()) |
| #80 | |
| #81 | if provider == "cohere": |
| #82 | embeddings = response_body.get("embeddings")[0] |
| #83 | else: |
| #84 | embeddings = response_body.get("embedding") |
| #85 | |
| #86 | return embeddings |
| #87 | except Exception as e: |
| #88 | raise ValueError(f"Error getting embedding from AWS Bedrock: {e}") |
| #89 | |
| #90 | def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): |
| #91 | """ |
| #92 | Get the embedding for the given text using AWS Bedrock. |
| #93 | |
| #94 | Args: |
| #95 | text (str): The text to embed. |
| #96 | memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None. |
| #97 | Returns: |
| #98 | list: The embedding vector. |
| #99 | """ |
| #100 | return self._get_embedding(text) |
| #101 |