repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | import json |
| #2 | import logging |
| #3 | from datetime import datetime |
| #4 | from functools import reduce |
| #5 | |
| #6 | import numpy as np |
| #7 | import pytz |
| #8 | import redis |
| #9 | from redis.commands.search.query import Query |
| #10 | from redisvl.index import SearchIndex |
| #11 | from redisvl.query import VectorQuery |
| #12 | from redisvl.query.filter import Tag |
| #13 | |
| #14 | from mem0.memory.utils import extract_json |
| #15 | from mem0.vector_stores.base import VectorStoreBase |
| #16 | |
| #17 | logger = logging.getLogger(__name__) |
| #18 | |
| #19 | # TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them. |
| #20 | DEFAULT_FIELDS = [ |
| #21 | {"name": "memory_id", "type": "tag"}, |
| #22 | {"name": "hash", "type": "tag"}, |
| #23 | {"name": "agent_id", "type": "tag"}, |
| #24 | {"name": "run_id", "type": "tag"}, |
| #25 | {"name": "user_id", "type": "tag"}, |
| #26 | {"name": "memory", "type": "text"}, |
| #27 | {"name": "metadata", "type": "text"}, |
| #28 | # TODO: Although it is numeric but also accepts string |
| #29 | {"name": "created_at", "type": "numeric"}, |
| #30 | {"name": "updated_at", "type": "numeric"}, |
| #31 | { |
| #32 | "name": "embedding", |
| #33 | "type": "vector", |
| #34 | "attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"}, |
| #35 | }, |
| #36 | ] |
| #37 | |
| #38 | excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} |
| #39 | |
| #40 | |
| #41 | class MemoryResult: |
| #42 | def __init__(self, id: str, payload: dict, score: float = None): |
| #43 | self.id = id |
| #44 | self.payload = payload |
| #45 | self.score = score |
| #46 | |
| #47 | |
| #48 | class RedisDB(VectorStoreBase): |
| #49 | def __init__( |
| #50 | self, |
| #51 | redis_url: str, |
| #52 | collection_name: str, |
| #53 | embedding_model_dims: int, |
| #54 | ): |
| #55 | """ |
| #56 | Initialize the Redis vector store. |
| #57 | |
| #58 | Args: |
| #59 | redis_url (str): Redis URL. |
| #60 | collection_name (str): Collection name. |
| #61 | embedding_model_dims (int): Embedding model dimensions. |
| #62 | """ |
| #63 | self.embedding_model_dims = embedding_model_dims |
| #64 | index_schema = { |
| #65 | "name": collection_name, |
| #66 | "prefix": f"mem0:{collection_name}", |
| #67 | } |
| #68 | |
| #69 | fields = DEFAULT_FIELDS.copy() |
| #70 | fields[-1]["attrs"]["dims"] = embedding_model_dims |
| #71 | |
| #72 | self.schema = {"index": index_schema, "fields": fields} |
| #73 | |
| #74 | self.client = redis.Redis.from_url(redis_url) |
| #75 | self.index = SearchIndex.from_dict(self.schema) |
| #76 | self.index.set_client(self.client) |
| #77 | self.index.create(overwrite=True) |
| #78 | |
| #79 | def create_col(self, name=None, vector_size=None, distance=None): |
| #80 | """ |
| #81 | Create a new collection (index) in Redis. |
| #82 | |
| #83 | Args: |
| #84 | name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name. |
| #85 | vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims. |
| #86 | distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'. |
| #87 | |
| #88 | Returns: |
| #89 | The created index object. |
| #90 | """ |
| #91 | # Use provided parameters or fall back to instance attributes |
| #92 | collection_name = name or self.schema["index"]["name"] |
| #93 | embedding_dims = vector_size or self.embedding_model_dims |
| #94 | distance_metric = distance or "cosine" |
| #95 | |
| #96 | # Create a new schema with the specified parameters |
| #97 | index_schema = { |
| #98 | "name": collection_name, |
| #99 | "prefix": f"mem0:{collection_name}", |
| #100 | } |
| #101 | |
| #102 | # Copy the default fields and update the vector field with the specified dimensions |
| #103 | fields = DEFAULT_FIELDS.copy() |
| #104 | fields[-1]["attrs"]["dims"] = embedding_dims |
| #105 | fields[-1]["attrs"]["distance_metric"] = distance_metric |
| #106 | |
| #107 | # Create the schema |
| #108 | schema = {"index": index_schema, "fields": fields} |
| #109 | |
| #110 | # Create the index |
| #111 | index = SearchIndex.from_dict(schema) |
| #112 | index.set_client(self.client) |
| #113 | index.create(overwrite=True) |
| #114 | |
| #115 | # Update instance attributes if creating a new collection |
| #116 | if name: |
| #117 | self.schema = schema |
| #118 | self.index = index |
| #119 | |
| #120 | return index |
| #121 | |
| #122 | def insert(self, vectors: list, payloads: list = None, ids: list = None): |
| #123 | data = [] |
| #124 | for vector, payload, id in zip(vectors, payloads, ids): |
| #125 | # Start with required fields |
| #126 | entry = { |
| #127 | "memory_id": id, |
| #128 | "hash": payload["hash"], |
| #129 | "memory": payload["data"], |
| #130 | "created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()), |
| #131 | "embedding": np.array(vector, dtype=np.float32).tobytes(), |
| #132 | } |
| #133 | |
| #134 | # Conditionally add optional fields |
| #135 | for field in ["agent_id", "run_id", "user_id"]: |
| #136 | if field in payload: |
| #137 | entry[field] = payload[field] |
| #138 | |
| #139 | # Add metadata excluding specific keys |
| #140 | entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys}) |
| #141 | |
| #142 | data.append(entry) |
| #143 | self.index.load(data, id_field="memory_id") |
| #144 | |
| #145 | def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None): |
| #146 | conditions = [Tag(key) == value for key, value in filters.items() if value is not None] |
| #147 | filter = reduce(lambda x, y: x & y, conditions) |
| #148 | |
| #149 | v = VectorQuery( |
| #150 | vector=np.array(vectors, dtype=np.float32).tobytes(), |
| #151 | vector_field_name="embedding", |
| #152 | return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"], |
| #153 | filter_expression=filter, |
| #154 | num_results=limit, |
| #155 | ) |
| #156 | |
| #157 | results = self.index.query(v) |
| #158 | |
| #159 | return [ |
| #160 | MemoryResult( |
| #161 | id=result["memory_id"], |
| #162 | score=result["vector_distance"], |
| #163 | payload={ |
| #164 | "hash": result["hash"], |
| #165 | "data": result["memory"], |
| #166 | "created_at": datetime.fromtimestamp( |
| #167 | int(result["created_at"]), tz=pytz.timezone("US/Pacific") |
| #168 | ).isoformat(timespec="microseconds"), |
| #169 | **( |
| #170 | { |
| #171 | "updated_at": datetime.fromtimestamp( |
| #172 | int(result["updated_at"]), tz=pytz.timezone("US/Pacific") |
| #173 | ).isoformat(timespec="microseconds") |
| #174 | } |
| #175 | if "updated_at" in result |
| #176 | else {} |
| #177 | ), |
| #178 | **{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result}, |
| #179 | **{k: v for k, v in json.loads(extract_json(result["metadata"])).items()}, |
| #180 | }, |
| #181 | ) |
| #182 | for result in results |
| #183 | ] |
| #184 | |
| #185 | def delete(self, vector_id): |
| #186 | self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}") |
| #187 | |
| #188 | def update(self, vector_id=None, vector=None, payload=None): |
| #189 | data = { |
| #190 | "memory_id": vector_id, |
| #191 | "hash": payload["hash"], |
| #192 | "memory": payload["data"], |
| #193 | "created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()), |
| #194 | "updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()), |
| #195 | "embedding": np.array(vector, dtype=np.float32).tobytes(), |
| #196 | } |
| #197 | |
| #198 | for field in ["agent_id", "run_id", "user_id"]: |
| #199 | if field in payload: |
| #200 | data[field] = payload[field] |
| #201 | |
| #202 | data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys}) |
| #203 | self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id") |
| #204 | |
| #205 | def get(self, vector_id): |
| #206 | result = self.index.fetch(vector_id) |
| #207 | payload = { |
| #208 | "hash": result["hash"], |
| #209 | "data": result["memory"], |
| #210 | "created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat( |
| #211 | timespec="microseconds" |
| #212 | ), |
| #213 | **( |
| #214 | { |
| #215 | "updated_at": datetime.fromtimestamp( |
| #216 | int(result["updated_at"]), tz=pytz.timezone("US/Pacific") |
| #217 | ).isoformat(timespec="microseconds") |
| #218 | } |
| #219 | if "updated_at" in result |
| #220 | else {} |
| #221 | ), |
| #222 | **{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result}, |
| #223 | **{k: v for k, v in json.loads(extract_json(result["metadata"])).items()}, |
| #224 | } |
| #225 | |
| #226 | return MemoryResult(id=result["memory_id"], payload=payload) |
| #227 | |
| #228 | def list_cols(self): |
| #229 | return self.index.listall() |
| #230 | |
| #231 | def delete_col(self): |
| #232 | self.index.delete() |
| #233 | |
| #234 | def col_info(self, name): |
| #235 | return self.index.info() |
| #236 | |
| #237 | def reset(self): |
| #238 | """ |
| #239 | Reset the index by deleting and recreating it. |
| #240 | """ |
| #241 | collection_name = self.schema["index"]["name"] |
| #242 | logger.warning(f"Resetting index {collection_name}...") |
| #243 | self.delete_col() |
| #244 | |
| #245 | self.index = SearchIndex.from_dict(self.schema) |
| #246 | self.index.set_client(self.client) |
| #247 | self.index.create(overwrite=True) |
| #248 | |
| #249 | # or use |
| #250 | # self.create_col(collection_name, self.embedding_model_dims) |
| #251 | |
| #252 | # Recreate the index with the same parameters |
| #253 | self.create_col(collection_name, self.embedding_model_dims) |
| #254 | |
| #255 | def list(self, filters: dict = None, limit: int = None) -> list: |
| #256 | """ |
| #257 | List all recent created memories from the vector store. |
| #258 | """ |
| #259 | conditions = [Tag(key) == value for key, value in filters.items() if value is not None] |
| #260 | filter = reduce(lambda x, y: x & y, conditions) |
| #261 | query = Query(str(filter)).sort_by("created_at", asc=False) |
| #262 | if limit is not None: |
| #263 | query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit) |
| #264 | |
| #265 | results = self.index.search(query) |
| #266 | return [ |
| #267 | [ |
| #268 | MemoryResult( |
| #269 | id=result["memory_id"], |
| #270 | payload={ |
| #271 | "hash": result["hash"], |
| #272 | "data": result["memory"], |
| #273 | "created_at": datetime.fromtimestamp( |
| #274 | int(result["created_at"]), tz=pytz.timezone("US/Pacific") |
| #275 | ).isoformat(timespec="microseconds"), |
| #276 | **( |
| #277 | { |
| #278 | "updated_at": datetime.fromtimestamp( |
| #279 | int(result["updated_at"]), tz=pytz.timezone("US/Pacific") |
| #280 | ).isoformat(timespec="microseconds") |
| #281 | } |
| #282 | if result.__dict__.get("updated_at") |
| #283 | else {} |
| #284 | ), |
| #285 | **{ |
| #286 | field: result[field] |
| #287 | for field in ["agent_id", "run_id", "user_id"] |
| #288 | if field in result.__dict__ |
| #289 | }, |
| #290 | **{k: v for k, v in json.loads(extract_json(result["metadata"])).items()}, |
| #291 | }, |
| #292 | ) |
| #293 | for result in results.docs |
| #294 | ] |
| #295 | ] |
| #296 |