repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import json |
| #2 | import logging |
| #3 | import uuid |
| #4 | from typing import Optional, List |
| #5 | from datetime import datetime, date |
| #6 | from databricks.sdk.service.catalog import ColumnInfo, ColumnTypeName, TableType, DataSourceFormat |
| #7 | from databricks.sdk.service.catalog import TableConstraint, PrimaryKeyConstraint |
| #8 | from databricks.sdk import WorkspaceClient |
| #9 | from databricks.sdk.service.vectorsearch import ( |
| #10 | VectorIndexType, |
| #11 | DeltaSyncVectorIndexSpecRequest, |
| #12 | DirectAccessVectorIndexSpec, |
| #13 | EmbeddingSourceColumn, |
| #14 | EmbeddingVectorColumn, |
| #15 | ) |
| #16 | from pydantic import BaseModel |
| #17 | from mem0.memory.utils import extract_json |
| #18 | from mem0.vector_stores.base import VectorStoreBase |
| #19 | |
| #20 | logger = logging.getLogger(__name__) |
| #21 | |
| #22 | |
| #23 | class MemoryResult(BaseModel): |
| #24 | id: Optional[str] = None |
| #25 | score: Optional[float] = None |
| #26 | payload: Optional[dict] = None |
| #27 | |
| #28 | |
| #29 | excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} |
| #30 | |
| #31 | |
| #32 | class Databricks(VectorStoreBase): |
| #33 | def __init__( |
| #34 | self, |
| #35 | workspace_url: str, |
| #36 | access_token: Optional[str] = None, |
| #37 | client_id: Optional[str] = None, |
| #38 | client_secret: Optional[str] = None, |
| #39 | azure_client_id: Optional[str] = None, |
| #40 | azure_client_secret: Optional[str] = None, |
| #41 | endpoint_name: str = None, |
| #42 | catalog: str = None, |
| #43 | schema: str = None, |
| #44 | table_name: str = None, |
| #45 | collection_name: str = "mem0", |
| #46 | index_type: str = "DELTA_SYNC", |
| #47 | embedding_model_endpoint_name: Optional[str] = None, |
| #48 | embedding_dimension: int = 1536, |
| #49 | endpoint_type: str = "STANDARD", |
| #50 | pipeline_type: str = "TRIGGERED", |
| #51 | warehouse_name: Optional[str] = None, |
| #52 | query_type: str = "ANN", |
| #53 | ): |
| #54 | """ |
| #55 | Initialize the Databricks Vector Search vector store. |
| #56 | |
| #57 | Args: |
| #58 | workspace_url (str): Databricks workspace URL. |
| #59 | access_token (str, optional): Personal access token for authentication. |
| #60 | client_id (str, optional): Service principal client ID for authentication. |
| #61 | client_secret (str, optional): Service principal client secret for authentication. |
| #62 | azure_client_id (str, optional): Azure AD application client ID (for Azure Databricks). |
| #63 | azure_client_secret (str, optional): Azure AD application client secret (for Azure Databricks). |
| #64 | endpoint_name (str): Vector search endpoint name. |
| #65 | catalog (str): Unity Catalog catalog name. |
| #66 | schema (str): Unity Catalog schema name. |
| #67 | table_name (str): Source Delta table name. |
| #68 | index_name (str, optional): Vector search index name (default: "mem0"). |
| #69 | index_type (str, optional): Index type, either "DELTA_SYNC" or "DIRECT_ACCESS" (default: "DELTA_SYNC"). |
| #70 | embedding_model_endpoint_name (str, optional): Embedding model endpoint for Databricks-computed embeddings. |
| #71 | embedding_dimension (int, optional): Vector embedding dimensions (default: 1536). |
| #72 | endpoint_type (str, optional): Endpoint type, either "STANDARD" or "STORAGE_OPTIMIZED" (default: "STANDARD"). |
| #73 | pipeline_type (str, optional): Sync pipeline type, either "TRIGGERED" or "CONTINUOUS" (default: "TRIGGERED"). |
| #74 | warehouse_name (str, optional): Databricks SQL warehouse Name (if using SQL warehouse). |
| #75 | query_type (str, optional): Query type, either "ANN" or "HYBRID" (default: "ANN"). |
| #76 | """ |
| #77 | # Basic identifiers |
| #78 | self.workspace_url = workspace_url |
| #79 | self.endpoint_name = endpoint_name |
| #80 | self.catalog = catalog |
| #81 | self.schema = schema |
| #82 | self.table_name = table_name |
| #83 | self.fully_qualified_table_name = f"{self.catalog}.{self.schema}.{self.table_name}" |
| #84 | self.index_name = collection_name |
| #85 | self.fully_qualified_index_name = f"{self.catalog}.{self.schema}.{self.index_name}" |
| #86 | |
| #87 | # Configuration |
| #88 | self.index_type = index_type |
| #89 | self.embedding_model_endpoint_name = embedding_model_endpoint_name |
| #90 | self.embedding_dimension = embedding_dimension |
| #91 | self.endpoint_type = endpoint_type |
| #92 | self.pipeline_type = pipeline_type |
| #93 | self.query_type = query_type |
| #94 | |
| #95 | # Schema |
| #96 | self.columns = [ |
| #97 | ColumnInfo( |
| #98 | name="memory_id", |
| #99 | type_name=ColumnTypeName.STRING, |
| #100 | type_text="string", |
| #101 | type_json='{"type":"string"}', |
| #102 | nullable=False, |
| #103 | comment="Primary key", |
| #104 | position=0, |
| #105 | ), |
| #106 | ColumnInfo( |
| #107 | name="hash", |
| #108 | type_name=ColumnTypeName.STRING, |
| #109 | type_text="string", |
| #110 | type_json='{"type":"string"}', |
| #111 | comment="Hash of the memory content", |
| #112 | position=1, |
| #113 | ), |
| #114 | ColumnInfo( |
| #115 | name="agent_id", |
| #116 | type_name=ColumnTypeName.STRING, |
| #117 | type_text="string", |
| #118 | type_json='{"type":"string"}', |
| #119 | comment="ID of the agent", |
| #120 | position=2, |
| #121 | ), |
| #122 | ColumnInfo( |
| #123 | name="run_id", |
| #124 | type_name=ColumnTypeName.STRING, |
| #125 | type_text="string", |
| #126 | type_json='{"type":"string"}', |
| #127 | comment="ID of the run", |
| #128 | position=3, |
| #129 | ), |
| #130 | ColumnInfo( |
| #131 | name="user_id", |
| #132 | type_name=ColumnTypeName.STRING, |
| #133 | type_text="string", |
| #134 | type_json='{"type":"string"}', |
| #135 | comment="ID of the user", |
| #136 | position=4, |
| #137 | ), |
| #138 | ColumnInfo( |
| #139 | name="memory", |
| #140 | type_name=ColumnTypeName.STRING, |
| #141 | type_text="string", |
| #142 | type_json='{"type":"string"}', |
| #143 | comment="Memory content", |
| #144 | position=5, |
| #145 | ), |
| #146 | ColumnInfo( |
| #147 | name="metadata", |
| #148 | type_name=ColumnTypeName.STRING, |
| #149 | type_text="string", |
| #150 | type_json='{"type":"string"}', |
| #151 | comment="Additional metadata", |
| #152 | position=6, |
| #153 | ), |
| #154 | ColumnInfo( |
| #155 | name="created_at", |
| #156 | type_name=ColumnTypeName.TIMESTAMP, |
| #157 | type_text="timestamp", |
| #158 | type_json='{"type":"timestamp"}', |
| #159 | comment="Creation timestamp", |
| #160 | position=7, |
| #161 | ), |
| #162 | ColumnInfo( |
| #163 | name="updated_at", |
| #164 | type_name=ColumnTypeName.TIMESTAMP, |
| #165 | type_text="timestamp", |
| #166 | type_json='{"type":"timestamp"}', |
| #167 | comment="Last update timestamp", |
| #168 | position=8, |
| #169 | ), |
| #170 | ] |
| #171 | if self.index_type == VectorIndexType.DIRECT_ACCESS: |
| #172 | self.columns.append( |
| #173 | ColumnInfo( |
| #174 | name="embedding", |
| #175 | type_name=ColumnTypeName.ARRAY, |
| #176 | type_text="array<float>", |
| #177 | type_json='{"type":"array","element":"float","element_nullable":false}', |
| #178 | nullable=True, |
| #179 | comment="Embedding vector", |
| #180 | position=9, |
| #181 | ) |
| #182 | ) |
| #183 | self.column_names = [col.name for col in self.columns] |
| #184 | |
| #185 | # Initialize Databricks workspace client |
| #186 | client_config = {} |
| #187 | if client_id and client_secret: |
| #188 | client_config.update( |
| #189 | { |
| #190 | "host": workspace_url, |
| #191 | "client_id": client_id, |
| #192 | "client_secret": client_secret, |
| #193 | } |
| #194 | ) |
| #195 | elif azure_client_id and azure_client_secret: |
| #196 | client_config.update( |
| #197 | { |
| #198 | "host": workspace_url, |
| #199 | "azure_client_id": azure_client_id, |
| #200 | "azure_client_secret": azure_client_secret, |
| #201 | } |
| #202 | ) |
| #203 | elif access_token: |
| #204 | client_config.update({"host": workspace_url, "token": access_token}) |
| #205 | else: |
| #206 | # Try automatic authentication |
| #207 | client_config["host"] = workspace_url |
| #208 | |
| #209 | try: |
| #210 | self.client = WorkspaceClient(**client_config) |
| #211 | logger.info("Initialized Databricks workspace client") |
| #212 | except Exception as e: |
| #213 | logger.error(f"Failed to initialize Databricks workspace client: {e}") |
| #214 | raise |
| #215 | |
| #216 | # Get the warehouse ID by name |
| #217 | self.warehouse_id = next((w.id for w in self.client.warehouses.list() if w.name == warehouse_name), None) |
| #218 | |
| #219 | # Initialize endpoint (required in Databricks) |
| #220 | self._ensure_endpoint_exists() |
| #221 | |
| #222 | # Check if index exists and create if needed |
| #223 | collections = self.list_cols() |
| #224 | if self.fully_qualified_index_name not in collections: |
| #225 | self.create_col() |
| #226 | |
| #227 | def _ensure_endpoint_exists(self): |
| #228 | """Ensure the vector search endpoint exists, create if it doesn't.""" |
| #229 | try: |
| #230 | self.client.vector_search_endpoints.get_endpoint(endpoint_name=self.endpoint_name) |
| #231 | logger.info(f"Vector search endpoint '{self.endpoint_name}' already exists") |
| #232 | except Exception: |
| #233 | # Endpoint doesn't exist, create it |
| #234 | try: |
| #235 | logger.info(f"Creating vector search endpoint '{self.endpoint_name}' with type '{self.endpoint_type}'") |
| #236 | self.client.vector_search_endpoints.create_endpoint_and_wait( |
| #237 | name=self.endpoint_name, endpoint_type=self.endpoint_type |
| #238 | ) |
| #239 | logger.info(f"Successfully created vector search endpoint '{self.endpoint_name}'") |
| #240 | except Exception as e: |
| #241 | logger.error(f"Failed to create vector search endpoint '{self.endpoint_name}': {e}") |
| #242 | raise |
| #243 | |
| #244 | def _ensure_source_table_exists(self): |
| #245 | """Ensure the source Delta table exists with the proper schema.""" |
| #246 | check = self.client.tables.exists(self.fully_qualified_table_name) |
| #247 | |
| #248 | if check.table_exists: |
| #249 | logger.info(f"Source table '{self.fully_qualified_table_name}' already exists") |
| #250 | else: |
| #251 | logger.info(f"Source table '{self.fully_qualified_table_name}' does not exist, creating it...") |
| #252 | self.client.tables.create( |
| #253 | name=self.table_name, |
| #254 | catalog_name=self.catalog, |
| #255 | schema_name=self.schema, |
| #256 | table_type=TableType.MANAGED, |
| #257 | data_source_format=DataSourceFormat.DELTA, |
| #258 | storage_location=None, # Use default storage location |
| #259 | columns=self.columns, |
| #260 | properties={"delta.enableChangeDataFeed": "true"}, |
| #261 | ) |
| #262 | logger.info(f"Successfully created source table '{self.fully_qualified_table_name}'") |
| #263 | self.client.table_constraints.create( |
| #264 | full_name_arg="logistics_dev.ai.dev_memory", |
| #265 | constraint=TableConstraint( |
| #266 | primary_key_constraint=PrimaryKeyConstraint( |
| #267 | name="pk_dev_memory", # Name of the primary key constraint |
| #268 | child_columns=["memory_id"], # Columns that make up the primary key |
| #269 | ) |
| #270 | ), |
| #271 | ) |
| #272 | logger.info( |
| #273 | f"Successfully created primary key constraint on 'memory_id' for table '{self.fully_qualified_table_name}'" |
| #274 | ) |
| #275 | |
| #276 | def create_col(self, name=None, vector_size=None, distance=None): |
| #277 | """ |
| #278 | Create a new collection (index). |
| #279 | |
| #280 | Args: |
| #281 | name (str, optional): Index name. If provided, will create a new index using the provided source_table_name. |
| #282 | vector_size (int, optional): Vector dimension size. |
| #283 | distance (str, optional): Distance metric (not directly applicable for Databricks). |
| #284 | |
| #285 | Returns: |
| #286 | The index object. |
| #287 | """ |
| #288 | # Determine index configuration |
| #289 | embedding_dims = vector_size or self.embedding_dimension |
| #290 | embedding_source_columns = [ |
| #291 | EmbeddingSourceColumn( |
| #292 | name="memory", |
| #293 | embedding_model_endpoint_name=self.embedding_model_endpoint_name, |
| #294 | ) |
| #295 | ] |
| #296 | |
| #297 | logger.info(f"Creating vector search index '{self.fully_qualified_index_name}'") |
| #298 | |
| #299 | # First, ensure the source Delta table exists |
| #300 | self._ensure_source_table_exists() |
| #301 | |
| #302 | if self.index_type not in [VectorIndexType.DELTA_SYNC, VectorIndexType.DIRECT_ACCESS]: |
| #303 | raise ValueError("index_type must be either 'DELTA_SYNC' or 'DIRECT_ACCESS'") |
| #304 | |
| #305 | try: |
| #306 | if self.index_type == VectorIndexType.DELTA_SYNC: |
| #307 | index = self.client.vector_search_indexes.create_index( |
| #308 | name=self.fully_qualified_index_name, |
| #309 | endpoint_name=self.endpoint_name, |
| #310 | primary_key="memory_id", |
| #311 | index_type=self.index_type, |
| #312 | delta_sync_index_spec=DeltaSyncVectorIndexSpecRequest( |
| #313 | source_table=self.fully_qualified_table_name, |
| #314 | pipeline_type=self.pipeline_type, |
| #315 | columns_to_sync=self.column_names, |
| #316 | embedding_source_columns=embedding_source_columns, |
| #317 | ), |
| #318 | ) |
| #319 | logger.info( |
| #320 | f"Successfully created vector search index '{self.fully_qualified_index_name}' with DELTA_SYNC type" |
| #321 | ) |
| #322 | return index |
| #323 | |
| #324 | elif self.index_type == VectorIndexType.DIRECT_ACCESS: |
| #325 | index = self.client.vector_search_indexes.create_index( |
| #326 | name=self.fully_qualified_index_name, |
| #327 | endpoint_name=self.endpoint_name, |
| #328 | primary_key="memory_id", |
| #329 | index_type=self.index_type, |
| #330 | direct_access_index_spec=DirectAccessVectorIndexSpec( |
| #331 | embedding_source_columns=embedding_source_columns, |
| #332 | embedding_vector_columns=[ |
| #333 | EmbeddingVectorColumn(name="embedding", embedding_dimension=embedding_dims) |
| #334 | ], |
| #335 | ), |
| #336 | ) |
| #337 | logger.info( |
| #338 | f"Successfully created vector search index '{self.fully_qualified_index_name}' with DIRECT_ACCESS type" |
| #339 | ) |
| #340 | return index |
| #341 | except Exception as e: |
| #342 | logger.error(f"Error making index_type: {self.index_type} for index {self.fully_qualified_index_name}: {e}") |
| #343 | |
| #344 | def _format_sql_value(self, v): |
| #345 | """ |
| #346 | Format a Python value into a safe SQL literal for Databricks. |
| #347 | """ |
| #348 | if v is None: |
| #349 | return "NULL" |
| #350 | if isinstance(v, bool): |
| #351 | return "TRUE" if v else "FALSE" |
| #352 | if isinstance(v, (int, float)): |
| #353 | return str(v) |
| #354 | if isinstance(v, (datetime, date)): |
| #355 | return f"'{v.isoformat()}'" |
| #356 | if isinstance(v, list): |
| #357 | # Render arrays (assume numeric or string elements) |
| #358 | elems = [] |
| #359 | for x in v: |
| #360 | if x is None: |
| #361 | elems.append("NULL") |
| #362 | elif isinstance(x, (int, float)): |
| #363 | elems.append(str(x)) |
| #364 | else: |
| #365 | s = str(x).replace("'", "''") |
| #366 | elems.append(f"'{s}'") |
| #367 | return f"array({', '.join(elems)})" |
| #368 | if isinstance(v, dict): |
| #369 | try: |
| #370 | s = json.dumps(v) |
| #371 | except Exception: |
| #372 | s = str(v) |
| #373 | s = s.replace("'", "''") |
| #374 | return f"'{s}'" |
| #375 | # Fallback: treat as string |
| #376 | s = str(v).replace("'", "''") |
| #377 | return f"'{s}'" |
| #378 | |
| #379 | def insert(self, vectors: list, payloads: list = None, ids: list = None): |
| #380 | """ |
| #381 | Insert vectors into the index. |
| #382 | |
| #383 | Args: |
| #384 | vectors (List[List[float]]): List of vectors to insert. |
| #385 | payloads (List[Dict], optional): List of payloads corresponding to vectors. |
| #386 | ids (List[str], optional): List of IDs corresponding to vectors. |
| #387 | """ |
| #388 | # Determine the number of items to process |
| #389 | num_items = len(payloads) if payloads else len(vectors) if vectors else 0 |
| #390 | |
| #391 | value_tuples = [] |
| #392 | for i in range(num_items): |
| #393 | values = [] |
| #394 | for col in self.columns: |
| #395 | if col.name == "memory_id": |
| #396 | val = ids[i] if ids and i < len(ids) else str(uuid.uuid4()) |
| #397 | elif col.name == "embedding": |
| #398 | val = vectors[i] if vectors and i < len(vectors) else [] |
| #399 | elif col.name == "memory": |
| #400 | val = payloads[i].get("data") if payloads and i < len(payloads) else None |
| #401 | else: |
| #402 | val = payloads[i].get(col.name) if payloads and i < len(payloads) else None |
| #403 | values.append(val) |
| #404 | formatted = [self._format_sql_value(v) for v in values] |
| #405 | value_tuples.append(f"({', '.join(formatted)})") |
| #406 | |
| #407 | insert_sql = f"INSERT INTO {self.fully_qualified_table_name} ({', '.join(self.column_names)}) VALUES {', '.join(value_tuples)}" |
| #408 | |
| #409 | # Execute the insert |
| #410 | try: |
| #411 | response = self.client.statement_execution.execute_statement( |
| #412 | statement=insert_sql, warehouse_id=self.warehouse_id, wait_timeout="30s" |
| #413 | ) |
| #414 | if response.status.state.value == "SUCCEEDED": |
| #415 | logger.info( |
| #416 | f"Successfully inserted {num_items} items into Delta table {self.fully_qualified_table_name}" |
| #417 | ) |
| #418 | return |
| #419 | else: |
| #420 | logger.error(f"Failed to insert items: {response.status.error}") |
| #421 | raise Exception(f"Insert operation failed: {response.status.error}") |
| #422 | except Exception as e: |
| #423 | logger.error(f"Insert operation failed: {e}") |
| #424 | raise |
| #425 | |
| #426 | def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> List[MemoryResult]: |
| #427 | """ |
| #428 | Search for similar vectors or text using the Databricks Vector Search index. |
| #429 | |
| #430 | Args: |
| #431 | query (str): Search query text (for text-based search). |
| #432 | vectors (list): Query vector (for vector-based search). |
| #433 | limit (int): Maximum number of results. |
| #434 | filters (dict): Filters to apply. |
| #435 | |
| #436 | Returns: |
| #437 | List of MemoryResult objects. |
| #438 | """ |
| #439 | try: |
| #440 | filters_json = json.dumps(filters) if filters else None |
| #441 | |
| #442 | # Choose query type |
| #443 | if self.index_type == VectorIndexType.DELTA_SYNC and query: |
| #444 | # Text-based search |
| #445 | sdk_results = self.client.vector_search_indexes.query_index( |
| #446 | index_name=self.fully_qualified_index_name, |
| #447 | columns=self.column_names, |
| #448 | query_text=query, |
| #449 | num_results=limit, |
| #450 | query_type=self.query_type, |
| #451 | filters_json=filters_json, |
| #452 | ) |
| #453 | elif self.index_type == VectorIndexType.DIRECT_ACCESS and vectors: |
| #454 | # Vector-based search |
| #455 | sdk_results = self.client.vector_search_indexes.query_index( |
| #456 | index_name=self.fully_qualified_index_name, |
| #457 | columns=self.column_names, |
| #458 | query_vector=vectors, |
| #459 | num_results=limit, |
| #460 | query_type=self.query_type, |
| #461 | filters_json=filters_json, |
| #462 | ) |
| #463 | else: |
| #464 | raise ValueError("Must provide query text for DELTA_SYNC or vectors for DIRECT_ACCESS.") |
| #465 | |
| #466 | # Parse results |
| #467 | result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results |
| #468 | data_array = result_data.data_array if getattr(result_data, "data_array", None) else [] |
| #469 | |
| #470 | memory_results = [] |
| #471 | for row in data_array: |
| #472 | # Map columns to values |
| #473 | row_dict = dict(zip(self.column_names, row)) if isinstance(row, (list, tuple)) else row |
| #474 | score = row_dict.get("score") or ( |
| #475 | row[-1] if isinstance(row, (list, tuple)) and len(row) > len(self.column_names) else None |
| #476 | ) |
| #477 | payload = {k: row_dict.get(k) for k in self.column_names} |
| #478 | payload["data"] = payload.get("memory", "") |
| #479 | memory_id = row_dict.get("memory_id") or row_dict.get("id") |
| #480 | memory_results.append(MemoryResult(id=memory_id, score=score, payload=payload)) |
| #481 | return memory_results |
| #482 | |
| #483 | except Exception as e: |
| #484 | logger.error(f"Search failed: {e}") |
| #485 | raise |
| #486 | |
| #487 | def delete(self, vector_id): |
| #488 | """ |
| #489 | Delete a vector by ID from the Delta table. |
| #490 | |
| #491 | Args: |
| #492 | vector_id (str): ID of the vector to delete. |
| #493 | """ |
| #494 | try: |
| #495 | logger.info(f"Deleting vector with ID {vector_id} from Delta table {self.fully_qualified_table_name}") |
| #496 | |
| #497 | delete_sql = f"DELETE FROM {self.fully_qualified_table_name} WHERE memory_id = '{vector_id}'" |
| #498 | |
| #499 | response = self.client.statement_execution.execute_statement( |
| #500 | statement=delete_sql, warehouse_id=self.warehouse_id, wait_timeout="30s" |
| #501 | ) |
| #502 | |
| #503 | if response.status.state.value == "SUCCEEDED": |
| #504 | logger.info(f"Successfully deleted vector with ID {vector_id}") |
| #505 | else: |
| #506 | logger.error(f"Failed to delete vector with ID {vector_id}: {response.status.error}") |
| #507 | |
| #508 | except Exception as e: |
| #509 | logger.error(f"Delete operation failed for vector ID {vector_id}: {e}") |
| #510 | raise |
| #511 | |
| #512 | def update(self, vector_id=None, vector=None, payload=None): |
| #513 | """ |
| #514 | Update a vector and its payload in the Delta table. |
| #515 | |
| #516 | Args: |
| #517 | vector_id (str): ID of the vector to update. |
| #518 | vector (list, optional): New vector values. |
| #519 | payload (dict, optional): New payload data. |
| #520 | """ |
| #521 | |
| #522 | update_sql = f"UPDATE {self.fully_qualified_table_name} SET " |
| #523 | set_clauses = [] |
| #524 | if not vector_id: |
| #525 | logger.error("vector_id is required for update operation") |
| #526 | return |
| #527 | if vector is not None: |
| #528 | if not isinstance(vector, list): |
| #529 | logger.error("vector must be a list of float values") |
| #530 | return |
| #531 | set_clauses.append(f"embedding = {vector}") |
| #532 | if payload: |
| #533 | if not isinstance(payload, dict): |
| #534 | logger.error("payload must be a dictionary") |
| #535 | return |
| #536 | for key, value in payload.items(): |
| #537 | if key not in excluded_keys: |
| #538 | set_clauses.append(f"{key} = '{value}'") |
| #539 | |
| #540 | if not set_clauses: |
| #541 | logger.error("No fields to update") |
| #542 | return |
| #543 | update_sql += ", ".join(set_clauses) |
| #544 | update_sql += f" WHERE memory_id = '{vector_id}'" |
| #545 | try: |
| #546 | logger.info(f"Updating vector with ID {vector_id} in Delta table {self.fully_qualified_table_name}") |
| #547 | |
| #548 | response = self.client.statement_execution.execute_statement( |
| #549 | statement=update_sql, warehouse_id=self.warehouse_id, wait_timeout="30s" |
| #550 | ) |
| #551 | |
| #552 | if response.status.state.value == "SUCCEEDED": |
| #553 | logger.info(f"Successfully updated vector with ID {vector_id}") |
| #554 | else: |
| #555 | logger.error(f"Failed to update vector with ID {vector_id}: {response.status.error}") |
| #556 | except Exception as e: |
| #557 | logger.error(f"Update operation failed for vector ID {vector_id}: {e}") |
| #558 | raise |
| #559 | |
| #560 | def get(self, vector_id) -> MemoryResult: |
| #561 | """ |
| #562 | Retrieve a vector by ID. |
| #563 | |
| #564 | Args: |
| #565 | vector_id (str): ID of the vector to retrieve. |
| #566 | |
| #567 | Returns: |
| #568 | MemoryResult: The retrieved vector. |
| #569 | """ |
| #570 | try: |
| #571 | # Use query with ID filter to retrieve the specific vector |
| #572 | filters = {"memory_id": vector_id} |
| #573 | filters_json = json.dumps(filters) |
| #574 | |
| #575 | results = self.client.vector_search_indexes.query_index( |
| #576 | index_name=self.fully_qualified_index_name, |
| #577 | columns=self.column_names, |
| #578 | query_text=" ", # Empty query, rely on filters |
| #579 | num_results=1, |
| #580 | query_type=self.query_type, |
| #581 | filters_json=filters_json, |
| #582 | ) |
| #583 | |
| #584 | # Process results |
| #585 | result_data = results.result if hasattr(results, "result") else results |
| #586 | data_array = result_data.data_array if hasattr(result_data, "data_array") else [] |
| #587 | |
| #588 | if not data_array: |
| #589 | raise KeyError(f"Vector with ID {vector_id} not found") |
| #590 | |
| #591 | result = data_array[0] |
| #592 | columns = columns = [col.name for col in results.manifest.columns] if results.manifest and results.manifest.columns else [] |
| #593 | row_data = dict(zip(columns, result)) |
| #594 | |
| #595 | # Build payload following the standard schema |
| #596 | payload = { |
| #597 | "hash": row_data.get("hash", "unknown"), |
| #598 | "data": row_data.get("memory", row_data.get("data", "unknown")), |
| #599 | "created_at": row_data.get("created_at"), |
| #600 | } |
| #601 | |
| #602 | # Add updated_at if available |
| #603 | if "updated_at" in row_data: |
| #604 | payload["updated_at"] = row_data.get("updated_at") |
| #605 | |
| #606 | # Add optional fields |
| #607 | for field in ["agent_id", "run_id", "user_id"]: |
| #608 | if field in row_data: |
| #609 | payload[field] = row_data[field] |
| #610 | |
| #611 | # Add metadata |
| #612 | if "metadata" in row_data and row_data.get('metadata'): |
| #613 | try: |
| #614 | metadata = json.loads(extract_json(row_data["metadata"])) |
| #615 | payload.update(metadata) |
| #616 | except (json.JSONDecodeError, TypeError): |
| #617 | logger.warning(f"Failed to parse metadata: {row_data.get('metadata')}") |
| #618 | |
| #619 | memory_id = row_data.get("memory_id", row_data.get("memory_id", vector_id)) |
| #620 | return MemoryResult(id=memory_id, payload=payload) |
| #621 | |
| #622 | except Exception as e: |
| #623 | logger.error(f"Failed to get vector with ID {vector_id}: {e}") |
| #624 | raise |
| #625 | |
| #626 | def list_cols(self) -> List[str]: |
| #627 | """ |
| #628 | List all collections (indexes). |
| #629 | |
| #630 | Returns: |
| #631 | List of index names. |
| #632 | """ |
| #633 | try: |
| #634 | indexes = self.client.vector_search_indexes.list_indexes(endpoint_name=self.endpoint_name) |
| #635 | return [idx.name for idx in indexes] |
| #636 | except Exception as e: |
| #637 | logger.error(f"Failed to list collections: {e}") |
| #638 | raise |
| #639 | |
| #640 | def delete_col(self): |
| #641 | """ |
| #642 | Delete the current collection (index). |
| #643 | """ |
| #644 | try: |
| #645 | # Try fully qualified first |
| #646 | try: |
| #647 | self.client.vector_search_indexes.delete_index(index_name=self.fully_qualified_index_name) |
| #648 | logger.info(f"Successfully deleted index '{self.fully_qualified_index_name}'") |
| #649 | except Exception: |
| #650 | self.client.vector_search_indexes.delete_index(index_name=self.index_name) |
| #651 | logger.info(f"Successfully deleted index '{self.index_name}' (short name)") |
| #652 | except Exception as e: |
| #653 | logger.error(f"Failed to delete index '{self.index_name}': {e}") |
| #654 | raise |
| #655 | |
| #656 | def col_info(self, name=None): |
| #657 | """ |
| #658 | Get information about a collection (index). |
| #659 | |
| #660 | Args: |
| #661 | name (str, optional): Index name. Defaults to current index. |
| #662 | |
| #663 | Returns: |
| #664 | Dict: Index information. |
| #665 | """ |
| #666 | try: |
| #667 | index_name = name or self.index_name |
| #668 | index = self.client.vector_search_indexes.get_index(index_name=index_name) |
| #669 | return {"name": index.name, "fields": self.columns} |
| #670 | except Exception as e: |
| #671 | logger.error(f"Failed to get info for index '{name or self.index_name}': {e}") |
| #672 | raise |
| #673 | |
| #674 | def list(self, filters: dict = None, limit: int = None) -> list[MemoryResult]: |
| #675 | """ |
| #676 | List all recent created memories from the vector store. |
| #677 | |
| #678 | Args: |
| #679 | filters (dict, optional): Filters to apply. |
| #680 | limit (int, optional): Maximum number of results. |
| #681 | |
| #682 | Returns: |
| #683 | List containing list of MemoryResult objects. |
| #684 | """ |
| #685 | try: |
| #686 | filters_json = json.dumps(filters) if filters else None |
| #687 | num_results = limit or 100 |
| #688 | columns = self.column_names |
| #689 | sdk_results = self.client.vector_search_indexes.query_index( |
| #690 | index_name=self.fully_qualified_index_name, |
| #691 | columns=columns, |
| #692 | query_text=" ", |
| #693 | num_results=num_results, |
| #694 | query_type=self.query_type, |
| #695 | filters_json=filters_json, |
| #696 | ) |
| #697 | result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results |
| #698 | data_array = result_data.data_array if hasattr(result_data, "data_array") else [] |
| #699 | |
| #700 | memory_results = [] |
| #701 | for row in data_array: |
| #702 | row_dict = dict(zip(columns, row)) if isinstance(row, (list, tuple)) else row |
| #703 | payload = {k: row_dict.get(k) for k in columns} |
| #704 | # Parse metadata if present |
| #705 | if "metadata" in payload and payload["metadata"]: |
| #706 | try: |
| #707 | payload.update(json.loads(payload["metadata"])) |
| #708 | except Exception: |
| #709 | pass |
| #710 | memory_id = row_dict.get("memory_id") or row_dict.get("id") |
| #711 | payload['data'] = payload['memory'] |
| #712 | memory_results.append(MemoryResult(id=memory_id, payload=payload)) |
| #713 | return [memory_results] |
| #714 | except Exception as e: |
| #715 | logger.error(f"Failed to list memories: {e}") |
| #716 | return [] |
| #717 | |
| #718 | def reset(self): |
| #719 | """Reset the vector search index and underlying source table. |
| #720 | |
| #721 | This will attempt to delete the existing index (both fully qualified and short name forms |
| #722 | for robustness), drop the backing Delta table, recreate the table with the expected schema, |
| #723 | and finally recreate the index. Use with caution as all existing data will be removed. |
| #724 | """ |
| #725 | fq_index = self.fully_qualified_index_name |
| #726 | logger.warning(f"Resetting Databricks vector search index '{fq_index}'...") |
| #727 | try: |
| #728 | # Try deleting via fully qualified name first |
| #729 | try: |
| #730 | self.client.vector_search_indexes.delete_index(index_name=fq_index) |
| #731 | logger.info(f"Deleted index '{fq_index}'") |
| #732 | except Exception as e_fq: |
| #733 | logger.debug(f"Failed deleting fully qualified index name '{fq_index}': {e_fq}. Trying short name...") |
| #734 | try: |
| #735 | # Fallback to existing helper which may use short name |
| #736 | self.delete_col() |
| #737 | except Exception as e_short: |
| #738 | logger.debug(f"Failed deleting short index name '{self.index_name}': {e_short}") |
| #739 | |
| #740 | # Drop the backing table (if it exists) |
| #741 | try: |
| #742 | drop_sql = f"DROP TABLE IF EXISTS {self.fully_qualified_table_name}" |
| #743 | resp = self.client.statement_execution.execute_statement( |
| #744 | statement=drop_sql, warehouse_id=self.warehouse_id, wait_timeout="30s" |
| #745 | ) |
| #746 | if getattr(resp.status, "state", None) == "SUCCEEDED": |
| #747 | logger.info(f"Dropped table '{self.fully_qualified_table_name}'") |
| #748 | else: |
| #749 | logger.warning( |
| #750 | f"Attempted to drop table '{self.fully_qualified_table_name}' but state was {getattr(resp.status, 'state', 'UNKNOWN')}: {getattr(resp.status, 'error', None)}" |
| #751 | ) |
| #752 | except Exception as e_drop: |
| #753 | logger.warning(f"Failed to drop table '{self.fully_qualified_table_name}': {e_drop}") |
| #754 | |
| #755 | # Recreate table & index |
| #756 | self._ensure_source_table_exists() |
| #757 | self.create_col() |
| #758 | logger.info(f"Successfully reset index '{fq_index}'") |
| #759 | except Exception as e: |
| #760 | logger.error(f"Error resetting index '{fq_index}': {e}") |
| #761 | raise |
| #762 |