repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import importlib |
| #2 | import os |
| #3 | from typing import Optional |
| #4 | |
| #5 | from langchain_community.llms.replicate import Replicate |
| #6 | |
| #7 | from embedchain.config import BaseLlmConfig |
| #8 | from embedchain.helpers.json_serializable import register_deserializable |
| #9 | from embedchain.llm.base import BaseLlm |
| #10 | |
| #11 | |
| #12 | @register_deserializable |
| #13 | class Llama2Llm(BaseLlm): |
| #14 | def __init__(self, config: Optional[BaseLlmConfig] = None): |
| #15 | try: |
| #16 | importlib.import_module("replicate") |
| #17 | except ModuleNotFoundError: |
| #18 | raise ModuleNotFoundError( |
| #19 | "The required dependencies for Llama2 are not installed." |
| #20 | 'Please install with `pip install --upgrade "embedchain[llama2]"`' |
| #21 | ) from None |
| #22 | |
| #23 | # Set default config values specific to this llm |
| #24 | if not config: |
| #25 | config = BaseLlmConfig() |
| #26 | # Add variables to this block that have a default value in the parent class |
| #27 | config.max_tokens = 500 |
| #28 | config.temperature = 0.75 |
| #29 | # Add variables that are `none` by default to this block. |
| #30 | if not config.model: |
| #31 | config.model = ( |
| #32 | "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5" |
| #33 | ) |
| #34 | |
| #35 | super().__init__(config=config) |
| #36 | if not self.config.api_key and "REPLICATE_API_TOKEN" not in os.environ: |
| #37 | raise ValueError("Please set the REPLICATE_API_TOKEN environment variable or pass it in the config.") |
| #38 | |
| #39 | def get_llm_model_answer(self, prompt): |
| #40 | # TODO: Move the model and other inputs into config |
| #41 | if self.config.system_prompt: |
| #42 | raise ValueError("Llama2 does not support `system_prompt`") |
| #43 | api_key = self.config.api_key or os.getenv("REPLICATE_API_TOKEN") |
| #44 | llm = Replicate( |
| #45 | model=self.config.model, |
| #46 | replicate_api_token=api_key, |
| #47 | input={ |
| #48 | "temperature": self.config.temperature, |
| #49 | "max_length": self.config.max_tokens, |
| #50 | "top_p": self.config.top_p, |
| #51 | }, |
| #52 | ) |
| #53 | return llm.invoke(prompt) |
| #54 |