repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import json |
| #2 | import logging |
| #3 | import uuid |
| #4 | from typing import Any, Dict, List, Optional |
| #5 | |
| #6 | import numpy as np |
| #7 | from pydantic import BaseModel |
| #8 | |
| #9 | try: |
| #10 | from cassandra.cluster import Cluster |
| #11 | from cassandra.auth import PlainTextAuthProvider |
| #12 | except ImportError: |
| #13 | raise ImportError( |
| #14 | "Apache Cassandra vector store requires cassandra-driver. " |
| #15 | "Please install it using 'pip install cassandra-driver'" |
| #16 | ) |
| #17 | |
| #18 | from mem0.vector_stores.base import VectorStoreBase |
| #19 | |
| #20 | logger = logging.getLogger(__name__) |
| #21 | |
| #22 | |
| #23 | class OutputData(BaseModel): |
| #24 | id: Optional[str] |
| #25 | score: Optional[float] |
| #26 | payload: Optional[dict] |
| #27 | |
| #28 | |
| #29 | class CassandraDB(VectorStoreBase): |
| #30 | def __init__( |
| #31 | self, |
| #32 | contact_points: List[str], |
| #33 | port: int = 9042, |
| #34 | username: Optional[str] = None, |
| #35 | password: Optional[str] = None, |
| #36 | keyspace: str = "mem0", |
| #37 | collection_name: str = "memories", |
| #38 | embedding_model_dims: int = 1536, |
| #39 | secure_connect_bundle: Optional[str] = None, |
| #40 | protocol_version: int = 4, |
| #41 | load_balancing_policy: Optional[Any] = None, |
| #42 | ): |
| #43 | """ |
| #44 | Initialize the Apache Cassandra vector store. |
| #45 | |
| #46 | Args: |
| #47 | contact_points (List[str]): List of contact point addresses (e.g., ['127.0.0.1']) |
| #48 | port (int): Cassandra port (default: 9042) |
| #49 | username (str, optional): Database username |
| #50 | password (str, optional): Database password |
| #51 | keyspace (str): Keyspace name (default: "mem0") |
| #52 | collection_name (str): Table name (default: "memories") |
| #53 | embedding_model_dims (int): Dimension of the embedding vector (default: 1536) |
| #54 | secure_connect_bundle (str, optional): Path to secure connect bundle for Astra DB |
| #55 | protocol_version (int): CQL protocol version (default: 4) |
| #56 | load_balancing_policy (Any, optional): Custom load balancing policy |
| #57 | """ |
| #58 | self.contact_points = contact_points |
| #59 | self.port = port |
| #60 | self.username = username |
| #61 | self.password = password |
| #62 | self.keyspace = keyspace |
| #63 | self.collection_name = collection_name |
| #64 | self.embedding_model_dims = embedding_model_dims |
| #65 | self.secure_connect_bundle = secure_connect_bundle |
| #66 | self.protocol_version = protocol_version |
| #67 | self.load_balancing_policy = load_balancing_policy |
| #68 | |
| #69 | # Initialize connection |
| #70 | self.cluster = None |
| #71 | self.session = None |
| #72 | self._setup_connection() |
| #73 | |
| #74 | # Create keyspace and table if they don't exist |
| #75 | self._create_keyspace() |
| #76 | self._create_table() |
| #77 | |
| #78 | def _setup_connection(self): |
| #79 | """Setup Cassandra cluster connection.""" |
| #80 | try: |
| #81 | # Setup authentication |
| #82 | auth_provider = None |
| #83 | if self.username and self.password: |
| #84 | auth_provider = PlainTextAuthProvider( |
| #85 | username=self.username, |
| #86 | password=self.password |
| #87 | ) |
| #88 | |
| #89 | # Connect to Astra DB using secure connect bundle |
| #90 | if self.secure_connect_bundle: |
| #91 | self.cluster = Cluster( |
| #92 | cloud={'secure_connect_bundle': self.secure_connect_bundle}, |
| #93 | auth_provider=auth_provider, |
| #94 | protocol_version=self.protocol_version |
| #95 | ) |
| #96 | else: |
| #97 | # Connect to standard Cassandra cluster |
| #98 | cluster_kwargs = { |
| #99 | 'contact_points': self.contact_points, |
| #100 | 'port': self.port, |
| #101 | 'protocol_version': self.protocol_version |
| #102 | } |
| #103 | |
| #104 | if auth_provider: |
| #105 | cluster_kwargs['auth_provider'] = auth_provider |
| #106 | |
| #107 | if self.load_balancing_policy: |
| #108 | cluster_kwargs['load_balancing_policy'] = self.load_balancing_policy |
| #109 | |
| #110 | self.cluster = Cluster(**cluster_kwargs) |
| #111 | |
| #112 | self.session = self.cluster.connect() |
| #113 | logger.info("Successfully connected to Cassandra cluster") |
| #114 | except Exception as e: |
| #115 | logger.error(f"Failed to connect to Cassandra: {e}") |
| #116 | raise |
| #117 | |
| #118 | def _create_keyspace(self): |
| #119 | """Create keyspace if it doesn't exist.""" |
| #120 | try: |
| #121 | # Use SimpleStrategy for single datacenter, NetworkTopologyStrategy for production |
| #122 | query = f""" |
| #123 | CREATE KEYSPACE IF NOT EXISTS {self.keyspace} |
| #124 | WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}} |
| #125 | """ |
| #126 | self.session.execute(query) |
| #127 | self.session.set_keyspace(self.keyspace) |
| #128 | logger.info(f"Keyspace '{self.keyspace}' is ready") |
| #129 | except Exception as e: |
| #130 | logger.error(f"Failed to create keyspace: {e}") |
| #131 | raise |
| #132 | |
| #133 | def _create_table(self): |
| #134 | """Create table with vector column if it doesn't exist.""" |
| #135 | try: |
| #136 | # Create table with vector stored as list<float> and payload as text (JSON) |
| #137 | query = f""" |
| #138 | CREATE TABLE IF NOT EXISTS {self.keyspace}.{self.collection_name} ( |
| #139 | id text PRIMARY KEY, |
| #140 | vector list<float>, |
| #141 | payload text |
| #142 | ) |
| #143 | """ |
| #144 | self.session.execute(query) |
| #145 | logger.info(f"Table '{self.collection_name}' is ready") |
| #146 | except Exception as e: |
| #147 | logger.error(f"Failed to create table: {e}") |
| #148 | raise |
| #149 | |
| #150 | def create_col(self, name: str = None, vector_size: int = None, distance: str = "cosine"): |
| #151 | """ |
| #152 | Create a new collection (table in Cassandra). |
| #153 | |
| #154 | Args: |
| #155 | name (str, optional): Collection name (uses self.collection_name if not provided) |
| #156 | vector_size (int, optional): Vector dimension (uses self.embedding_model_dims if not provided) |
| #157 | distance (str): Distance metric (cosine, euclidean, dot_product) |
| #158 | """ |
| #159 | table_name = name or self.collection_name |
| #160 | dims = vector_size or self.embedding_model_dims |
| #161 | |
| #162 | try: |
| #163 | query = f""" |
| #164 | CREATE TABLE IF NOT EXISTS {self.keyspace}.{table_name} ( |
| #165 | id text PRIMARY KEY, |
| #166 | vector list<float>, |
| #167 | payload text |
| #168 | ) |
| #169 | """ |
| #170 | self.session.execute(query) |
| #171 | logger.info(f"Created collection '{table_name}' with vector dimension {dims}") |
| #172 | except Exception as e: |
| #173 | logger.error(f"Failed to create collection: {e}") |
| #174 | raise |
| #175 | |
| #176 | def insert( |
| #177 | self, |
| #178 | vectors: List[List[float]], |
| #179 | payloads: Optional[List[Dict]] = None, |
| #180 | ids: Optional[List[str]] = None |
| #181 | ): |
| #182 | """ |
| #183 | Insert vectors into the collection. |
| #184 | |
| #185 | Args: |
| #186 | vectors (List[List[float]]): List of vectors to insert |
| #187 | payloads (List[Dict], optional): List of payloads corresponding to vectors |
| #188 | ids (List[str], optional): List of IDs corresponding to vectors |
| #189 | """ |
| #190 | logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") |
| #191 | |
| #192 | if payloads is None: |
| #193 | payloads = [{}] * len(vectors) |
| #194 | if ids is None: |
| #195 | ids = [str(uuid.uuid4()) for _ in range(len(vectors))] |
| #196 | |
| #197 | try: |
| #198 | query = f""" |
| #199 | INSERT INTO {self.keyspace}.{self.collection_name} (id, vector, payload) |
| #200 | VALUES (?, ?, ?) |
| #201 | """ |
| #202 | prepared = self.session.prepare(query) |
| #203 | |
| #204 | for vector, payload, vec_id in zip(vectors, payloads, ids): |
| #205 | self.session.execute( |
| #206 | prepared, |
| #207 | (vec_id, vector, json.dumps(payload)) |
| #208 | ) |
| #209 | except Exception as e: |
| #210 | logger.error(f"Failed to insert vectors: {e}") |
| #211 | raise |
| #212 | |
| #213 | def search( |
| #214 | self, |
| #215 | query: str, |
| #216 | vectors: List[float], |
| #217 | limit: int = 5, |
| #218 | filters: Optional[Dict] = None, |
| #219 | ) -> List[OutputData]: |
| #220 | """ |
| #221 | Search for similar vectors using cosine similarity. |
| #222 | |
| #223 | Args: |
| #224 | query (str): Query string (not used in vector search) |
| #225 | vectors (List[float]): Query vector |
| #226 | limit (int): Number of results to return |
| #227 | filters (Dict, optional): Filters to apply to the search |
| #228 | |
| #229 | Returns: |
| #230 | List[OutputData]: Search results |
| #231 | """ |
| #232 | try: |
| #233 | # Fetch all vectors (in production, you'd want pagination or filtering) |
| #234 | query_cql = f""" |
| #235 | SELECT id, vector, payload |
| #236 | FROM {self.keyspace}.{self.collection_name} |
| #237 | """ |
| #238 | rows = self.session.execute(query_cql) |
| #239 | |
| #240 | # Calculate cosine similarity in Python |
| #241 | query_vec = np.array(vectors) |
| #242 | scored_results = [] |
| #243 | |
| #244 | for row in rows: |
| #245 | if not row.vector: |
| #246 | continue |
| #247 | |
| #248 | vec = np.array(row.vector) |
| #249 | |
| #250 | # Cosine similarity |
| #251 | similarity = np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec)) |
| #252 | distance = 1 - similarity |
| #253 | |
| #254 | # Apply filters if provided |
| #255 | if filters: |
| #256 | try: |
| #257 | payload = json.loads(row.payload) if row.payload else {} |
| #258 | match = all(payload.get(k) == v for k, v in filters.items()) |
| #259 | if not match: |
| #260 | continue |
| #261 | except json.JSONDecodeError: |
| #262 | continue |
| #263 | |
| #264 | scored_results.append((row.id, distance, row.payload)) |
| #265 | |
| #266 | # Sort by distance and limit |
| #267 | scored_results.sort(key=lambda x: x[1]) |
| #268 | scored_results = scored_results[:limit] |
| #269 | |
| #270 | return [ |
| #271 | OutputData( |
| #272 | id=r[0], |
| #273 | score=float(r[1]), |
| #274 | payload=json.loads(r[2]) if r[2] else {} |
| #275 | ) |
| #276 | for r in scored_results |
| #277 | ] |
| #278 | except Exception as e: |
| #279 | logger.error(f"Search failed: {e}") |
| #280 | raise |
| #281 | |
| #282 | def delete(self, vector_id: str): |
| #283 | """ |
| #284 | Delete a vector by ID. |
| #285 | |
| #286 | Args: |
| #287 | vector_id (str): ID of the vector to delete |
| #288 | """ |
| #289 | try: |
| #290 | query = f""" |
| #291 | DELETE FROM {self.keyspace}.{self.collection_name} |
| #292 | WHERE id = ? |
| #293 | """ |
| #294 | prepared = self.session.prepare(query) |
| #295 | self.session.execute(prepared, (vector_id,)) |
| #296 | logger.info(f"Deleted vector with id: {vector_id}") |
| #297 | except Exception as e: |
| #298 | logger.error(f"Failed to delete vector: {e}") |
| #299 | raise |
| #300 | |
| #301 | def update( |
| #302 | self, |
| #303 | vector_id: str, |
| #304 | vector: Optional[List[float]] = None, |
| #305 | payload: Optional[Dict] = None, |
| #306 | ): |
| #307 | """ |
| #308 | Update a vector and its payload. |
| #309 | |
| #310 | Args: |
| #311 | vector_id (str): ID of the vector to update |
| #312 | vector (List[float], optional): Updated vector |
| #313 | payload (Dict, optional): Updated payload |
| #314 | """ |
| #315 | try: |
| #316 | if vector is not None: |
| #317 | query = f""" |
| #318 | UPDATE {self.keyspace}.{self.collection_name} |
| #319 | SET vector = ? |
| #320 | WHERE id = ? |
| #321 | """ |
| #322 | prepared = self.session.prepare(query) |
| #323 | self.session.execute(prepared, (vector, vector_id)) |
| #324 | |
| #325 | if payload is not None: |
| #326 | query = f""" |
| #327 | UPDATE {self.keyspace}.{self.collection_name} |
| #328 | SET payload = ? |
| #329 | WHERE id = ? |
| #330 | """ |
| #331 | prepared = self.session.prepare(query) |
| #332 | self.session.execute(prepared, (json.dumps(payload), vector_id)) |
| #333 | |
| #334 | logger.info(f"Updated vector with id: {vector_id}") |
| #335 | except Exception as e: |
| #336 | logger.error(f"Failed to update vector: {e}") |
| #337 | raise |
| #338 | |
| #339 | def get(self, vector_id: str) -> Optional[OutputData]: |
| #340 | """ |
| #341 | Retrieve a vector by ID. |
| #342 | |
| #343 | Args: |
| #344 | vector_id (str): ID of the vector to retrieve |
| #345 | |
| #346 | Returns: |
| #347 | OutputData: Retrieved vector or None if not found |
| #348 | """ |
| #349 | try: |
| #350 | query = f""" |
| #351 | SELECT id, vector, payload |
| #352 | FROM {self.keyspace}.{self.collection_name} |
| #353 | WHERE id = ? |
| #354 | """ |
| #355 | prepared = self.session.prepare(query) |
| #356 | row = self.session.execute(prepared, (vector_id,)).one() |
| #357 | |
| #358 | if not row: |
| #359 | return None |
| #360 | |
| #361 | return OutputData( |
| #362 | id=row.id, |
| #363 | score=None, |
| #364 | payload=json.loads(row.payload) if row.payload else {} |
| #365 | ) |
| #366 | except Exception as e: |
| #367 | logger.error(f"Failed to get vector: {e}") |
| #368 | return None |
| #369 | |
| #370 | def list_cols(self) -> List[str]: |
| #371 | """ |
| #372 | List all collections (tables in the keyspace). |
| #373 | |
| #374 | Returns: |
| #375 | List[str]: List of collection names |
| #376 | """ |
| #377 | try: |
| #378 | query = f""" |
| #379 | SELECT table_name |
| #380 | FROM system_schema.tables |
| #381 | WHERE keyspace_name = '{self.keyspace}' |
| #382 | """ |
| #383 | rows = self.session.execute(query) |
| #384 | return [row.table_name for row in rows] |
| #385 | except Exception as e: |
| #386 | logger.error(f"Failed to list collections: {e}") |
| #387 | return [] |
| #388 | |
| #389 | def delete_col(self): |
| #390 | """Delete the collection (table).""" |
| #391 | try: |
| #392 | query = f""" |
| #393 | DROP TABLE IF EXISTS {self.keyspace}.{self.collection_name} |
| #394 | """ |
| #395 | self.session.execute(query) |
| #396 | logger.info(f"Deleted collection '{self.collection_name}'") |
| #397 | except Exception as e: |
| #398 | logger.error(f"Failed to delete collection: {e}") |
| #399 | raise |
| #400 | |
| #401 | def col_info(self) -> Dict[str, Any]: |
| #402 | """ |
| #403 | Get information about the collection. |
| #404 | |
| #405 | Returns: |
| #406 | Dict[str, Any]: Collection information |
| #407 | """ |
| #408 | try: |
| #409 | # Get row count (approximate) |
| #410 | query = f""" |
| #411 | SELECT COUNT(*) as count |
| #412 | FROM {self.keyspace}.{self.collection_name} |
| #413 | """ |
| #414 | row = self.session.execute(query).one() |
| #415 | count = row.count if row else 0 |
| #416 | |
| #417 | return { |
| #418 | "name": self.collection_name, |
| #419 | "keyspace": self.keyspace, |
| #420 | "count": count, |
| #421 | "vector_dims": self.embedding_model_dims |
| #422 | } |
| #423 | except Exception as e: |
| #424 | logger.error(f"Failed to get collection info: {e}") |
| #425 | return {} |
| #426 | |
| #427 | def list( |
| #428 | self, |
| #429 | filters: Optional[Dict] = None, |
| #430 | limit: int = 100 |
| #431 | ) -> List[List[OutputData]]: |
| #432 | """ |
| #433 | List all vectors in the collection. |
| #434 | |
| #435 | Args: |
| #436 | filters (Dict, optional): Filters to apply |
| #437 | limit (int): Number of vectors to return |
| #438 | |
| #439 | Returns: |
| #440 | List[List[OutputData]]: List of vectors |
| #441 | """ |
| #442 | try: |
| #443 | query = f""" |
| #444 | SELECT id, vector, payload |
| #445 | FROM {self.keyspace}.{self.collection_name} |
| #446 | LIMIT {limit} |
| #447 | """ |
| #448 | rows = self.session.execute(query) |
| #449 | |
| #450 | results = [] |
| #451 | for row in rows: |
| #452 | # Apply filters if provided |
| #453 | if filters: |
| #454 | try: |
| #455 | payload = json.loads(row.payload) if row.payload else {} |
| #456 | match = all(payload.get(k) == v for k, v in filters.items()) |
| #457 | if not match: |
| #458 | continue |
| #459 | except json.JSONDecodeError: |
| #460 | continue |
| #461 | |
| #462 | results.append( |
| #463 | OutputData( |
| #464 | id=row.id, |
| #465 | score=None, |
| #466 | payload=json.loads(row.payload) if row.payload else {} |
| #467 | ) |
| #468 | ) |
| #469 | |
| #470 | return [results] |
| #471 | except Exception as e: |
| #472 | logger.error(f"Failed to list vectors: {e}") |
| #473 | return [[]] |
| #474 | |
| #475 | def reset(self): |
| #476 | """Reset the collection by truncating it.""" |
| #477 | try: |
| #478 | logger.warning(f"Resetting collection {self.collection_name}...") |
| #479 | query = f""" |
| #480 | TRUNCATE TABLE {self.keyspace}.{self.collection_name} |
| #481 | """ |
| #482 | self.session.execute(query) |
| #483 | logger.info(f"Collection '{self.collection_name}' has been reset") |
| #484 | except Exception as e: |
| #485 | logger.error(f"Failed to reset collection: {e}") |
| #486 | raise |
| #487 | |
| #488 | def __del__(self): |
| #489 | """Close the cluster connection when the object is deleted.""" |
| #490 | try: |
| #491 | if self.cluster: |
| #492 | self.cluster.shutdown() |
| #493 | logger.info("Cassandra cluster connection closed") |
| #494 | except Exception: |
| #495 | pass |
| #496 | |
| #497 |