repositories
loading repo index
repositories
loading repo index
repository
loading code, commits, and activity
public Clawd ADK gateway launch mirror
stars
latest
clone command
git clone gitlawb://did:key:z6Mkq5mY...iFZ5/my-project-publ...git clone gitlawb://did:key:z6Mkq5mY.../my-project-publ...2fa351d6docs: add automaton and perps launch sources16d ago| #1 | from typing import List, Dict, Any, Union |
| #2 | import numpy as np |
| #3 | |
| #4 | from mem0.reranker.base import BaseReranker |
| #5 | from mem0.configs.rerankers.base import BaseRerankerConfig |
| #6 | from mem0.configs.rerankers.huggingface import HuggingFaceRerankerConfig |
| #7 | |
| #8 | try: |
| #9 | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| #10 | import torch |
| #11 | TRANSFORMERS_AVAILABLE = True |
| #12 | except ImportError: |
| #13 | TRANSFORMERS_AVAILABLE = False |
| #14 | |
| #15 | |
| #16 | class HuggingFaceReranker(BaseReranker): |
| #17 | """HuggingFace Transformers based reranker implementation.""" |
| #18 | |
| #19 | def __init__(self, config: Union[BaseRerankerConfig, HuggingFaceRerankerConfig, Dict]): |
| #20 | """ |
| #21 | Initialize HuggingFace reranker. |
| #22 | |
| #23 | Args: |
| #24 | config: Configuration object with reranker parameters |
| #25 | """ |
| #26 | if not TRANSFORMERS_AVAILABLE: |
| #27 | raise ImportError("transformers package is required for HuggingFaceReranker. Install with: pip install transformers torch") |
| #28 | |
| #29 | # Convert to HuggingFaceRerankerConfig if needed |
| #30 | if isinstance(config, dict): |
| #31 | config = HuggingFaceRerankerConfig(**config) |
| #32 | elif isinstance(config, BaseRerankerConfig) and not isinstance(config, HuggingFaceRerankerConfig): |
| #33 | # Convert BaseRerankerConfig to HuggingFaceRerankerConfig with defaults |
| #34 | config = HuggingFaceRerankerConfig( |
| #35 | provider=getattr(config, 'provider', 'huggingface'), |
| #36 | model=getattr(config, 'model', 'BAAI/bge-reranker-base'), |
| #37 | api_key=getattr(config, 'api_key', None), |
| #38 | top_k=getattr(config, 'top_k', None), |
| #39 | device=None, # Will auto-detect |
| #40 | batch_size=32, # Default |
| #41 | max_length=512, # Default |
| #42 | normalize=True, # Default |
| #43 | ) |
| #44 | |
| #45 | self.config = config |
| #46 | |
| #47 | # Set device |
| #48 | if self.config.device is None: |
| #49 | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| #50 | else: |
| #51 | self.device = self.config.device |
| #52 | |
| #53 | # Load model and tokenizer |
| #54 | self.tokenizer = AutoTokenizer.from_pretrained(self.config.model) |
| #55 | self.model = AutoModelForSequenceClassification.from_pretrained(self.config.model) |
| #56 | self.model.to(self.device) |
| #57 | self.model.eval() |
| #58 | |
| #59 | def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]: |
| #60 | """ |
| #61 | Rerank documents using HuggingFace cross-encoder model. |
| #62 | |
| #63 | Args: |
| #64 | query: The search query |
| #65 | documents: List of documents to rerank |
| #66 | top_k: Number of top documents to return |
| #67 | |
| #68 | Returns: |
| #69 | List of reranked documents with rerank_score |
| #70 | """ |
| #71 | if not documents: |
| #72 | return documents |
| #73 | |
| #74 | # Extract text content for reranking |
| #75 | doc_texts = [] |
| #76 | for doc in documents: |
| #77 | if 'memory' in doc: |
| #78 | doc_texts.append(doc['memory']) |
| #79 | elif 'text' in doc: |
| #80 | doc_texts.append(doc['text']) |
| #81 | elif 'content' in doc: |
| #82 | doc_texts.append(doc['content']) |
| #83 | else: |
| #84 | doc_texts.append(str(doc)) |
| #85 | |
| #86 | try: |
| #87 | scores = [] |
| #88 | |
| #89 | # Process documents in batches |
| #90 | for i in range(0, len(doc_texts), self.config.batch_size): |
| #91 | batch_docs = doc_texts[i:i + self.config.batch_size] |
| #92 | batch_pairs = [[query, doc] for doc in batch_docs] |
| #93 | |
| #94 | # Tokenize batch |
| #95 | inputs = self.tokenizer( |
| #96 | batch_pairs, |
| #97 | padding=True, |
| #98 | truncation=True, |
| #99 | max_length=self.config.max_length, |
| #100 | return_tensors="pt" |
| #101 | ).to(self.device) |
| #102 | |
| #103 | # Get scores |
| #104 | with torch.no_grad(): |
| #105 | outputs = self.model(**inputs) |
| #106 | batch_scores = outputs.logits.squeeze(-1).cpu().numpy() |
| #107 | |
| #108 | # Handle single item case |
| #109 | if batch_scores.ndim == 0: |
| #110 | batch_scores = [float(batch_scores)] |
| #111 | else: |
| #112 | batch_scores = batch_scores.tolist() |
| #113 | |
| #114 | scores.extend(batch_scores) |
| #115 | |
| #116 | # Normalize scores if requested |
| #117 | if self.config.normalize: |
| #118 | scores = np.array(scores) |
| #119 | scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8) |
| #120 | scores = scores.tolist() |
| #121 | |
| #122 | # Combine documents with scores |
| #123 | doc_score_pairs = list(zip(documents, scores)) |
| #124 | |
| #125 | # Sort by score (descending) |
| #126 | doc_score_pairs.sort(key=lambda x: x[1], reverse=True) |
| #127 | |
| #128 | # Apply top_k limit |
| #129 | final_top_k = top_k or self.config.top_k |
| #130 | if final_top_k: |
| #131 | doc_score_pairs = doc_score_pairs[:final_top_k] |
| #132 | |
| #133 | # Create reranked results |
| #134 | reranked_docs = [] |
| #135 | for doc, score in doc_score_pairs: |
| #136 | reranked_doc = doc.copy() |
| #137 | reranked_doc['rerank_score'] = float(score) |
| #138 | reranked_docs.append(reranked_doc) |
| #139 | |
| #140 | return reranked_docs |
| #141 | |
| #142 | except Exception: |
| #143 | # Fallback to original order if reranking fails |
| #144 | for doc in documents: |
| #145 | doc['rerank_score'] = 0.0 |
| #146 | final_top_k = top_k or self.config.top_k |
| #147 | return documents[:final_top_k] if final_top_k else documents |