repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | from typing import Dict, List, Optional |
| #3 | |
| #4 | from pydantic import BaseModel |
| #5 | |
| #6 | try: |
| #7 | from langchain_community.vectorstores import VectorStore |
| #8 | except ImportError: |
| #9 | raise ImportError( |
| #10 | "The 'langchain_community' library is required. Please install it using 'pip install langchain_community'." |
| #11 | ) |
| #12 | |
| #13 | from mem0.vector_stores.base import VectorStoreBase |
| #14 | |
| #15 | logger = logging.getLogger(__name__) |
| #16 | |
| #17 | |
| #18 | class OutputData(BaseModel): |
| #19 | id: Optional[str] # memory id |
| #20 | score: Optional[float] # distance |
| #21 | payload: Optional[Dict] # metadata |
| #22 | |
| #23 | |
| #24 | class Langchain(VectorStoreBase): |
| #25 | def __init__(self, client: VectorStore, collection_name: str = "mem0"): |
| #26 | self.client = client |
| #27 | self.collection_name = collection_name |
| #28 | |
| #29 | def _parse_output(self, data: Dict) -> List[OutputData]: |
| #30 | """ |
| #31 | Parse the output data. |
| #32 | |
| #33 | Args: |
| #34 | data (Dict): Output data or list of Document objects. |
| #35 | |
| #36 | Returns: |
| #37 | List[OutputData]: Parsed output data. |
| #38 | """ |
| #39 | # Check if input is a list of Document objects |
| #40 | if isinstance(data, list) and all(hasattr(doc, "metadata") for doc in data if hasattr(doc, "__dict__")): |
| #41 | result = [] |
| #42 | for doc in data: |
| #43 | entry = OutputData( |
| #44 | id=getattr(doc, "id", None), |
| #45 | score=None, # Document objects typically don't include scores |
| #46 | payload=getattr(doc, "metadata", {}), |
| #47 | ) |
| #48 | result.append(entry) |
| #49 | return result |
| #50 | |
| #51 | # Original format handling |
| #52 | keys = ["ids", "distances", "metadatas"] |
| #53 | values = [] |
| #54 | |
| #55 | for key in keys: |
| #56 | value = data.get(key, []) |
| #57 | if isinstance(value, list) and value and isinstance(value[0], list): |
| #58 | value = value[0] |
| #59 | values.append(value) |
| #60 | |
| #61 | ids, distances, metadatas = values |
| #62 | max_length = max(len(v) for v in values if isinstance(v, list) and v is not None) |
| #63 | |
| #64 | result = [] |
| #65 | for i in range(max_length): |
| #66 | entry = OutputData( |
| #67 | id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None, |
| #68 | score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None), |
| #69 | payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None), |
| #70 | ) |
| #71 | result.append(entry) |
| #72 | |
| #73 | return result |
| #74 | |
| #75 | def create_col(self, name, vector_size=None, distance=None): |
| #76 | self.collection_name = name |
| #77 | return self.client |
| #78 | |
| #79 | def insert( |
| #80 | self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None |
| #81 | ): |
| #82 | """ |
| #83 | Insert vectors into the LangChain vectorstore. |
| #84 | """ |
| #85 | # Check if client has add_embeddings method |
| #86 | if hasattr(self.client, "add_embeddings"): |
| #87 | # Some LangChain vectorstores have a direct add_embeddings method |
| #88 | self.client.add_embeddings(embeddings=vectors, metadatas=payloads, ids=ids) |
| #89 | else: |
| #90 | # Fallback to add_texts method |
| #91 | texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors) |
| #92 | self.client.add_texts(texts=texts, metadatas=payloads, ids=ids) |
| #93 | |
| #94 | def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None): |
| #95 | """ |
| #96 | Search for similar vectors in LangChain. |
| #97 | """ |
| #98 | # For each vector, perform a similarity search |
| #99 | if filters: |
| #100 | results = self.client.similarity_search_by_vector(embedding=vectors, k=limit, filter=filters) |
| #101 | else: |
| #102 | results = self.client.similarity_search_by_vector(embedding=vectors, k=limit) |
| #103 | |
| #104 | final_results = self._parse_output(results) |
| #105 | return final_results |
| #106 | |
| #107 | def delete(self, vector_id): |
| #108 | """ |
| #109 | Delete a vector by ID. |
| #110 | """ |
| #111 | self.client.delete(ids=[vector_id]) |
| #112 | |
| #113 | def update(self, vector_id, vector=None, payload=None): |
| #114 | """ |
| #115 | Update a vector and its payload. |
| #116 | """ |
| #117 | self.delete(vector_id) |
| #118 | self.insert(vector, payload, [vector_id]) |
| #119 | |
| #120 | def get(self, vector_id): |
| #121 | """ |
| #122 | Retrieve a vector by ID. |
| #123 | """ |
| #124 | docs = self.client.get_by_ids([vector_id]) |
| #125 | if docs and len(docs) > 0: |
| #126 | doc = docs[0] |
| #127 | return self._parse_output([doc])[0] |
| #128 | return None |
| #129 | |
| #130 | def list_cols(self): |
| #131 | """ |
| #132 | List all collections. |
| #133 | """ |
| #134 | # LangChain doesn't have collections |
| #135 | return [self.collection_name] |
| #136 | |
| #137 | def delete_col(self): |
| #138 | """ |
| #139 | Delete a collection. |
| #140 | """ |
| #141 | logger.warning("Deleting collection") |
| #142 | if hasattr(self.client, "delete_collection"): |
| #143 | self.client.delete_collection() |
| #144 | elif hasattr(self.client, "reset_collection"): |
| #145 | self.client.reset_collection() |
| #146 | else: |
| #147 | self.client.delete(ids=None) |
| #148 | |
| #149 | def col_info(self): |
| #150 | """ |
| #151 | Get information about a collection. |
| #152 | """ |
| #153 | return {"name": self.collection_name} |
| #154 | |
| #155 | def list(self, filters=None, limit=None): |
| #156 | """ |
| #157 | List all vectors in a collection. |
| #158 | """ |
| #159 | try: |
| #160 | if hasattr(self.client, "_collection") and hasattr(self.client._collection, "get"): |
| #161 | # Convert mem0 filters to Chroma where clause if needed |
| #162 | where_clause = None |
| #163 | if filters: |
| #164 | # Handle all filters, not just user_id |
| #165 | where_clause = filters |
| #166 | |
| #167 | result = self.client._collection.get(where=where_clause, limit=limit) |
| #168 | |
| #169 | # Convert the result to the expected format |
| #170 | if result and isinstance(result, dict): |
| #171 | return [self._parse_output(result)] |
| #172 | return [] |
| #173 | except Exception as e: |
| #174 | logger.error(f"Error listing vectors from Chroma: {e}") |
| #175 | return [] |
| #176 | |
| #177 | def reset(self): |
| #178 | """Reset the index by deleting and recreating it.""" |
| #179 | logger.warning(f"Resetting collection: {self.collection_name}") |
| #180 | self.delete_col() |
| #181 |