repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | from typing import Dict, Optional |
| #3 | |
| #4 | from pydantic import BaseModel |
| #5 | |
| #6 | from mem0.configs.vector_stores.milvus import MetricType |
| #7 | from mem0.vector_stores.base import VectorStoreBase |
| #8 | |
| #9 | try: |
| #10 | import pymilvus # noqa: F401 |
| #11 | except ImportError: |
| #12 | raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.") |
| #13 | |
| #14 | from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient |
| #15 | |
| #16 | logger = logging.getLogger(__name__) |
| #17 | |
| #18 | |
| #19 | class OutputData(BaseModel): |
| #20 | id: Optional[str] # memory id |
| #21 | score: Optional[float] # distance |
| #22 | payload: Optional[Dict] # metadata |
| #23 | |
| #24 | |
| #25 | class MilvusDB(VectorStoreBase): |
| #26 | def __init__( |
| #27 | self, |
| #28 | url: str, |
| #29 | token: str, |
| #30 | collection_name: str, |
| #31 | embedding_model_dims: int, |
| #32 | metric_type: MetricType, |
| #33 | db_name: str, |
| #34 | ) -> None: |
| #35 | """Initialize the MilvusDB database. |
| #36 | |
| #37 | Args: |
| #38 | url (str): Full URL for Milvus/Zilliz server. |
| #39 | token (str): Token/api_key for Zilliz server / for local setup defaults to None. |
| #40 | collection_name (str): Name of the collection (defaults to mem0). |
| #41 | embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536). |
| #42 | metric_type (MetricType): Metric type for similarity search (defaults to L2). |
| #43 | db_name (str): Name of the database (defaults to ""). |
| #44 | """ |
| #45 | self.collection_name = collection_name |
| #46 | self.embedding_model_dims = embedding_model_dims |
| #47 | self.metric_type = metric_type |
| #48 | self.client = MilvusClient(uri=url, token=token, db_name=db_name) |
| #49 | self.create_col( |
| #50 | collection_name=self.collection_name, |
| #51 | vector_size=self.embedding_model_dims, |
| #52 | metric_type=self.metric_type, |
| #53 | ) |
| #54 | |
| #55 | def create_col( |
| #56 | self, |
| #57 | collection_name: str, |
| #58 | vector_size: int, |
| #59 | metric_type: MetricType = MetricType.COSINE, |
| #60 | ) -> None: |
| #61 | """Create a new collection with index_type AUTOINDEX. |
| #62 | |
| #63 | Args: |
| #64 | collection_name (str): Name of the collection (defaults to mem0). |
| #65 | vector_size (int): Dimensions of the embedding model (defaults to 1536). |
| #66 | metric_type (MetricType, optional): etric type for similarity search. Defaults to MetricType.COSINE. |
| #67 | """ |
| #68 | |
| #69 | if self.client.has_collection(collection_name): |
| #70 | logger.info(f"Collection {collection_name} already exists. Skipping creation.") |
| #71 | else: |
| #72 | fields = [ |
| #73 | FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512), |
| #74 | FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size), |
| #75 | FieldSchema(name="metadata", dtype=DataType.JSON), |
| #76 | ] |
| #77 | |
| #78 | schema = CollectionSchema(fields, enable_dynamic_field=True) |
| #79 | |
| #80 | index = self.client.prepare_index_params( |
| #81 | field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index" |
| #82 | ) |
| #83 | self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index) |
| #84 | |
| #85 | def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]): |
| #86 | """Insert vectors into a collection. |
| #87 | |
| #88 | Args: |
| #89 | vectors (List[List[float]]): List of vectors to insert. |
| #90 | payloads (List[Dict], optional): List of payloads corresponding to vectors. |
| #91 | ids (List[str], optional): List of IDs corresponding to vectors. |
| #92 | """ |
| #93 | # Batch insert all records at once for better performance and consistency |
| #94 | data = [ |
| #95 | {"id": idx, "vectors": embedding, "metadata": metadata} |
| #96 | for idx, embedding, metadata in zip(ids, vectors, payloads) |
| #97 | ] |
| #98 | self.client.insert(collection_name=self.collection_name, data=data, **kwargs) |
| #99 | |
| #100 | def _create_filter(self, filters: dict): |
| #101 | """Prepare filters for efficient query. |
| #102 | |
| #103 | Args: |
| #104 | filters (dict): filters [user_id, agent_id, run_id] |
| #105 | |
| #106 | Returns: |
| #107 | str: formated filter. |
| #108 | """ |
| #109 | operands = [] |
| #110 | for key, value in filters.items(): |
| #111 | if isinstance(value, str): |
| #112 | operands.append(f'(metadata["{key}"] == "{value}")') |
| #113 | else: |
| #114 | operands.append(f'(metadata["{key}"] == {value})') |
| #115 | |
| #116 | return " and ".join(operands) |
| #117 | |
| #118 | def _parse_output(self, data: list): |
| #119 | """ |
| #120 | Parse the output data. |
| #121 | |
| #122 | Args: |
| #123 | data (Dict): Output data. |
| #124 | |
| #125 | Returns: |
| #126 | List[OutputData]: Parsed output data. |
| #127 | """ |
| #128 | memory = [] |
| #129 | |
| #130 | for value in data: |
| #131 | uid, score, metadata = ( |
| #132 | value.get("id"), |
| #133 | value.get("distance"), |
| #134 | value.get("entity", {}).get("metadata"), |
| #135 | ) |
| #136 | |
| #137 | memory_obj = OutputData(id=uid, score=score, payload=metadata) |
| #138 | memory.append(memory_obj) |
| #139 | |
| #140 | return memory |
| #141 | |
| #142 | def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list: |
| #143 | """ |
| #144 | Search for similar vectors. |
| #145 | |
| #146 | Args: |
| #147 | query (str): Query. |
| #148 | vectors (List[float]): Query vector. |
| #149 | limit (int, optional): Number of results to return. Defaults to 5. |
| #150 | filters (Dict, optional): Filters to apply to the search. Defaults to None. |
| #151 | |
| #152 | Returns: |
| #153 | list: Search results. |
| #154 | """ |
| #155 | query_filter = self._create_filter(filters) if filters else None |
| #156 | hits = self.client.search( |
| #157 | collection_name=self.collection_name, |
| #158 | data=[vectors], |
| #159 | limit=limit, |
| #160 | filter=query_filter, |
| #161 | output_fields=["*"], |
| #162 | ) |
| #163 | result = self._parse_output(data=hits[0]) |
| #164 | return result |
| #165 | |
| #166 | def delete(self, vector_id): |
| #167 | """ |
| #168 | Delete a vector by ID. |
| #169 | |
| #170 | Args: |
| #171 | vector_id (str): ID of the vector to delete. |
| #172 | """ |
| #173 | self.client.delete(collection_name=self.collection_name, ids=vector_id) |
| #174 | |
| #175 | def update(self, vector_id=None, vector=None, payload=None): |
| #176 | """ |
| #177 | Update a vector and its payload. |
| #178 | |
| #179 | Args: |
| #180 | vector_id (str): ID of the vector to update. |
| #181 | vector (List[float], optional): Updated vector. |
| #182 | payload (Dict, optional): Updated payload. |
| #183 | """ |
| #184 | schema = {"id": vector_id, "vectors": vector, "metadata": payload} |
| #185 | self.client.upsert(collection_name=self.collection_name, data=schema) |
| #186 | |
| #187 | def get(self, vector_id): |
| #188 | """ |
| #189 | Retrieve a vector by ID. |
| #190 | |
| #191 | Args: |
| #192 | vector_id (str): ID of the vector to retrieve. |
| #193 | |
| #194 | Returns: |
| #195 | OutputData: Retrieved vector. |
| #196 | """ |
| #197 | result = self.client.get(collection_name=self.collection_name, ids=vector_id) |
| #198 | output = OutputData( |
| #199 | id=result[0].get("id", None), |
| #200 | score=None, |
| #201 | payload=result[0].get("metadata", None), |
| #202 | ) |
| #203 | return output |
| #204 | |
| #205 | def list_cols(self): |
| #206 | """ |
| #207 | List all collections. |
| #208 | |
| #209 | Returns: |
| #210 | List[str]: List of collection names. |
| #211 | """ |
| #212 | return self.client.list_collections() |
| #213 | |
| #214 | def delete_col(self): |
| #215 | """Delete a collection.""" |
| #216 | return self.client.drop_collection(collection_name=self.collection_name) |
| #217 | |
| #218 | def col_info(self): |
| #219 | """ |
| #220 | Get information about a collection. |
| #221 | |
| #222 | Returns: |
| #223 | Dict[str, Any]: Collection information. |
| #224 | """ |
| #225 | return self.client.get_collection_stats(collection_name=self.collection_name) |
| #226 | |
| #227 | def list(self, filters: dict = None, limit: int = 100) -> list: |
| #228 | """ |
| #229 | List all vectors in a collection. |
| #230 | |
| #231 | Args: |
| #232 | filters (Dict, optional): Filters to apply to the list. |
| #233 | limit (int, optional): Number of vectors to return. Defaults to 100. |
| #234 | |
| #235 | Returns: |
| #236 | List[OutputData]: List of vectors. |
| #237 | """ |
| #238 | query_filter = self._create_filter(filters) if filters else None |
| #239 | result = self.client.query(collection_name=self.collection_name, filter=query_filter, limit=limit) |
| #240 | memories = [] |
| #241 | for data in result: |
| #242 | obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata")) |
| #243 | memories.append(obj) |
| #244 | return [memories] |
| #245 | |
| #246 | def reset(self): |
| #247 | """Reset the index by deleting and recreating it.""" |
| #248 | logger.warning(f"Resetting index {self.collection_name}...") |
| #249 | self.delete_col() |
| #250 | self.create_col(self.collection_name, self.embedding_model_dims, self.metric_type) |
| #251 |