repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | from typing import List, Dict, Any, Union |
| #2 | import numpy as np |
| #3 | |
| #4 | from mem0.reranker.base import BaseReranker |
| #5 | from mem0.configs.rerankers.base import BaseRerankerConfig |
| #6 | from mem0.configs.rerankers.sentence_transformer import SentenceTransformerRerankerConfig |
| #7 | |
| #8 | try: |
| #9 | from sentence_transformers import SentenceTransformer |
| #10 | SENTENCE_TRANSFORMERS_AVAILABLE = True |
| #11 | except ImportError: |
| #12 | SENTENCE_TRANSFORMERS_AVAILABLE = False |
| #13 | |
| #14 | |
| #15 | class SentenceTransformerReranker(BaseReranker): |
| #16 | """Sentence Transformer based reranker implementation.""" |
| #17 | |
| #18 | def __init__(self, config: Union[BaseRerankerConfig, SentenceTransformerRerankerConfig, Dict]): |
| #19 | """ |
| #20 | Initialize Sentence Transformer reranker. |
| #21 | |
| #22 | Args: |
| #23 | config: Configuration object with reranker parameters |
| #24 | """ |
| #25 | if not SENTENCE_TRANSFORMERS_AVAILABLE: |
| #26 | raise ImportError("sentence-transformers package is required for SentenceTransformerReranker. Install with: pip install sentence-transformers") |
| #27 | |
| #28 | # Convert to SentenceTransformerRerankerConfig if needed |
| #29 | if isinstance(config, dict): |
| #30 | config = SentenceTransformerRerankerConfig(**config) |
| #31 | elif isinstance(config, BaseRerankerConfig) and not isinstance(config, SentenceTransformerRerankerConfig): |
| #32 | # Convert BaseRerankerConfig to SentenceTransformerRerankerConfig with defaults |
| #33 | config = SentenceTransformerRerankerConfig( |
| #34 | provider=getattr(config, 'provider', 'sentence_transformer'), |
| #35 | model=getattr(config, 'model', 'cross-encoder/ms-marco-MiniLM-L-6-v2'), |
| #36 | api_key=getattr(config, 'api_key', None), |
| #37 | top_k=getattr(config, 'top_k', None), |
| #38 | device=None, # Will auto-detect |
| #39 | batch_size=32, # Default |
| #40 | show_progress_bar=False, # Default |
| #41 | ) |
| #42 | |
| #43 | self.config = config |
| #44 | self.model = SentenceTransformer(self.config.model, device=self.config.device) |
| #45 | |
| #46 | def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]: |
| #47 | """ |
| #48 | Rerank documents using sentence transformer cross-encoder. |
| #49 | |
| #50 | Args: |
| #51 | query: The search query |
| #52 | documents: List of documents to rerank |
| #53 | top_k: Number of top documents to return |
| #54 | |
| #55 | Returns: |
| #56 | List of reranked documents with rerank_score |
| #57 | """ |
| #58 | if not documents: |
| #59 | return documents |
| #60 | |
| #61 | # Extract text content for reranking |
| #62 | doc_texts = [] |
| #63 | for doc in documents: |
| #64 | if 'memory' in doc: |
| #65 | doc_texts.append(doc['memory']) |
| #66 | elif 'text' in doc: |
| #67 | doc_texts.append(doc['text']) |
| #68 | elif 'content' in doc: |
| #69 | doc_texts.append(doc['content']) |
| #70 | else: |
| #71 | doc_texts.append(str(doc)) |
| #72 | |
| #73 | try: |
| #74 | # Create query-document pairs |
| #75 | pairs = [[query, doc_text] for doc_text in doc_texts] |
| #76 | |
| #77 | # Get similarity scores |
| #78 | scores = self.model.predict(pairs) |
| #79 | if isinstance(scores, np.ndarray): |
| #80 | scores = scores.tolist() |
| #81 | |
| #82 | # Combine documents with scores |
| #83 | doc_score_pairs = list(zip(documents, scores)) |
| #84 | |
| #85 | # Sort by score (descending) |
| #86 | doc_score_pairs.sort(key=lambda x: x[1], reverse=True) |
| #87 | |
| #88 | # Apply top_k limit |
| #89 | final_top_k = top_k or self.config.top_k |
| #90 | if final_top_k: |
| #91 | doc_score_pairs = doc_score_pairs[:final_top_k] |
| #92 | |
| #93 | # Create reranked results |
| #94 | reranked_docs = [] |
| #95 | for doc, score in doc_score_pairs: |
| #96 | reranked_doc = doc.copy() |
| #97 | reranked_doc['rerank_score'] = float(score) |
| #98 | reranked_docs.append(reranked_doc) |
| #99 | |
| #100 | return reranked_docs |
| #101 | |
| #102 | except Exception: |
| #103 | # Fallback to original order if reranking fails |
| #104 | for doc in documents: |
| #105 | doc['rerank_score'] = 0.0 |
| #106 | final_top_k = top_k or self.config.top_k |
| #107 | return documents[:final_top_k] if final_top_k else documents |