repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import logging |
| #2 | import uuid |
| #3 | from typing import Dict, List, Mapping, Optional |
| #4 | from urllib.parse import urlparse |
| #5 | |
| #6 | from pydantic import BaseModel |
| #7 | |
| #8 | try: |
| #9 | import weaviate |
| #10 | except ImportError: |
| #11 | raise ImportError( |
| #12 | "The 'weaviate' library is required. Please install it using 'pip install weaviate-client weaviate'." |
| #13 | ) |
| #14 | |
| #15 | import weaviate.classes.config as wvcc |
| #16 | from weaviate.classes.init import AdditionalConfig, Auth, Timeout |
| #17 | from weaviate.classes.query import Filter, MetadataQuery |
| #18 | from weaviate.util import get_valid_uuid |
| #19 | |
| #20 | from mem0.vector_stores.base import VectorStoreBase |
| #21 | |
| #22 | logger = logging.getLogger(__name__) |
| #23 | |
| #24 | |
| #25 | class OutputData(BaseModel): |
| #26 | id: str |
| #27 | score: float |
| #28 | payload: Dict |
| #29 | |
| #30 | |
| #31 | class Weaviate(VectorStoreBase): |
| #32 | def __init__( |
| #33 | self, |
| #34 | collection_name: str, |
| #35 | embedding_model_dims: int, |
| #36 | cluster_url: str = None, |
| #37 | auth_client_secret: str = None, |
| #38 | additional_headers: dict = None, |
| #39 | ): |
| #40 | """ |
| #41 | Initialize the Weaviate vector store. |
| #42 | |
| #43 | Args: |
| #44 | collection_name (str): Name of the collection/class in Weaviate. |
| #45 | embedding_model_dims (int): Dimensions of the embedding model. |
| #46 | client (WeaviateClient, optional): Existing Weaviate client instance. Defaults to None. |
| #47 | cluster_url (str, optional): URL for Weaviate server. Defaults to None. |
| #48 | auth_config (dict, optional): Authentication configuration for Weaviate. Defaults to None. |
| #49 | additional_headers (dict, optional): Additional headers for requests. Defaults to None. |
| #50 | """ |
| #51 | if "localhost" in cluster_url: |
| #52 | self.client = weaviate.connect_to_local(headers=additional_headers) |
| #53 | elif auth_client_secret: |
| #54 | self.client = weaviate.connect_to_weaviate_cloud( |
| #55 | cluster_url=cluster_url, |
| #56 | auth_credentials=Auth.api_key(auth_client_secret), |
| #57 | headers=additional_headers, |
| #58 | ) |
| #59 | else: |
| #60 | parsed = urlparse(cluster_url) # e.g., http://mem0_store:8080 |
| #61 | http_host = parsed.hostname or "localhost" |
| #62 | http_port = parsed.port or (443 if parsed.scheme == "https" else 8080) |
| #63 | http_secure = parsed.scheme == "https" |
| #64 | |
| #65 | # Weaviate gRPC defaults (inside Docker network) |
| #66 | grpc_host = http_host |
| #67 | grpc_port = 50051 |
| #68 | grpc_secure = False |
| #69 | |
| #70 | self.client = weaviate.connect_to_custom( |
| #71 | http_host, |
| #72 | http_port, |
| #73 | http_secure, |
| #74 | grpc_host, |
| #75 | grpc_port, |
| #76 | grpc_secure, |
| #77 | headers=additional_headers, |
| #78 | skip_init_checks=True, |
| #79 | additional_config=AdditionalConfig(timeout=Timeout(init=2.0)), |
| #80 | ) |
| #81 | |
| #82 | self.collection_name = collection_name |
| #83 | self.embedding_model_dims = embedding_model_dims |
| #84 | self.create_col(embedding_model_dims) |
| #85 | |
| #86 | def _parse_output(self, data: Dict) -> List[OutputData]: |
| #87 | """ |
| #88 | Parse the output data. |
| #89 | |
| #90 | Args: |
| #91 | data (Dict): Output data. |
| #92 | |
| #93 | Returns: |
| #94 | List[OutputData]: Parsed output data. |
| #95 | """ |
| #96 | keys = ["ids", "distances", "metadatas"] |
| #97 | values = [] |
| #98 | |
| #99 | for key in keys: |
| #100 | value = data.get(key, []) |
| #101 | if isinstance(value, list) and value and isinstance(value[0], list): |
| #102 | value = value[0] |
| #103 | values.append(value) |
| #104 | |
| #105 | ids, distances, metadatas = values |
| #106 | max_length = max(len(v) for v in values if isinstance(v, list) and v is not None) |
| #107 | |
| #108 | result = [] |
| #109 | for i in range(max_length): |
| #110 | entry = OutputData( |
| #111 | id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None, |
| #112 | score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None), |
| #113 | payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None), |
| #114 | ) |
| #115 | result.append(entry) |
| #116 | |
| #117 | return result |
| #118 | |
| #119 | def create_col(self, vector_size, distance="cosine"): |
| #120 | """ |
| #121 | Create a new collection with the specified schema. |
| #122 | |
| #123 | Args: |
| #124 | vector_size (int): Size of the vectors to be stored. |
| #125 | distance (str, optional): Distance metric for vector similarity. Defaults to "cosine". |
| #126 | """ |
| #127 | if self.client.collections.exists(self.collection_name): |
| #128 | logger.debug(f"Collection {self.collection_name} already exists. Skipping creation.") |
| #129 | return |
| #130 | |
| #131 | properties = [ |
| #132 | wvcc.Property(name="ids", data_type=wvcc.DataType.TEXT), |
| #133 | wvcc.Property(name="hash", data_type=wvcc.DataType.TEXT), |
| #134 | wvcc.Property( |
| #135 | name="metadata", |
| #136 | data_type=wvcc.DataType.TEXT, |
| #137 | description="Additional metadata", |
| #138 | ), |
| #139 | wvcc.Property(name="data", data_type=wvcc.DataType.TEXT), |
| #140 | wvcc.Property(name="created_at", data_type=wvcc.DataType.TEXT), |
| #141 | wvcc.Property(name="category", data_type=wvcc.DataType.TEXT), |
| #142 | wvcc.Property(name="updated_at", data_type=wvcc.DataType.TEXT), |
| #143 | wvcc.Property(name="user_id", data_type=wvcc.DataType.TEXT), |
| #144 | wvcc.Property(name="agent_id", data_type=wvcc.DataType.TEXT), |
| #145 | wvcc.Property(name="run_id", data_type=wvcc.DataType.TEXT), |
| #146 | ] |
| #147 | |
| #148 | vectorizer_config = wvcc.Configure.Vectorizer.none() |
| #149 | vector_index_config = wvcc.Configure.VectorIndex.hnsw() |
| #150 | |
| #151 | self.client.collections.create( |
| #152 | self.collection_name, |
| #153 | vectorizer_config=vectorizer_config, |
| #154 | vector_index_config=vector_index_config, |
| #155 | properties=properties, |
| #156 | ) |
| #157 | |
| #158 | def insert(self, vectors, payloads=None, ids=None): |
| #159 | """ |
| #160 | Insert vectors into a collection. |
| #161 | |
| #162 | Args: |
| #163 | vectors (list): List of vectors to insert. |
| #164 | payloads (list, optional): List of payloads corresponding to vectors. Defaults to None. |
| #165 | ids (list, optional): List of IDs corresponding to vectors. Defaults to None. |
| #166 | """ |
| #167 | logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") |
| #168 | with self.client.batch.fixed_size(batch_size=100) as batch: |
| #169 | for idx, vector in enumerate(vectors): |
| #170 | object_id = ids[idx] if ids and idx < len(ids) else str(uuid.uuid4()) |
| #171 | object_id = get_valid_uuid(object_id) |
| #172 | |
| #173 | data_object = payloads[idx] if payloads and idx < len(payloads) else {} |
| #174 | |
| #175 | # Ensure 'id' is not included in properties (it's used as the Weaviate object ID) |
| #176 | if "ids" in data_object: |
| #177 | del data_object["ids"] |
| #178 | |
| #179 | batch.add_object(collection=self.collection_name, properties=data_object, uuid=object_id, vector=vector) |
| #180 | |
| #181 | def search( |
| #182 | self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None |
| #183 | ) -> List[OutputData]: |
| #184 | """ |
| #185 | Search for similar vectors. |
| #186 | """ |
| #187 | collection = self.client.collections.get(str(self.collection_name)) |
| #188 | filter_conditions = [] |
| #189 | if filters: |
| #190 | for key, value in filters.items(): |
| #191 | if value and key in ["user_id", "agent_id", "run_id"]: |
| #192 | filter_conditions.append(Filter.by_property(key).equal(value)) |
| #193 | combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None |
| #194 | response = collection.query.hybrid( |
| #195 | query="", |
| #196 | vector=vectors, |
| #197 | limit=limit, |
| #198 | filters=combined_filter, |
| #199 | return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"], |
| #200 | return_metadata=MetadataQuery(score=True), |
| #201 | ) |
| #202 | results = [] |
| #203 | for obj in response.objects: |
| #204 | payload = obj.properties.copy() |
| #205 | |
| #206 | for id_field in ["run_id", "agent_id", "user_id"]: |
| #207 | if id_field in payload and payload[id_field] is None: |
| #208 | del payload[id_field] |
| #209 | |
| #210 | payload["id"] = str(obj.uuid).split("'")[0] # Include the id in the payload |
| #211 | if obj.metadata.distance is not None: |
| #212 | score = 1 - obj.metadata.distance # Convert distance to similarity score |
| #213 | elif obj.metadata.score is not None: |
| #214 | score = obj.metadata.score |
| #215 | else: |
| #216 | score = 1.0 # Default score if none provided |
| #217 | results.append( |
| #218 | OutputData( |
| #219 | id=str(obj.uuid), |
| #220 | score=score, |
| #221 | payload=payload, |
| #222 | ) |
| #223 | ) |
| #224 | return results |
| #225 | |
| #226 | def delete(self, vector_id): |
| #227 | """ |
| #228 | Delete a vector by ID. |
| #229 | |
| #230 | Args: |
| #231 | vector_id: ID of the vector to delete. |
| #232 | """ |
| #233 | collection = self.client.collections.get(str(self.collection_name)) |
| #234 | collection.data.delete_by_id(vector_id) |
| #235 | |
| #236 | def update(self, vector_id, vector=None, payload=None): |
| #237 | """ |
| #238 | Update a vector and its payload. |
| #239 | |
| #240 | Args: |
| #241 | vector_id: ID of the vector to update. |
| #242 | vector (list, optional): Updated vector. Defaults to None. |
| #243 | payload (dict, optional): Updated payload. Defaults to None. |
| #244 | """ |
| #245 | collection = self.client.collections.get(str(self.collection_name)) |
| #246 | |
| #247 | if payload: |
| #248 | collection.data.update(uuid=vector_id, properties=payload) |
| #249 | |
| #250 | if vector: |
| #251 | existing_data = self.get(vector_id) |
| #252 | if existing_data: |
| #253 | existing_data = dict(existing_data) |
| #254 | if "id" in existing_data: |
| #255 | del existing_data["id"] |
| #256 | existing_payload: Mapping[str, str] = existing_data |
| #257 | collection.data.update(uuid=vector_id, properties=existing_payload, vector=vector) |
| #258 | |
| #259 | def get(self, vector_id): |
| #260 | """ |
| #261 | Retrieve a vector by ID. |
| #262 | |
| #263 | Args: |
| #264 | vector_id: ID of the vector to retrieve. |
| #265 | |
| #266 | Returns: |
| #267 | dict: Retrieved vector and metadata. |
| #268 | """ |
| #269 | vector_id = get_valid_uuid(vector_id) |
| #270 | collection = self.client.collections.get(str(self.collection_name)) |
| #271 | |
| #272 | response = collection.query.fetch_object_by_id( |
| #273 | uuid=vector_id, |
| #274 | return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"], |
| #275 | ) |
| #276 | # results = {} |
| #277 | # print("reponse",response) |
| #278 | # for obj in response.objects: |
| #279 | payload = response.properties.copy() |
| #280 | payload["id"] = str(response.uuid).split("'")[0] |
| #281 | results = OutputData( |
| #282 | id=str(response.uuid).split("'")[0], |
| #283 | score=1.0, |
| #284 | payload=payload, |
| #285 | ) |
| #286 | return results |
| #287 | |
| #288 | def list_cols(self): |
| #289 | """ |
| #290 | List all collections. |
| #291 | |
| #292 | Returns: |
| #293 | list: List of collection names. |
| #294 | """ |
| #295 | collections = self.client.collections.list_all() |
| #296 | logger.debug(f"collections: {collections}") |
| #297 | print(f"collections: {collections}") |
| #298 | return {"collections": [{"name": col.name} for col in collections]} |
| #299 | |
| #300 | def delete_col(self): |
| #301 | """Delete a collection.""" |
| #302 | self.client.collections.delete(self.collection_name) |
| #303 | |
| #304 | def col_info(self): |
| #305 | """ |
| #306 | Get information about a collection. |
| #307 | |
| #308 | Returns: |
| #309 | dict: Collection information. |
| #310 | """ |
| #311 | schema = self.client.collections.get(self.collection_name) |
| #312 | if schema: |
| #313 | return schema |
| #314 | return None |
| #315 | |
| #316 | def list(self, filters=None, limit=100) -> List[OutputData]: |
| #317 | """ |
| #318 | List all vectors in a collection. |
| #319 | """ |
| #320 | collection = self.client.collections.get(self.collection_name) |
| #321 | filter_conditions = [] |
| #322 | if filters: |
| #323 | for key, value in filters.items(): |
| #324 | if value and key in ["user_id", "agent_id", "run_id"]: |
| #325 | filter_conditions.append(Filter.by_property(key).equal(value)) |
| #326 | combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None |
| #327 | response = collection.query.fetch_objects( |
| #328 | limit=limit, |
| #329 | filters=combined_filter, |
| #330 | return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"], |
| #331 | ) |
| #332 | results = [] |
| #333 | for obj in response.objects: |
| #334 | payload = obj.properties.copy() |
| #335 | payload["id"] = str(obj.uuid).split("'")[0] |
| #336 | results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload)) |
| #337 | return [results] |
| #338 | |
| #339 | def reset(self): |
| #340 | """Reset the index by deleting and recreating it.""" |
| #341 | logger.warning(f"Resetting index {self.collection_name}...") |
| #342 | self.delete_col() |
| #343 | self.create_col() |
| #344 |