| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| MODE = "Dense" |
|
|
| TASK_LIST_CLASSIFICATION = [ |
| "AmazonCounterfactualClassification", |
| "AmazonPolarityClassification", |
| "AmazonReviewsClassification", |
| "Banking77Classification", |
| "EmotionClassification", |
| "ImdbClassification", |
| "MassiveIntentClassification", |
| "MassiveScenarioClassification", |
| "MTOPDomainClassification", |
| "MTOPIntentClassification", |
| "ToxicConversationsClassification", |
| "TweetSentimentExtractionClassification", |
| ] |
|
|
| TASK_LIST_CLUSTERING = [ |
| "ArxivClusteringP2P", |
| "ArxivClusteringS2S", |
| "BiorxivClusteringP2P", |
| "BiorxivClusteringS2S", |
| "MedrxivClusteringP2P", |
| "MedrxivClusteringS2S", |
| "RedditClustering", |
| "RedditClusteringP2P", |
| "StackExchangeClustering", |
| "StackExchangeClusteringP2P", |
| "TwentyNewsgroupsClustering", |
| ] |
|
|
| TASK_LIST_PAIR_CLASSIFICATION = [ |
| "SprintDuplicateQuestions", |
| "TwitterSemEval2015", |
| "TwitterURLCorpus", |
| ] |
|
|
| TASK_LIST_RERANKING = [ |
| "AskUbuntuDupQuestions", |
| "MindSmallReranking", |
| "SciDocsRR", |
| "StackOverflowDupQuestions", |
| ] |
|
|
| TASK_LIST_RETRIEVAL = [ |
| "ArguAna", |
| "FiQA2018", |
| "QuoraRetrieval", |
| "SCIDOCS", |
| "SciFact", |
| "Touche2020", |
| "TRECCOVID", |
| "NFCorpus", |
| "NQ", |
| "ClimateFEVER", |
| "CQADupstackAndroidRetrieval", |
| "CQADupstackEnglishRetrieval", |
| "CQADupstackGamingRetrieval", |
| "CQADupstackGisRetrieval", |
| "CQADupstackMathematicaRetrieval", |
| "CQADupstackPhysicsRetrieval", |
| "CQADupstackProgrammersRetrieval", |
| "CQADupstackStatsRetrieval", |
| "CQADupstackTexRetrieval", |
| "CQADupstackUnixRetrieval", |
| "CQADupstackWebmastersRetrieval", |
| "CQADupstackWordpressRetrieval", |
| "DBPedia", |
| "HotpotQA", |
| "MSMARCO", |
| "FEVER", |
| ] |
|
|
| TASK_LIST_STS = [ |
| "BIOSSES", |
| "SICK-R", |
| "STS12", |
| "STS13", |
| "STS14", |
| "STS15", |
| "STS16", |
| "STS17", |
| "STS22", |
| "STSBenchmark", |
| "SummEval", |
| ] |
|
|
| MTEB_TASK_LIST = ( |
| TASK_LIST_RETRIEVAL |
| + TASK_LIST_CLASSIFICATION |
| + TASK_LIST_CLUSTERING |
| + TASK_LIST_PAIR_CLASSIFICATION |
| + TASK_LIST_RERANKING |
| + TASK_LIST_STS |
| ) |
|
|
|
|
| CMTEB_TASK_LIST = [ |
| "TNews", |
| "IFlyTek", |
| "MultilingualSentiment", |
| "JDReview", |
| "OnlineShopping", |
| "Waimai", |
| "AmazonReviewsClassification", |
| "MassiveIntentClassification", |
| "MassiveScenarioClassification", |
| "MultilingualSentiment", |
| "CLSClusteringS2S", |
| "CLSClusteringP2P", |
| "ThuNewsClusteringS2S", |
| "ThuNewsClusteringP2P", |
| "Ocnli", |
| "Cmnli", |
| "T2Reranking", |
| "MMarcoReranking", |
| "CMedQAv1-reranking", |
| "CMedQAv2-reranking", |
| "T2Retrieval", |
| "MMarcoRetrieval", |
| "DuRetrieval", |
| "CovidRetrieval", |
| "CmedqaRetrieval", |
| "EcomRetrieval", |
| "MedicalRetrieval", |
| "VideoRetrieval", |
| "ATEC", |
| "BQ", |
| "LCQMC", |
| "PAWSX", |
| "STSB", |
| "AFQMC", |
| "QBQTC", |
| "STS22", |
| ] |
|
|
| MTEB_TASK_LIST = CMTEB_TASK_LIST + MTEB_TASK_LIST |
|
|
| import torch |
| import torch.nn.functional as F |
| import tqdm |
| import numpy as np |
| import math |
|
|
| from functools import partial |
| from torch.utils.data import DataLoader |
| from datasets import Dataset |
| from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding, PreTrainedTokenizerFast, BatchEncoding |
| from transformers.modeling_outputs import BaseModelOutput |
| from typing import List, Dict |
| from mteb import MTEB |
|
|
| def get_detailed_instruct(task_description: str) -> str: |
| if not task_description: |
| return "" |
|
|
| return "Instruction: {} Query: ".format(task_description) |
|
|
|
|
|
|
| def get_task_def_by_task_name_and_type( |
| task_name: str, |
| task_type: str, |
| default_instruct="", |
| ): |
| if task_type in ["STS"]: |
| return None |
|
|
| if task_type in ["Summarization"]: |
| return "Given a news summary, retrieve other semantically similar summaries" |
|
|
| if task_type in ["Classification"]: |
| task_name_to_instruct: Dict[str, str] = { |
| "AmazonCounterfactualClassification": "Classify a given Amazon customer review text as either counterfactual or not-counterfactual.", |
| "AmazonPolarityClassification": "Classify Amazon reviews into positive or negative sentiment.", |
| "AmazonReviewsClassification": "Classify the given Amazon review into its appropriate rating category.", |
| "Banking77Classification": "Given a online banking query, find the corresponding intents.", |
| "EmotionClassification": "Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.", |
| "ImdbClassification": "Classify the sentiment expressed in the given movie review text from the IMDB dataset.", |
| "MassiveIntentClassification": "Given a user utterance as query, find the user intents.", |
| "MassiveScenarioClassification": "Given a user utterance as query, find the user scenarios.", |
| "MTOPDomainClassification": "Classify the intent domain of the given utterance in task-oriented conversation.", |
| "MTOPIntentClassification": "Classify the intent of the given utterance in task-oriented conversation.", |
| "ToxicConversationsClassification": "Classify the given comments as either toxic or not toxic.", |
| "TweetSentimentExtractionClassification": "Classify the sentiment of a given tweet as either positive, negative, or neutral.", |
| |
| "TNews": "根据标题确定新闻的类别。", |
| "IFlyTek": "根据描述确定APP的类别。", |
| "MultilingualSentiment": "将亚马逊评论分为积极、消极或中立情绪。", |
| "JDReview": "将商品评论分为积极或消极情绪。", |
| "OnlineShopping": "将商品评论分为积极或消极情绪。", |
| "Waimai": "将外卖评论分为积极或消极情绪。", |
| } |
| return task_name_to_instruct.get(task_name,None) |
|
|
| if task_type in ["Clustering"]: |
| task_name_to_instruct: Dict[str, str] = { |
| "ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts.", |
| "ArxivClusteringS2S": "Identify the main and secondary category of Arxiv papers based on the titles.", |
| "BiorxivClusteringP2P": "Identify the main category of Biorxiv papers based on the titles and abstracts.", |
| "BiorxivClusteringS2S": "Identify the main category of Biorxiv papers based on the titles.", |
| "MedrxivClusteringP2P": "Identify the main category of Medrxiv papers based on the titles and abstracts.", |
| "MedrxivClusteringS2S": "Identify the main category of Medrxiv papers based on the titles.", |
| "RedditClustering": "Identify the topic or theme of Reddit posts based on the titles.", |
| "RedditClusteringP2P": "Identify the topic or theme of Reddit posts based on the titles and posts.", |
| "StackExchangeClustering": "Identify the topic or theme of StackExchange posts based on the titles.", |
| "StackExchangeClusteringP2P": "Identify the topic or theme of StackExchange posts based on the given paragraphs.", |
| "TwentyNewsgroupsClustering": "Identify the topic or theme of the given news articles.", |
| |
| "CLSClusteringS2S": "根据标题确定文章的类别。", |
| "CLSClusteringP2P": "根据标题和摘要确定文章的类别。", |
| "ThuNewsClusteringS2S": "根据标题确定新闻的类别。", |
| "ThuNewsClusteringP2P": "根据标题和摘要确定新闻的类别。", |
| } |
| return task_name_to_instruct.get(task_name,None) |
|
|
| if task_type in ["Reranking", "PairClassification"]: |
| task_name_to_instruct: Dict[str, str] = { |
| "AskUbuntuDupQuestions": "Retrieve duplicate questions from AskUbuntu forum.", |
| "MindSmallReranking": "Retrieve relevant news articles based on user browsing history.", |
| "SciDocsRR": "Given a title of a scientific paper, retrieve the titles of other relevant papers.", |
| "StackOverflowDupQuestions": "Retrieve duplicate questions from StackOverflow forum.", |
| "SprintDuplicateQuestions": "Retrieve duplicate questions from Sprint forum.", |
| "TwitterSemEval2015": "Retrieve tweets that are semantically similar to the given tweet.", |
| "TwitterURLCorpus": "Retrieve tweets that are semantically similar to the given tweet.", |
| |
| "T2Reranking": "为这个问题检索相关段落。", |
| "MMarcoReranking": "为这个查询检索相关段落。", |
| "CMedQAv1-reranking": "为这个医疗问题检索相关回答。", |
| "CMedQAv2-reranking": "为这个医疗问题检索相关回答。", |
| } |
|
|
| return task_name_to_instruct.get(task_name,None) |
|
|
| if task_type in ["Retrieval"]: |
| if task_name.lower().startswith("cqadupstack"): |
| return "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question" |
|
|
| task_name_to_instruct: Dict[str, str] = { |
| "ArguAna": "Given a claim, find documents that refute the claim.", |
| "ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim.", |
| "DBPedia": "Given a query, retrieve relevant entity descriptions from DBPedia.", |
| "FEVER": "Given a claim, retrieve documents that support or refute the claim.", |
| "FiQA2018": "Given a financial question, retrieve user replies that best answer the question.", |
| "HotpotQA": "Given a multi-hop question, retrieve documents that can help answer the question.", |
| "MSMARCO": "Given a web search query, retrieve relevant passages that answer the query.", |
| "NFCorpus": "Given a question, retrieve relevant documents that best answer the question.", |
| "NQ": "Given a question, retrieve Wikipedia passages that answer the question.", |
| "QuoraRetrieval": "Given a question, retrieve questions that are semantically equivalent to the given question.", |
| "SCIDOCS": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.", |
| "SciFact": "Given a scientific claim, retrieve documents that support or refute the claim.", |
| "Touche2020": "Given a question, retrieve detailed and persuasive arguments that answer the question.", |
| "TRECCOVID": "Given a query on COVID-19, retrieve documents that answer the query.", |
| |
| "T2Retrieval": "为这个问题检索相关段落。", |
| "MMarcoRetrieval": "为这个查询检索相关段落。", |
| "DuRetrieval": "为这个问题检索相关百度知道回答。", |
| "CovidRetrieval": "为这个问题检索相关政策回答。", |
| "CmedqaRetrieval": "为这个医疗问题检索相关回答。", |
| "EcomRetrieval": "为这个查询检索相关商品标题。", |
| "MedicalRetrieval": "为这个医疗问题检索相关回答。", |
| "VideoRetrieval": "为这个电影标题检索相关段落。", |
| } |
|
|
| task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()}) |
|
|
| return task_name_to_instruct.get(task_name,None) |
| return default_instruct |
| def _transform_func(tokenizer: PreTrainedTokenizerFast, |
| examples: Dict[str, List]) -> BatchEncoding: |
| batch_dict = tokenizer(examples['input_texts'], |
| max_length=1024, |
| padding=True, |
| truncation=True) |
|
|
| return batch_dict |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def mean_pooling(hidden,attention_mask): |
| |
| s = torch.sum(hidden * attention_mask.unsqueeze(-1).float(), dim=1) |
| d = attention_mask.sum(dim=1, keepdim=True).float() |
| return s / d |
|
|
| def wmean_pooling(hidden,attention_mask): |
| attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) |
| hidden_masked = hidden * attention_mask_.unsqueeze(-1).float() |
| s = torch.sum(hidden_masked, dim=1) |
| d = attention_mask_.sum(dim=1, keepdim=True).float() |
| reps = s / d |
| return reps |
|
|
| def reverse_wmean_pooling(hidden,attention_mask): |
| attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) |
| d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() / attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float() |
| hidden = hidden.float() * d |
| return hidden / torch.clamp(attention_mask_.unsqueeze(-1).float(),min=1e-9) |
|
|
|
|
| def sparse_pooling(head,model,items,hidden,attention_mask): |
| hidden = reverse_wmean_pooling(hidden,attention_mask) |
| max_hidden_norm = torch.max(torch.norm(hidden,dim=-1),dim = -1).values |
| token_weights = torch.relu(head(hidden.float()/max_hidden_norm.unsqueeze(-1).unsqueeze(-1))) |
| vocab_size = model.embed_tokens.weight.size(0) |
| input_ids = items["input_ids"] |
| sparse_embedding_chunks = [] |
| mini_chunk_size = 1 |
| mini_chunk_size = min(mini_chunk_size,hidden.shape[0]) |
| for i in range(0, token_weights.size(0), mini_chunk_size): |
| now_chunk_size = min(mini_chunk_size, token_weights.size(0) - i) |
| sparse_embedding = torch.zeros(now_chunk_size , input_ids.size(1), vocab_size, |
| dtype=token_weights.dtype, |
| device=token_weights.device) |
| sparse_embedding_chunks.append(torch.max((torch.scatter(sparse_embedding, dim=-1, index=input_ids[i:i+now_chunk_size, :].unsqueeze(-1), src=token_weights[i:i+now_chunk_size, :])), dim=1).values) |
| sparse_embedding = torch.concat(sparse_embedding_chunks, dim=0) |
| unused_tokens = [0,1,2,73440] |
| sparse_embedding[:, unused_tokens] *= 0. |
| return sparse_embedding |
|
|
| def concat_pooling(head,model,items,hidden,attention_mask): |
| mean_reps = mean_pooling(hidden,attention_mask) |
| mean_reps = F.normalize(mean_reps, p=2, dim=1) |
| sparse_reps = sparse_pooling(head,model,items,hidden,attention_mask) * math.sqrt(0.3) |
| return torch.cat([mean_reps,sparse_reps],dim=-1) |
|
|
| |
|
|
| class DenseEncoder(torch.nn.Module): |
| def __init__(self, **kwargs): |
| super().__init__() |
| |
| model_path = "openbmb/MiniCPM-Embedding-Light" |
| self.encoder = AutoModel.from_pretrained(model_path, trust_remote_code=True,attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda") |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| self.gpu_count = torch.cuda.device_count() |
| self.instruction = "" |
|
|
| self.encoder.eval() |
| self.encoder.cuda() |
|
|
| if self.gpu_count > 1: |
| self.encoder = torch.nn.DataParallel(self.encoder) |
| |
| @torch.no_grad() |
| def encode(self, sentences,is_query=None, **kwargs) -> np.ndarray: |
| """ Returns a list of embeddings for the given sentences. |
| Args: |
| sentences (`List[str]`): List of sentences to encode |
| batch_size (`int`): Batch size for the encoding |
| |
| Returns: |
| `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences |
| """ |
| if is_query is not False: |
| sentences = [self.instruction + s for s in sentences] |
| dataset: Dataset = Dataset.from_dict({'input_texts': sentences}) |
| |
| |
| dataset.set_transform(partial(_transform_func, self.tokenizer)) |
|
|
| data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8) |
| data_loader = DataLoader( |
| dataset, |
| batch_size=128* self.gpu_count, |
| shuffle=False, |
| drop_last=False, |
| num_workers=2, |
| collate_fn=data_collator, |
| pin_memory=True) |
|
|
| encoded_embeds = [] |
| for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10): |
|
|
| with torch.cuda.amp.autocast() and torch.no_grad(): |
| for key in batch_dict: |
| batch_dict[key] = batch_dict[key].to("cuda") |
| outputs: BaseModelOutput = self.encoder(**batch_dict) |
| if MODE == "Dense": |
| embeds = mean_pooling(outputs.last_hidden_state, batch_dict['attention_mask']) |
| embeds = F.normalize(embeds, p=2, dim=1) |
| elif MODE == "Sparse": |
| embeds = sparse_pooling(self.encoder.module.head,self.encoder.module, batch_dict, outputs.last_hidden_state, batch_dict['attention_mask']) |
| else: |
| embeds = concat_pooling(self.encoder.module.head,self.encoder.module, batch_dict, outputs.last_hidden_state, batch_dict['attention_mask']) |
| encoded_embeds.append(embeds.cpu().numpy()) |
|
|
| return np.concatenate(encoded_embeds, axis=0) |
| |
| @torch.no_grad() |
| def encode_queries(self, queries: list[str], **kwargs) -> list[np.ndarray] | list[torch.Tensor]: |
| """ |
| Returns a list of embeddings for the given sentences. |
| Args: |
| queries: List of sentences to encode |
| |
| Returns: |
| List of embeddings for the given sentences |
| """ |
|
|
|
|
| queries = [query for query in queries] |
| return self.encode(queries, is_query=True, **kwargs) |
| |
| @torch.no_grad() |
| def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs): |
| |
| if type(corpus) is dict: |
| sentences = [ |
| (corpus["title"][i] + " " + corpus["text"][i]).strip() |
| if "title" in corpus |
| else corpus["text"][i].strip() |
| for i in range(len(corpus["text"])) |
| ] |
| elif isinstance(corpus[0], dict): |
| sentences = [ |
| (doc["title"] + " " + doc["text"]).strip() |
| if "title" in doc |
| else doc["text"].strip() |
| for doc in corpus |
| ] |
| else: |
| sentences = corpus |
| is_query = False |
| return self.encode(sentences, is_query=is_query, **kwargs) |
|
|
|
|
| model = DenseEncoder() |
| task_names = MTEB_TASK_LIST |
| task_names = ["NFCorpus"] |
| lang = ["en","zh", "zh-CN"] |
|
|
| for task in task_names: |
| try: |
| evaluation = MTEB(tasks=[task], task_langs=lang) |
| task_cls = evaluation.tasks[0] |
| task_name: str = task_cls.metadata_dict["name"] |
| task_type: str = task_cls.metadata_dict["type"] |
| instruction = get_task_def_by_task_name_and_type(task_name, task_type) |
| model.instruction = get_detailed_instruct(instruction) |
| print(model.instruction) |
| if task == "MSMARCO": |
| eval_splits = ["dev"] |
| elif task in CMTEB_TASK_LIST: |
| eval_splits = task_cls.metadata_dict["eval_splits"] |
| else: |
| eval_splits = ["test"] |
| evaluation.run(model, eval_splits=eval_splits, overwrite_results=True) |
| |
| except Exception as e: |
| import traceback |
| print(traceback.format_exc()) |
| continue |