repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources15d ago| #1 | #!/usr/bin/env python3 |
| #2 | """ |
| #3 | BEAM SOTA Benchmark: Mnemosyne vs ICLR 2026 BEAM Dataset |
| #4 | ========================================================= |
| #5 | Evaluates Mnemosyne's BEAM architecture against the official BEAM benchmark |
| #6 | (Tavakoli et al., ICLR 2026) across all scales: 100K, 500K, 1M, 10M tokens. |
| #7 | |
| #8 | Metrics: |
| #9 | - Recall@K (K=1,3,5,10) |
| #10 | - MRR (Mean Reciprocal Rank) |
| #11 | - NDCG@K |
| #12 | - Robustness-δ@K (δ=0.1, 0.3, 0.5) |
| #13 | - Latency (avg, p50, p95, p99) |
| #14 | - Throughput (queries/sec) |
| #15 | |
| #16 | Modes: |
| #17 | - full: All BEAM tiers active (working + episodic + scratchpad) |
| #18 | - fts5_only: FTS5 text search only (no vectors) |
| #19 | - vec_only: Vector search only (no FTS5) |
| #20 | - keyword_only: Simple keyword fallback (no FTS5, no vectors) |
| #21 | - no_scratchpad: Ablation - scratchpad disabled |
| #22 | - no_episodic: Ablation - episodic memory disabled |
| #23 | |
| #24 | Run: PYTHONPATH=. python tests/benchmark_beam_sota.py --scales 100K,500K,1M |
| #25 | """ |
| #26 | |
| #27 | import argparse |
| #28 | import ast |
| #29 | import gc |
| #30 | import hashlib |
| #31 | import json |
| #32 | import math |
| #33 | import os |
| #34 | import resource |
| #35 | import statistics |
| #36 | import sys |
| #37 | import tempfile |
| #38 | import time |
| #39 | from collections import defaultdict |
| #40 | from datetime import datetime |
| #41 | from pathlib import Path |
| #42 | from typing import Dict, List, Optional, Tuple |
| #43 | |
| #44 | import numpy as np |
| #45 | |
| #46 | # --- Raise file descriptor limit (must happen early) --- |
| #47 | import resource as _resource |
| #48 | try: |
| #49 | _resource.setrlimit(_resource.RLIMIT_NOFILE, (65536, 65536)) |
| #50 | except Exception: |
| #51 | pass |
| #52 | |
| #53 | # Pre-load embedding model before datasets consume file descriptors |
| #54 | print(" Pre-loading embedding model...") |
| #55 | try: |
| #56 | from mnemosyne.core import embeddings as _emb |
| #57 | _ = _emb.embed(["warmup"]) |
| #58 | print(f" Embeddings ready: {_emb.available()}") |
| #59 | except Exception: |
| #60 | print(" Embeddings not available (will use FTS5 fallback)") |
| #61 | |
| #62 | import gc as _gc |
| #63 | |
| #64 | # --- Config --- |
| #65 | PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| #66 | sys.path.insert(0, str(PROJECT_ROOT)) |
| #67 | |
| #68 | from mnemosyne.core.beam import BeamMemory, init_beam |
| #69 | from mnemosyne.core import embeddings as _embeddings |
| #70 | |
| #71 | # Defaults |
| #72 | DEFAULT_SCALES = ["100K"] |
| #73 | DEFAULT_TOP_K = 10 |
| #74 | DEFAULT_WARMUP = 3 |
| #75 | BENCHMARK_QUERIES_PER_SCALE = 50 # Cap probing questions per scale |
| #76 | WORKING_MEMORY_BATCH = 500 |
| #77 | SCRATCHPAD_MAX = 200 |
| #78 | EMBEDDING_DIM = 384 |
| #79 | VEC_TYPE = os.environ.get("MNEMOSYNE_VEC_TYPE", "int8") |
| #80 | |
| #81 | # Robustness thresholds |
| #82 | ROBUSTNESS_DELTAS = [0.1, 0.3, 0.5] |
| #83 | |
| #84 | # --- Utility --- |
| #85 | |
| #86 | def fmt_ms(val: float) -> str: |
| #87 | if val < 1: |
| #88 | return f"{val*1000:.1f} µs" |
| #89 | elif val < 1000: |
| #90 | return f"{val:.1f} ms" |
| #91 | else: |
| #92 | return f"{val/1000:.1f} s" |
| #93 | |
| #94 | def fmt_size(size_bytes: int) -> str: |
| #95 | if size_bytes < 1024: |
| #96 | return f"{size_bytes} B" |
| #97 | elif size_bytes < 1024 * 1024: |
| #98 | return f"{size_bytes / 1024:.1f} KB" |
| #99 | else: |
| #100 | return f"{size_bytes / (1024**2):.2f} MB" |
| #101 | |
| #102 | def pcnt(val: float) -> str: |
| #103 | return f"{val * 100:.1f}%" |
| #104 | |
| #105 | # --- Data Loading --- |
| #106 | |
| #107 | def load_beam_dataset(scales: List[str] = None, max_conversations: int = None): |
| #108 | """Load BEAM dataset from HuggingFace. Returns dict[scale] -> list[conversation].""" |
| #109 | try: |
| #110 | from datasets import load_dataset |
| #111 | except ImportError: |
| #112 | print("ERROR: 'datasets' package not installed. Run: pip install datasets") |
| #113 | sys.exit(1) |
| #114 | |
| #115 | if scales is None: |
| #116 | scales = ["100K", "500K", "1M"] |
| #117 | |
| #118 | data = {} |
| #119 | total_loaded = 0 |
| #120 | |
| #121 | for scale in scales: |
| #122 | print(f" Loading BEAM {scale}...") |
| #123 | try: |
| #124 | ds = load_dataset("Mohammadta/BEAM", streaming=True) |
| #125 | if scale not in ds: |
| #126 | print(f" WARNING: split '{scale}' not found. Available: {list(ds.keys())}") |
| #127 | continue |
| #128 | |
| #129 | conversations = [] |
| #130 | for i, sample in enumerate(ds[scale]): |
| #131 | if max_conversations and i >= max_conversations: |
| #132 | break |
| #133 | |
| #134 | # Parse probing questions |
| #135 | try: |
| #136 | pq_raw = sample.get("probing_questions", "{}") |
| #137 | if isinstance(pq_raw, str): |
| #138 | probing = ast.literal_eval(pq_raw) |
| #139 | else: |
| #140 | probing = pq_raw |
| #141 | except Exception: |
| #142 | probing = {} |
| #143 | |
| #144 | # Flatten probing questions into list of {question, ideal_answer, ability} |
| #145 | flat_questions = [] |
| #146 | for ability, questions in probing.items(): |
| #147 | if isinstance(questions, list): |
| #148 | for q in questions: |
| #149 | if isinstance(q, dict): |
| #150 | flat_questions.append({ |
| #151 | "ability": ability, |
| #152 | "question": q.get("question", ""), |
| #153 | "ideal_answer": q.get("ideal_answer", q.get("ideal_response", "")), |
| #154 | }) |
| #155 | |
| #156 | # Extract chat messages |
| #157 | chat_blocks = sample.get("chat", []) |
| #158 | messages = [] |
| #159 | for block in chat_blocks: |
| #160 | if isinstance(block, list): |
| #161 | for msg in block: |
| #162 | if isinstance(msg, dict): |
| #163 | messages.append({ |
| #164 | "role": msg.get("role", "unknown"), |
| #165 | "content": msg.get("content", ""), |
| #166 | "time_anchor": msg.get("time_anchor", ""), |
| #167 | "index": msg.get("index", len(messages)), |
| #168 | }) |
| #169 | |
| #170 | conv_id = sample.get("conversation_id", str(i)) |
| #171 | conversations.append({ |
| #172 | "id": conv_id, |
| #173 | "messages": messages, |
| #174 | "questions": flat_questions, |
| #175 | "seed": sample.get("conversation_seed", {}), |
| #176 | "scale": scale, |
| #177 | }) |
| #178 | total_loaded += 1 |
| #179 | |
| #180 | data[scale] = conversations |
| #181 | |
| #182 | # Release dataset handles |
| #183 | try: |
| #184 | ds.cleanup_cache_files() |
| #185 | except Exception: |
| #186 | pass |
| #187 | del ds |
| #188 | _gc.collect() # Force GC to release file handles |
| #189 | print(f" Loaded {len(conversations)} conversations") |
| #190 | except Exception as e: |
| #191 | print(f" ERROR loading {scale}: {e}") |
| #192 | import traceback |
| #193 | traceback.print_exc() |
| #194 | |
| #195 | print(f" Total: {total_loaded} conversations across {len(data)} scales") |
| #196 | return data |
| #197 | |
| #198 | |
| #199 | def load_beam_10m(max_conversations: int = None): |
| #200 | """Load BEAM-10M dataset from HuggingFace. Has special multi-plan structure.""" |
| #201 | try: |
| #202 | from datasets import load_dataset |
| #203 | except ImportError: |
| #204 | print("ERROR: 'datasets' package not installed.") |
| #205 | return [] |
| #206 | |
| #207 | print(" Loading BEAM-10M...") |
| #208 | try: |
| #209 | ds = load_dataset("Mohammadta/BEAM-10M", streaming=True) |
| #210 | conversations = [] |
| #211 | |
| #212 | # BEAM-10M has a single split named "10M" or similar |
| #213 | split_name = "10M" if "10M" in ds else list(ds.keys())[0] |
| #214 | for i, sample in enumerate(ds[split_name]): |
| #215 | if max_conversations and i >= max_conversations: |
| #216 | break |
| #217 | |
| #218 | # BEAM-10M: probing questions are at top level, not inside plans |
| #219 | probing_raw = sample.get("probing_questions", {}) |
| #220 | if isinstance(probing_raw, str): |
| #221 | try: |
| #222 | probing = ast.literal_eval(probing_raw) |
| #223 | except Exception: |
| #224 | probing = {} |
| #225 | else: |
| #226 | probing = probing_raw |
| #227 | |
| #228 | all_questions = [] |
| #229 | for ability, questions in probing.items(): |
| #230 | if isinstance(questions, list): |
| #231 | for q in questions: |
| #232 | if isinstance(q, dict): |
| #233 | all_questions.append({ |
| #234 | "ability": ability, |
| #235 | "question": q.get("question", ""), |
| #236 | "ideal_answer": q.get("ideal_answer", q.get("ideal_response", "")), |
| #237 | }) |
| #238 | |
| #239 | # Extract messages from plans |
| #240 | plans = sample.get("plans", []) |
| #241 | all_messages = [] |
| #242 | for plan in plans: |
| #243 | chat_blocks = plan.get("chat", []) if isinstance(plan, dict) else [] |
| #244 | for block in chat_blocks: |
| #245 | if isinstance(block, list): |
| #246 | for msg in block: |
| #247 | if isinstance(msg, dict): |
| #248 | all_messages.append({ |
| #249 | "role": msg.get("role", "unknown"), |
| #250 | "content": msg.get("content", ""), |
| #251 | "time_anchor": msg.get("time_anchor", ""), |
| #252 | "index": len(all_messages), |
| #253 | }) |
| #254 | |
| #255 | conv_id = sample.get("conversation_id", str(i)) |
| #256 | conversations.append({ |
| #257 | "id": conv_id, |
| #258 | "messages": all_messages, |
| #259 | "questions": all_questions, |
| #260 | "seed": sample.get("conversation_seed", {}), |
| #261 | "scale": "10M", |
| #262 | }) |
| #263 | |
| #264 | print(f" Loaded {len(conversations)} conversations") |
| #265 | return conversations |
| #266 | except Exception as e: |
| #267 | print(f" ERROR loading 10M: {e}") |
| #268 | return [] |
| #269 | |
| #270 | |
| #271 | # --- Mnemosyne Ingestion --- |
| #272 | |
| #273 | def ingest_conversation(beam: BeamMemory, messages: List[Dict], |
| #274 | use_scratchpad: bool = True, |
| #275 | use_episodic: bool = True) -> Dict: |
| #276 | """Ingest a conversation into Mnemosyne BEAM tiers using batch writes.""" |
| #277 | start_time = time.perf_counter() |
| #278 | stats = {"wm_count": 0, "ep_count": 0, "sp_count": 0, "total_chars": 0} |
| #279 | |
| #280 | BATCH_SIZE = 500 |
| #281 | |
| #282 | # Process in batches for efficiency |
| #283 | for batch_start in range(0, len(messages), BATCH_SIZE): |
| #284 | batch_msgs = messages[batch_start:batch_start + BATCH_SIZE] |
| #285 | |
| #286 | # Build batch items |
| #287 | batch_items = [] |
| #288 | for i, msg in enumerate(batch_msgs): |
| #289 | content = msg.get("content", "") |
| #290 | if not content.strip(): |
| #291 | continue |
| #292 | batch_items.append({ |
| #293 | "content": content, |
| #294 | "source": f"beam_{msg.get('role', 'unknown')}", |
| #295 | "importance": 0.3 + (0.1 * ((batch_start + i) % 5)), |
| #296 | }) |
| #297 | stats["total_chars"] += len(content) |
| #298 | |
| #299 | # Scratchpad every 10 messages |
| #300 | if use_scratchpad and (batch_start + i) % 10 == 0 and len(content) > 50: |
| #301 | try: |
| #302 | beam.scratchpad_write(f"[t={batch_start + i}] {content[:300]}") |
| #303 | stats["sp_count"] += 1 |
| #304 | except Exception: |
| #305 | pass |
| #306 | |
| #307 | if not batch_items: |
| #308 | continue |
| #309 | |
| #310 | # Batch insert into working memory |
| #311 | beam.remember_batch(batch_items) |
| #312 | stats["wm_count"] += len(batch_items) |
| #313 | |
| #314 | # Episodic consolidation per batch |
| #315 | if use_episodic: |
| #316 | try: |
| #317 | cursor = beam.conn.cursor() |
| #318 | # Get oldest working memory items for this batch |
| #319 | cursor.execute(""" |
| #320 | SELECT id, content FROM working_memory |
| #321 | WHERE session_id = ? |
| #322 | ORDER BY timestamp ASC |
| #323 | LIMIT ? |
| #324 | """, (beam.session_id, min(len(batch_items), 500))) |
| #325 | wm_rows = cursor.fetchall() |
| #326 | |
| #327 | if wm_rows: |
| #328 | wm_ids = [row["id"] for row in wm_rows] |
| #329 | recent_texts = [row["content"][:100] for row in wm_rows[:5]] |
| #330 | summary = f"Conversation batch {batch_start // BATCH_SIZE}: " + " | ".join(recent_texts[:3]) |
| #331 | if len(summary) > 500: |
| #332 | summary = summary[:497] + "..." |
| #333 | |
| #334 | beam.consolidate_to_episodic( |
| #335 | summary=summary, |
| #336 | source_wm_ids=wm_ids, |
| #337 | source="beam_consolidation", |
| #338 | importance=0.4, |
| #339 | scope="global", |
| #340 | ) |
| #341 | stats["ep_count"] += 1 |
| #342 | |
| #343 | # Remove consolidated items from working memory |
| #344 | placeholders = ",".join("?" * len(wm_ids)) |
| #345 | cursor.execute(f"DELETE FROM working_memory WHERE id IN ({placeholders})", wm_ids) |
| #346 | stats["wm_count"] -= len(wm_ids) |
| #347 | beam.conn.commit() |
| #348 | except Exception: |
| #349 | pass # Best-effort consolidation |
| #350 | |
| #351 | stats["ingest_time_ms"] = (time.perf_counter() - start_time) * 1000 |
| #352 | return stats |
| #353 | |
| #354 | |
| #355 | def ingest_for_ablation(beam: BeamMemory, messages: List[Dict], |
| #356 | mode: str) -> Dict: |
| #357 | """Ingest with specific ablation mode.""" |
| #358 | use_episodic = mode not in ("no_episodic",) |
| #359 | use_scratchpad = mode not in ("no_scratchpad",) |
| #360 | return ingest_conversation(beam, messages, use_scratchpad, use_episodic) |
| #361 | |
| #362 | |
| #363 | # --- Retrieval Evaluation --- |
| #364 | |
| #365 | def compute_relevance(retrieved_content: str, ideal_answer: str, |
| #366 | use_embeddings: bool = True) -> float: |
| #367 | """ |
| #368 | Compute relevance score between retrieved content and ideal answer. |
| #369 | Hybrid: token overlap + containment + embedding cosine similarity. |
| #370 | |
| #371 | Token overlap and containment capture exact lexical matches. |
| #372 | Embedding similarity captures semantic relevance (critical for |
| #373 | evaluating vector-based retrievers fairly). |
| #374 | """ |
| #375 | if not retrieved_content or not ideal_answer: |
| #376 | return 0.0 |
| #377 | |
| #378 | # 1. Token overlap (Jaccard-like) — 30% weight |
| #379 | ret_tokens = set(retrieved_content.lower().split()) |
| #380 | ans_tokens = set(ideal_answer.lower().split()) |
| #381 | if not ans_tokens: |
| #382 | return 0.0 |
| #383 | |
| #384 | jaccard = len(ret_tokens & ans_tokens) / len(ret_tokens | ans_tokens) if ret_tokens | ans_tokens else 0.0 |
| #385 | |
| #386 | # 2. Substring containment — 30% weight |
| #387 | ideal_lower = ideal_answer.lower() |
| #388 | ret_lower = retrieved_content.lower() |
| #389 | containment_score = 0.0 |
| #390 | if ideal_lower in ret_lower or ret_lower in ideal_lower: |
| #391 | containment_score = 1.0 |
| #392 | else: |
| #393 | ans_words = ideal_lower.split() |
| #394 | if len(ans_words) >= 3: |
| #395 | matches = 0 |
| #396 | for i in range(len(ans_words) - 2): |
| #397 | phrase = " ".join(ans_words[i:i+3]) |
| #398 | if phrase in ret_lower: |
| #399 | matches += 1 |
| #400 | containment_score = min(1.0, matches / max(1, len(ans_words) - 2)) |
| #401 | |
| #402 | # 3. Embedding cosine similarity — 40% weight |
| #403 | embed_score = 0.0 |
| #404 | if use_embeddings: |
| #405 | try: |
| #406 | from mnemosyne.core import embeddings as _emb_eval |
| #407 | if _emb_eval.available(): |
| #408 | # Truncate long texts to avoid OOM |
| #409 | ret_text = retrieved_content[:1000] |
| #410 | ans_text = ideal_answer[:1000] |
| #411 | vecs = _emb_eval.embed([ret_text, ans_text]) |
| #412 | if vecs is not None and len(vecs) == 2: |
| #413 | a, b = vecs[0], vecs[1] |
| #414 | a_norm = a / (np.linalg.norm(a) + 1e-8) |
| #415 | b_norm = b / (np.linalg.norm(b) + 1e-8) |
| #416 | cosine = float(np.dot(a_norm, b_norm)) |
| #417 | # Cosine ranges [-1, 1]. Map to [0, 1] via (cos+1)/2, |
| #418 | # but most embeddings cluster positive, so use max(0, cos) |
| #419 | embed_score = max(0.0, cosine) |
| #420 | except Exception: |
| #421 | pass # Embedding eval is best-effort |
| #422 | |
| #423 | if embed_score > 0: |
| #424 | return 0.3 * jaccard + 0.3 * containment_score + 0.4 * embed_score |
| #425 | else: |
| #426 | # Fallback: pure lexical (50/50) |
| #427 | return 0.5 * jaccard + 0.5 * containment_score |
| #428 | |
| #429 | |
| #430 | def evaluate_retrieval(beam: BeamMemory, questions: List[Dict], top_k: int = 10) -> Dict: |
| #431 | """Evaluate retrieval quality for a set of probing questions.""" |
| #432 | if not questions: |
| #433 | return {} |
| #434 | |
| #435 | metrics = { |
| #436 | "recall": {k: [] for k in [1, 3, 5, 10]}, |
| #437 | "mrr": [], |
| #438 | "ndcg": {k: [] for k in [1, 3, 5, 10]}, |
| #439 | "latency_ms": [], |
| #440 | "relevance_scores": [], |
| #441 | } |
| #442 | |
| #443 | # Limit questions for benchmarking |
| #444 | questions = questions[:BENCHMARK_QUERIES_PER_SCALE] |
| #445 | |
| #446 | for q in questions: |
| #447 | query = q["question"] |
| #448 | ideal = q["ideal_answer"] |
| #449 | |
| #450 | # Time the recall |
| #451 | t0 = time.perf_counter() |
| #452 | try: |
| #453 | results = beam.recall(query, top_k=top_k) |
| #454 | except Exception as e: |
| #455 | print(f" Recall error for '{query[:60]}...': {e}") |
| #456 | results = [] |
| #457 | latency = (time.perf_counter() - t0) * 1000 |
| #458 | metrics["latency_ms"].append(latency) |
| #459 | |
| #460 | # Compute relevance for each retrieved result |
| #461 | relevances = [] |
| #462 | for r in results: |
| #463 | content = r.get("content", "") |
| #464 | rel = compute_relevance(content, ideal) |
| #465 | relevances.append(rel) |
| #466 | |
| #467 | if not relevances: |
| #468 | relevances = [0.0] |
| #469 | |
| #470 | # Binary relevance: is there any relevant result in top-K? |
| #471 | # Threshold: relevance > 0.15 means "contains useful information" |
| #472 | RELEVANCE_THRESHOLD = 0.15 |
| #473 | binary_relevance = [1.0 if r >= RELEVANCE_THRESHOLD else 0.0 for r in relevances] |
| #474 | |
| #475 | # Recall@K |
| #476 | for k in [1, 3, 5, 10]: |
| #477 | if k <= len(binary_relevance): |
| #478 | metrics["recall"][k].append( |
| #479 | 1.0 if sum(binary_relevance[:k]) > 0 else 0.0 |
| #480 | ) |
| #481 | else: |
| #482 | metrics["recall"][k].append( |
| #483 | 1.0 if sum(binary_relevance) > 0 else 0.0 |
| #484 | ) |
| #485 | |
| #486 | # MRR |
| #487 | for rank, rel in enumerate(relevances, 1): |
| #488 | if rel >= RELEVANCE_THRESHOLD: |
| #489 | metrics["mrr"].append(1.0 / rank) |
| #490 | break |
| #491 | else: |
| #492 | metrics["mrr"].append(0.0) |
| #493 | |
| #494 | # NDCG@K |
| #495 | for k in [1, 3, 5, 10]: |
| #496 | dcg = sum( |
| #497 | (2**rel - 1) / math.log2(i + 2) |
| #498 | for i, rel in enumerate(relevances[:k]) |
| #499 | ) |
| #500 | # Ideal DCG: all relevant results at top |
| #501 | ideal_rels = sorted(relevances, reverse=True)[:k] |
| #502 | idcg = sum( |
| #503 | (2**rel - 1) / math.log2(i + 2) |
| #504 | for i, rel in enumerate(ideal_rels) |
| #505 | ) |
| #506 | ndcg = dcg / idcg if idcg > 0 else 0.0 |
| #507 | metrics["ndcg"][k].append(ndcg) |
| #508 | |
| #509 | metrics["relevance_scores"].append(max(relevances) if relevances else 0.0) |
| #510 | |
| #511 | return metrics |
| #512 | |
| #513 | |
| #514 | def compute_robustness(recall_values: List[float], delta: float) -> float: |
| #515 | """Robustness-δ@K: fraction of queries with recall >= delta.""" |
| #516 | if not recall_values: |
| #517 | return 0.0 |
| #518 | return sum(1.0 for r in recall_values if r >= delta) / len(recall_values) |
| #519 | |
| #520 | |
| #521 | def aggregate_metrics(metrics: Dict) -> Dict: |
| #522 | """Compute aggregate statistics from per-query metrics.""" |
| #523 | agg = {} |
| #524 | |
| #525 | # Recall@K |
| #526 | for k, vals in metrics.get("recall", {}).items(): |
| #527 | if vals: |
| #528 | agg[f"recall@{k}"] = statistics.mean(vals) |
| #529 | for delta in ROBUSTNESS_DELTAS: |
| #530 | agg[f"robustness_{delta}@k{k}"] = compute_robustness(vals, delta) |
| #531 | |
| #532 | # MRR |
| #533 | mrr_vals = metrics.get("mrr", []) |
| #534 | agg["mrr"] = statistics.mean(mrr_vals) if mrr_vals else 0.0 |
| #535 | |
| #536 | # NDCG@K |
| #537 | for k, vals in metrics.get("ndcg", {}).items(): |
| #538 | if vals: |
| #539 | agg[f"ndcg@{k}"] = statistics.mean(vals) |
| #540 | |
| #541 | # Latency |
| #542 | lat_vals = metrics.get("latency_ms", []) |
| #543 | if lat_vals: |
| #544 | sorted_lat = sorted(lat_vals) |
| #545 | agg["latency_avg_ms"] = statistics.mean(lat_vals) |
| #546 | agg["latency_p50_ms"] = sorted_lat[int(len(sorted_lat) * 0.50)] |
| #547 | agg["latency_p95_ms"] = sorted_lat[min(int(len(sorted_lat) * 0.95), len(sorted_lat) - 1)] |
| #548 | agg["latency_p99_ms"] = sorted_lat[min(int(len(sorted_lat) * 0.99), len(sorted_lat) - 1)] |
| #549 | agg["latency_min_ms"] = min(lat_vals) |
| #550 | agg["latency_max_ms"] = max(lat_vals) |
| #551 | |
| #552 | # Throughput |
| #553 | if lat_vals: |
| #554 | agg["qps"] = 1000.0 / statistics.mean(lat_vals) if statistics.mean(lat_vals) > 0 else 0 |
| #555 | |
| #556 | # Average relevance |
| #557 | rel_vals = metrics.get("relevance_scores", []) |
| #558 | agg["avg_relevance"] = statistics.mean(rel_vals) if rel_vals else 0.0 |
| #559 | |
| #560 | return agg |
| #561 | |
| #562 | |
| #563 | # --- Baseline Retrievers --- |
| #564 | |
| #565 | class BaselineRetriever: |
| #566 | """Base class for baseline retrievers that bypass BEAM entirely.""" |
| #567 | |
| #568 | def __init__(self, db_path: Path): |
| #569 | import sqlite3 |
| #570 | self._db_path = db_path |
| #571 | self.conn = sqlite3.connect(str(db_path)) |
| #572 | self.conn.row_factory = sqlite3.Row |
| #573 | |
| #574 | def search(self, query: str, top_k: int = 10) -> List[Dict]: |
| #575 | raise NotImplementedError |
| #576 | |
| #577 | def close(self): |
| #578 | self.conn.close() |
| #579 | |
| #580 | |
| #581 | class KeywordRetriever(BaselineRetriever): |
| #582 | """Simple keyword matching: no FTS5, no vectors. Searches both working + episodic.""" |
| #583 | |
| #584 | def search(self, query: str, top_k: int = 10) -> List[Dict]: |
| #585 | query_words = set(query.lower().split()) |
| #586 | if not query_words: |
| #587 | return [] |
| #588 | |
| #589 | cursor = self.conn.cursor() |
| #590 | all_rows = [] |
| #591 | |
| #592 | # Search episodic memory |
| #593 | cursor.execute(""" |
| #594 | SELECT id, content, source, timestamp, importance, 'episodic' as tier |
| #595 | FROM episodic_memory |
| #596 | ORDER BY timestamp DESC |
| #597 | LIMIT 50000 |
| #598 | """) |
| #599 | all_rows.extend(dict(row) for row in cursor.fetchall()) |
| #600 | |
| #601 | # Search working memory |
| #602 | cursor.execute(""" |
| #603 | SELECT id, content, source, timestamp, importance, 'working' as tier |
| #604 | FROM working_memory |
| #605 | ORDER BY timestamp DESC |
| #606 | LIMIT 50000 |
| #607 | """) |
| #608 | all_rows.extend(dict(row) for row in cursor.fetchall()) |
| #609 | |
| #610 | scored = [] |
| #611 | for row in all_rows: |
| #612 | content = (row.get("content") or "").lower() |
| #613 | score = sum(1 for w in query_words if w in content) |
| #614 | if score > 0: |
| #615 | scored.append((score, row)) |
| #616 | |
| #617 | scored.sort(key=lambda x: x[0], reverse=True) |
| #618 | return [item[1] for item in scored[:top_k]] |
| #619 | |
| #620 | |
| #621 | class FTS5OnlyRetriever(BaselineRetriever): |
| #622 | """FTS5 text search only, no vector embedding. Searches both working + episodic FTS5.""" |
| #623 | |
| #624 | def search(self, query: str, top_k: int = 10) -> List[Dict]: |
| #625 | cursor = self.conn.cursor() |
| #626 | results = [] |
| #627 | |
| #628 | # Search episodic FTS5 |
| #629 | try: |
| #630 | cursor.execute(""" |
| #631 | SELECT e.id, e.content, e.source, e.timestamp, e.importance, f.rank, 'episodic' as tier |
| #632 | FROM fts_episodes f |
| #633 | JOIN episodic_memory e ON f.rowid = e.rowid |
| #634 | WHERE fts_episodes MATCH ? |
| #635 | ORDER BY f.rank |
| #636 | LIMIT ? |
| #637 | """, (query, top_k)) |
| #638 | for row in cursor.fetchall(): |
| #639 | d = dict(row) |
| #640 | d["score"] = 1.0 / (1.0 + row["rank"]) if row["rank"] else 0.5 |
| #641 | results.append(d) |
| #642 | except Exception: |
| #643 | pass |
| #644 | |
| #645 | # Search working FTS5 |
| #646 | try: |
| #647 | cursor.execute(""" |
| #648 | SELECT wm.id, wm.content, wm.source, wm.timestamp, wm.importance, wf.rank, 'working' as tier |
| #649 | FROM fts_working wf |
| #650 | JOIN working_memory wm ON wf.id = wm.id |
| #651 | WHERE fts_working MATCH ? |
| #652 | ORDER BY wf.rank |
| #653 | LIMIT ? |
| #654 | """, (query, top_k)) |
| #655 | for row in cursor.fetchall(): |
| #656 | d = dict(row) |
| #657 | d["score"] = 1.0 / (1.0 + row["rank"]) if row["rank"] else 0.5 |
| #658 | results.append(d) |
| #659 | except Exception: |
| #660 | pass |
| #661 | |
| #662 | if not results: |
| #663 | return KeywordRetriever(self._db_path).search(query, top_k) |
| #664 | |
| #665 | # Sort by score descending, keep top_k |
| #666 | results.sort(key=lambda x: x.get("score", 0), reverse=True) |
| #667 | return results[:top_k] |
| #668 | |
| #669 | |
| #670 | class VecOnlyRetriever(BaselineRetriever): |
| #671 | """Vector-only search (sqlite-vec), no FTS5.""" |
| #672 | |
| #673 | def search(self, query: str, top_k: int = 10) -> List[Dict]: |
| #674 | if not _embeddings.available(): |
| #675 | return KeywordRetriever(self.conn.path).search(query, top_k) |
| #676 | |
| #677 | query_vec = _embeddings.embed([query]) |
| #678 | if query_vec is None: |
| #679 | return [] |
| #680 | |
| #681 | cursor = self.conn.cursor() |
| #682 | try: |
| #683 | vec_json = json.dumps(query_vec[0].tolist()) |
| #684 | cursor.execute(f""" |
| #685 | SELECT e.id, e.content, e.source, e.timestamp, e.importance, |
| #686 | vec_distance_L2(?, e.rowid) as distance |
| #687 | FROM vec_episodes v |
| #688 | JOIN episodic_memory e ON v.rowid = e.rowid |
| #689 | ORDER BY distance |
| #690 | LIMIT ? |
| #691 | """, (vec_json, top_k)) |
| #692 | rows = cursor.fetchall() |
| #693 | except Exception: |
| #694 | # Fallback if vec_episodes doesn't exist or L2 function not available |
| #695 | return KeywordRetriever(self.conn.path).search(query, top_k) |
| #696 | |
| #697 | max_dist = max((row["distance"] for row in rows), default=1.0) |
| #698 | results = [] |
| #699 | for row in rows: |
| #700 | d = dict(row) |
| #701 | d["score"] = 1.0 - (row["distance"] / max_dist) if max_dist > 0 else 1.0 |
| #702 | results.append(d) |
| #703 | return results |
| #704 | |
| #705 | |
| #706 | # --- Benchmark Runner --- |
| #707 | |
| #708 | def run_benchmark_scale(scale: str, conversations: List[Dict], |
| #709 | modes: List[str] = None, |
| #710 | top_k: int = DEFAULT_TOP_K) -> Dict: |
| #711 | """Run benchmark for a specific scale across all modes.""" |
| #712 | if modes is None: |
| #713 | modes = ["full", "fts5_only", "vec_only", "keyword_only", "no_scratchpad", "no_episodic"] |
| #714 | |
| #715 | results = {} |
| #716 | total_messages = sum(len(c.get("messages", [])) for c in conversations) |
| #717 | total_questions = sum(len(c.get("questions", [])) for c in conversations) |
| #718 | print(f"\n{'='*70}") |
| #719 | print(f" SCALE: {scale} | Conversations: {len(conversations)} | Messages: {total_messages:,} | Questions: {total_questions}") |
| #720 | print(f"{'='*70}") |
| #721 | |
| #722 | for mode in modes: |
| #723 | print(f"\n --- Mode: {mode} ---") |
| #724 | |
| #725 | with tempfile.TemporaryDirectory() as tmpdir: |
| #726 | db_path = Path(tmpdir) / f"bench_{scale}_{mode}.db" |
| #727 | init_beam(db_path) |
| #728 | |
| #729 | use_episodic = mode not in ("no_episodic",) |
| #730 | use_scratchpad = mode not in ("no_scratchpad",) |
| #731 | beam = BeamMemory(session_id=f"beam_{scale}_{mode}", db_path=db_path) |
| #732 | |
| #733 | # --- Ingest --- |
| #734 | ingest_start = time.perf_counter() |
| #735 | total_ingest_stats = {"wm_count": 0, "ep_count": 0, "sp_count": 0, "total_chars": 0} |
| #736 | for conv in conversations: |
| #737 | stats = ingest_conversation( |
| #738 | beam, conv["messages"], |
| #739 | use_scratchpad=use_scratchpad, |
| #740 | use_episodic=use_episodic, |
| #741 | ) |
| #742 | for k in total_ingest_stats: |
| #743 | if k in stats: |
| #744 | total_ingest_stats[k] += stats[k] |
| #745 | |
| #746 | ingest_time = time.perf_counter() - ingest_start |
| #747 | |
| #748 | # DB size - get stats BEFORE closing connection |
| #749 | db_size = os.path.getsize(db_path) |
| #750 | wm_stats = beam.get_working_stats() |
| #751 | ep_stats = beam.get_episodic_stats() |
| #752 | |
| #753 | print(f" Ingest: {fmt_ms(ingest_time*1000)} | " |
| #754 | f"WM: {wm_stats.get('total', 0)} | EP: {ep_stats.get('total', 0)} | " |
| #755 | f"SP: {total_ingest_stats['sp_count']} | DB: {fmt_size(db_size)}") |
| #756 | |
| #757 | # --- Retrieval (BEAM native) --- |
| #758 | if mode in ("full", "no_scratchpad", "no_episodic"): |
| #759 | all_metrics = {"recall": {k: [] for k in [1, 3, 5, 10]}, |
| #760 | "mrr": [], "ndcg": {k: [] for k in [1, 3, 5, 10]}, |
| #761 | "latency_ms": [], "relevance_scores": []} |
| #762 | |
| #763 | for conv in conversations: |
| #764 | conv_metrics = evaluate_retrieval(beam, conv["questions"], top_k=top_k) |
| #765 | for key in all_metrics: |
| #766 | if isinstance(all_metrics[key], dict): |
| #767 | for subkey in all_metrics[key]: |
| #768 | all_metrics[key][subkey].extend(conv_metrics.get(key, {}).get(subkey, [])) |
| #769 | else: |
| #770 | all_metrics[key].extend(conv_metrics.get(key, [])) |
| #771 | |
| #772 | agg = aggregate_metrics(all_metrics) |
| #773 | agg["ingest_time_ms"] = ingest_time * 1000 |
| #774 | agg["db_size_bytes"] = db_size |
| #775 | agg["db_size"] = fmt_size(db_size) |
| #776 | agg["wm_items"] = wm_stats.get("total", 0) |
| #777 | agg["ep_items"] = ep_stats.get("total", 0) |
| #778 | agg["ep_vectors"] = ep_stats.get("vectors", 0) |
| #779 | agg["messages_ingested"] = total_messages |
| #780 | agg["questions_evaluated"] = min(total_questions, BENCHMARK_QUERIES_PER_SCALE * len(conversations)) |
| #781 | results[mode] = agg |
| #782 | |
| #783 | # --- Baseline retrievers --- |
| #784 | elif mode in ("fts5_only", "vec_only", "keyword_only"): |
| #785 | # Close beam connection so retriever can open its own |
| #786 | try: |
| #787 | beam.conn.close() |
| #788 | except Exception: |
| #789 | pass |
| #790 | |
| #791 | retriever_class = { |
| #792 | "fts5_only": FTS5OnlyRetriever, |
| #793 | "vec_only": VecOnlyRetriever, |
| #794 | "keyword_only": KeywordRetriever, |
| #795 | }[mode] |
| #796 | |
| #797 | retriever = retriever_class(db_path) |
| #798 | all_metrics = {"recall": {k: [] for k in [1, 3, 5, 10]}, |
| #799 | "mrr": [], "ndcg": {k: [] for k in [1, 3, 5, 10]}, |
| #800 | "latency_ms": [], "relevance_scores": []} |
| #801 | |
| #802 | for conv in conversations: |
| #803 | for q in conv["questions"][:BENCHMARK_QUERIES_PER_SCALE]: |
| #804 | query = q["question"] |
| #805 | ideal = q["ideal_answer"] |
| #806 | |
| #807 | t0 = time.perf_counter() |
| #808 | try: |
| #809 | results_list = retriever.search(query, top_k=top_k) |
| #810 | except Exception as e: |
| #811 | results_list = [] |
| #812 | latency = (time.perf_counter() - t0) * 1000 |
| #813 | all_metrics["latency_ms"].append(latency) |
| #814 | |
| #815 | relevances = [compute_relevance(r.get("content", ""), ideal) for r in results_list] |
| #816 | if not relevances: |
| #817 | relevances = [0.0] |
| #818 | |
| #819 | RELEVANCE_THRESHOLD = 0.15 |
| #820 | binary = [1.0 if r >= RELEVANCE_THRESHOLD else 0.0 for r in relevances] |
| #821 | |
| #822 | for k in [1, 3, 5, 10]: |
| #823 | all_metrics["recall"][k].append(1.0 if sum(binary[:k]) > 0 else 0.0) |
| #824 | |
| #825 | for rank, rel in enumerate(relevances, 1): |
| #826 | if rel >= RELEVANCE_THRESHOLD: |
| #827 | all_metrics["mrr"].append(1.0 / rank) |
| #828 | break |
| #829 | else: |
| #830 | all_metrics["mrr"].append(0.0) |
| #831 | |
| #832 | for k in [1, 3, 5, 10]: |
| #833 | dcg = sum((2**rel - 1) / math.log2(i + 2) for i, rel in enumerate(relevances[:k])) |
| #834 | ideal_rels = sorted(relevances, reverse=True)[:k] |
| #835 | idcg = sum((2**rel - 1) / math.log2(i + 2) for i, rel in enumerate(ideal_rels)) |
| #836 | all_metrics["ndcg"][k].append(dcg / idcg if idcg > 0 else 0.0) |
| #837 | |
| #838 | all_metrics["relevance_scores"].append(max(relevances)) |
| #839 | |
| #840 | agg = aggregate_metrics(all_metrics) |
| #841 | agg["ingest_time_ms"] = ingest_time * 1000 |
| #842 | agg["db_size_bytes"] = db_size |
| #843 | agg["db_size"] = fmt_size(db_size) |
| #844 | agg["messages_ingested"] = total_messages |
| #845 | results[mode] = agg |
| #846 | retriever.close() |
| #847 | |
| #848 | beam.conn.close() |
| #849 | # Release thread-local connection to prevent "Too many open files" |
| #850 | import gc |
| #851 | gc.collect() |
| #852 | |
| #853 | return results |
| #854 | |
| #855 | |
| #856 | # --- Report Generation --- |
| #857 | |
| #858 | def print_results(all_results: Dict): |
| #859 | """Print formatted benchmark results.""" |
| #860 | print(f"\n\n{'='*80}") |
| #861 | print(f" MNEMOSYNE BEAM SOTA BENCHMARK RESULTS") |
| #862 | print(f" Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
| #863 | print(f" Embedding: BAAI/bge-small-en-v1.5 ({EMBEDDING_DIM}d, {VEC_TYPE})") |
| #864 | print(f" Top-K: {DEFAULT_TOP_K}") |
| #865 | print(f"{'='*80}") |
| #866 | |
| #867 | for scale, modes in sorted(all_results.items()): |
| #868 | print(f"\n{'─'*80}") |
| #869 | print(f" SCALE: {scale}") |
| #870 | print(f"{'─'*80}") |
| #871 | |
| #872 | # Header |
| #873 | mode_names = list(modes.keys()) |
| #874 | print(f" {'Metric':<30}", end="") |
| #875 | for mode in mode_names: |
| #876 | print(f" {mode:<18}", end="") |
| #877 | print() |
| #878 | |
| #879 | # Data rows |
| #880 | metrics_to_show = [ |
| #881 | "recall@1", "recall@3", "recall@5", "recall@10", |
| #882 | "mrr", "ndcg@10", |
| #883 | "robustness_0.3@k10", |
| #884 | "latency_avg_ms", "latency_p95_ms", |
| #885 | "qps", "avg_relevance", |
| #886 | "messages_ingested", "db_size", "wm_items", "ep_items", |
| #887 | ] |
| #888 | |
| #889 | for metric in metrics_to_show: |
| #890 | print(f" {metric:<30}", end="") |
| #891 | for mode in mode_names: |
| #892 | data = modes.get(mode, {}) |
| #893 | val = data.get(metric, "-") |
| #894 | if isinstance(val, float): |
| #895 | if "latency" in metric: |
| #896 | print(f" {fmt_ms(val):<18}", end="") |
| #897 | elif "qps" in metric: |
| #898 | print(f" {val:.1f} qps{'':>12}", end="") |
| #899 | elif "recall" in metric or "robustness" in metric or "mrr" in metric or "ndcg" in metric: |
| #900 | print(f" {pcnt(val):<18}", end="") |
| #901 | elif "relevance" in metric: |
| #902 | print(f" {val:.4f}{'':>13}", end="") |
| #903 | else: |
| #904 | print(f" {val:<18.2f}", end="") |
| #905 | else: |
| #906 | print(f" {str(val):<18}", end="") |
| #907 | print() |
| #908 | |
| #909 | # --- SOTA Comparison Table --- |
| #910 | print(f"\n\n{'='*80}") |
| #911 | print(f" SOTA COMPARISON: Mnemosyne BEAM vs Published Baselines") |
| #912 | print(f"{'='*80}") |
| #913 | print(f" Note: Published numbers from Tavakoli et al., ICLR 2026 (Table 3)") |
| #914 | print(f" Mnemosyne uses identical BEAM dataset; metrics are retrieval-only (no LLM generation).") |
| #915 | print(f" Published numbers are end-to-end QA accuracy with LLM-as-judge.") |
| #916 | print(f" Direct comparison is APPROXIMATE -- retrieval quality correlates with QA accuracy.") |
| #917 | print(f"") |
| #918 | |
| #919 | # The paper's key finding: LIGHT framework improves 3.5%-12.69% over baselines |
| #920 | # We want to show Mnemosyne's retrieval quality at each scale |
| #921 | print(f" Methodology per ICLR 2026 paper:") |
| #922 | print(f" - BEAM dataset: 100 conversations, 2,000 probing questions") |
| #923 | print(f" - 10 memory abilities tested") |
| #924 | print(f" - LIGHT framework: episodic + working + scratchpad (identical to Mnemosyne BEAM)") |
| #925 | print(f" - Key metric: Robustness-δ@K (δ=0.3) for retrieval reliability") |
| #926 | print(f"") |
| #927 | |
| #928 | # Find best mode per scale |
| #929 | for scale in sorted(all_results.keys()): |
| #930 | modes = all_results[scale] |
| #931 | full = modes.get("full", {}) |
| #932 | if not full: |
| #933 | continue |
| #934 | print(f" Scale {scale}:") |
| #935 | print(f" Mnemosyne Recall@10: {pcnt(full.get('recall@10', 0))}") |
| #936 | print(f" Mnemosyne MRR: {full.get('mrr', 0):.4f}") |
| #937 | print(f" Robustness-0.3@10: {pcnt(full.get('robustness_0.3@k10', 0))}") |
| #938 | print(f" Avg Latency: {fmt_ms(full.get('latency_avg_ms', 0))}") |
| #939 | print(f" P95 Latency: {fmt_ms(full.get('latency_p95_ms', 0))}") |
| #940 | print(f" QPS (queries/sec): {full.get('qps', 0):.1f}") |
| #941 | print(f" DB Size: {full.get('db_size', 'N/A')}") |
| #942 | print(f"") |
| #943 | |
| #944 | print(f"{'='*80}") |
| #945 | print(f" BENCHMARK COMPLETE") |
| #946 | print(f"{'='*80}") |
| #947 | |
| #948 | |
| #949 | # --- Main --- |
| #950 | |
| #951 | def main(): |
| #952 | parser = argparse.ArgumentParser(description="BEAM SOTA Benchmark for Mnemosyne") |
| #953 | parser.add_argument("--scales", type=str, default="100K,500K,1M", |
| #954 | help="Comma-separated scales to benchmark (100K,500K,1M,10M)") |
| #955 | parser.add_argument("--top-k", type=int, default=DEFAULT_TOP_K, |
| #956 | help=f"Top-K for retrieval (default: {DEFAULT_TOP_K})") |
| #957 | parser.add_argument("--max-conv", type=int, default=None, |
| #958 | help="Max conversations per scale (default: all)") |
| #959 | parser.add_argument("--modes", type=str, |
| #960 | default="full,fts5_only,vec_only,keyword_only,no_scratchpad,no_episodic", |
| #961 | help="Comma-separated modes to benchmark") |
| #962 | parser.add_argument("--output", type=str, default=None, |
| #963 | help="Output JSON file for results") |
| #964 | parser.add_argument("--skip-10m", action="store_true", |
| #965 | help="Skip 10M scale (very large)") |
| #966 | |
| #967 | args = parser.parse_args() |
| #968 | scales = [s.strip() for s in args.scales.split(",")] |
| #969 | modes = [m.strip() for m in args.modes.split(",")] |
| #970 | |
| #971 | has_10m = "10M" in scales |
| #972 | if has_10m: |
| #973 | scales.remove("10M") |
| #974 | |
| #975 | print(f"╔{'═'*78}╗") |
| #976 | print(f"║ MNEMOSYNE BEAM SOTA BENCHMARK ║") |
| #977 | print(f"║ ICLR 2026 BEAM Dataset: Beyond a Million Tokens ║") |
| #978 | print(f"║ Scales: {', '.join(scales):<67s}║") |
| #979 | print(f"║ Modes: {', '.join(modes):<68s}║") |
| #980 | print(f"║ Top-K: {args.top_k:<70d}║") |
| #981 | print(f"╚{'═'*78}╝") |
| #982 | |
| #983 | # --- Download dataset --- |
| #984 | print(f"\n📥 Downloading BEAM dataset...") |
| #985 | data = load_beam_dataset(scales, max_conversations=args.max_conv) |
| #986 | _gc.collect() # Ensure all dataset connections are released |
| #987 | |
| #988 | # --- Run benchmarks --- |
| #989 | all_results = {} |
| #990 | total_start = time.perf_counter() |
| #991 | |
| #992 | for scale in sorted(data.keys()): |
| #993 | conversations = data[scale] |
| #994 | scale_results = run_benchmark_scale( |
| #995 | scale, conversations, |
| #996 | modes=modes, |
| #997 | top_k=args.top_k, |
| #998 | ) |
| #999 | all_results[scale] = scale_results |
| #1000 | |
| #1001 | # --- 10M scale (if requested) --- |
| #1002 | if has_10m and not args.skip_10m: |
| #1003 | print(f"\n📥 Loading BEAM-10M dataset...") |
| #1004 | convs_10m = load_beam_10m(max_conversations=args.max_conv) |
| #1005 | if convs_10m: |
| #1006 | scale_results = run_benchmark_scale( |
| #1007 | "10M", convs_10m, |
| #1008 | modes=modes, |
| #1009 | top_k=args.top_k, |
| #1010 | ) |
| #1011 | all_results["10M"] = scale_results |
| #1012 | else: |
| #1013 | print(" WARNING: Could not load 10M dataset. Skipping.") |
| #1014 | |
| #1015 | total_time = time.perf_counter() - total_start |
| #1016 | print(f"\n⏱️ Total benchmark time: {fmt_ms(total_time*1000)}") |
| #1017 | |
| #1018 | # --- Print results --- |
| #1019 | print_results(all_results) |
| #1020 | |
| #1021 | # --- Save results --- |
| #1022 | if args.output: |
| #1023 | output_path = Path(args.output) |
| #1024 | # Convert to serializable format |
| #1025 | serializable = {} |
| #1026 | for scale, modes in all_results.items(): |
| #1027 | serializable[scale] = {} |
| #1028 | for mode, metrics in modes.items(): |
| #1029 | serializable[scale][mode] = { |
| #1030 | k: (v if not isinstance(v, float) or math.isfinite(v) else 0.0) |
| #1031 | for k, v in metrics.items() |
| #1032 | } |
| #1033 | |
| #1034 | serializable["_meta"] = { |
| #1035 | "date": datetime.now().isoformat(), |
| #1036 | "benchmark": "ICLR 2026 BEAM", |
| #1037 | "framework": "Mnemosyne BEAM Architecture", |
| #1038 | "scales": list(all_results.keys()), |
| #1039 | "modes": modes, |
| #1040 | "top_k": args.top_k, |
| #1041 | "total_time_ms": total_time * 1000, |
| #1042 | "embedding_model": "BAAI/bge-small-en-v1.5", |
| #1043 | "embedding_dim": EMBEDDING_DIM, |
| #1044 | "vec_type": VEC_TYPE, |
| #1045 | } |
| #1046 | |
| #1047 | output_path.parent.mkdir(parents=True, exist_ok=True) |
| #1048 | with open(output_path, "w") as f: |
| #1049 | json.dump(serializable, f, indent=2) |
| #1050 | print(f"\n📁 Results saved to: {output_path}") |
| #1051 | |
| #1052 | return all_results |
| #1053 | |
| #1054 | |
| #1055 | if __name__ == "__main__": |
| #1056 | main() |
| #1057 |