repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import os |
| #2 | |
| #3 | import pytest |
| #4 | |
| #5 | from embedchain.config import BaseLlmConfig |
| #6 | from embedchain.llm.cohere import CohereLlm |
| #7 | |
| #8 | |
| #9 | @pytest.fixture |
| #10 | def cohere_llm_config(): |
| #11 | os.environ["COHERE_API_KEY"] = "test_api_key" |
| #12 | config = BaseLlmConfig(model="command-r", max_tokens=100, temperature=0.7, top_p=0.8, token_usage=False) |
| #13 | yield config |
| #14 | os.environ.pop("COHERE_API_KEY") |
| #15 | |
| #16 | |
| #17 | def test_init_raises_value_error_without_api_key(mocker): |
| #18 | mocker.patch.dict(os.environ, clear=True) |
| #19 | with pytest.raises(ValueError): |
| #20 | CohereLlm() |
| #21 | |
| #22 | |
| #23 | def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config): |
| #24 | llm = CohereLlm(cohere_llm_config) |
| #25 | llm.config.system_prompt = "system_prompt" |
| #26 | with pytest.raises(ValueError): |
| #27 | llm.get_llm_model_answer("prompt") |
| #28 | |
| #29 | |
| #30 | def test_get_llm_model_answer(cohere_llm_config, mocker): |
| #31 | mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer") |
| #32 | |
| #33 | llm = CohereLlm(cohere_llm_config) |
| #34 | answer = llm.get_llm_model_answer("Test query") |
| #35 | |
| #36 | assert answer == "Test answer" |
| #37 | |
| #38 | |
| #39 | def test_get_llm_model_answer_with_token_usage(cohere_llm_config, mocker): |
| #40 | test_config = BaseLlmConfig( |
| #41 | temperature=cohere_llm_config.temperature, |
| #42 | max_tokens=cohere_llm_config.max_tokens, |
| #43 | top_p=cohere_llm_config.top_p, |
| #44 | model=cohere_llm_config.model, |
| #45 | token_usage=True, |
| #46 | ) |
| #47 | mocker.patch( |
| #48 | "embedchain.llm.cohere.CohereLlm._get_answer", |
| #49 | return_value=("Test answer", {"input_tokens": 1, "output_tokens": 2}), |
| #50 | ) |
| #51 | |
| #52 | llm = CohereLlm(test_config) |
| #53 | answer, token_info = llm.get_llm_model_answer("Test query") |
| #54 | |
| #55 | assert answer == "Test answer" |
| #56 | assert token_info == { |
| #57 | "prompt_tokens": 1, |
| #58 | "completion_tokens": 2, |
| #59 | "total_tokens": 3, |
| #60 | "total_cost": 3.5e-06, |
| #61 | "cost_currency": "USD", |
| #62 | } |
| #63 | |
| #64 | |
| #65 | def test_get_answer_mocked_cohere(cohere_llm_config, mocker): |
| #66 | mocked_cohere = mocker.patch("embedchain.llm.cohere.ChatCohere") |
| #67 | mocked_cohere.return_value.invoke.return_value.content = "Mocked answer" |
| #68 | |
| #69 | llm = CohereLlm(cohere_llm_config) |
| #70 | prompt = "Test query" |
| #71 | answer = llm.get_llm_model_answer(prompt) |
| #72 | |
| #73 | assert answer == "Mocked answer" |
| #74 |