| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import SequenceClassifierOutput |
| from transformers import SeamlessM4TModel |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class SeamlessBasicConfig(PretrainedConfig): |
| """Configuration class for SeamlessBasic model.""" |
| |
| model_type = "seamless_basic" |
| |
| def __init__( |
| self, |
| seamless_model_name="facebook/hf-seamless-m4t-medium", |
| hidden_size=1024, |
| dropout_prob=0.1, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.seamless_model_name = seamless_model_name |
| self.hidden_size = hidden_size |
| self.dropout_prob = dropout_prob |
|
|
|
|
| class HFSeamlessBasic(PreTrainedModel): |
| """Basic SeamlessM4T model for HuggingFace Hub - processes audio and text only.""" |
| |
| config_class = SeamlessBasicConfig |
| supports_gradient_checkpointing = True |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| |
| |
| self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name) |
| self.seamless_model_speech_encoder = self.seamless_model.speech_encoder |
| self.seamless_model_text_encoder = self.seamless_model.text_encoder |
| |
| |
| for param in self.seamless_model_speech_encoder.parameters(): |
| param.requires_grad = False |
| for param in self.seamless_model_text_encoder.parameters(): |
| param.requires_grad = False |
| |
| |
| self.audio_proj = nn.Linear( |
| self.seamless_model_speech_encoder.config.hidden_size, |
| config.hidden_size |
| ) |
| self.text_proj = nn.Linear( |
| self.seamless_model_text_encoder.config.hidden_size, |
| config.hidden_size |
| ) |
| |
| |
| self.fc = nn.Sequential( |
| nn.Linear(2048, 1024), |
| nn.ReLU(), |
| nn.Dropout(config.dropout_prob), |
| nn.Linear(1024, 512), |
| nn.ReLU(), |
| nn.Dropout(config.dropout_prob), |
| nn.Linear(512, 256), |
| nn.ReLU(), |
| nn.Dropout(config.dropout_prob), |
| nn.Linear(256, 1) |
| ) |
| |
| def forward( |
| self, |
| input_features, |
| input_ids, |
| text_attention_mask, |
| audio_attention_mask=None, |
| labels=None, |
| **kwargs |
| ): |
| |
| audio_emb = self.seamless_model_speech_encoder( |
| input_features=input_features, |
| attention_mask=audio_attention_mask |
| ).last_hidden_state.mean(dim=1) |
| audio_emb = self.audio_proj(audio_emb) |
| |
| |
| text_emb = self.seamless_model_text_encoder( |
| input_ids=input_ids, |
| attention_mask=text_attention_mask |
| ).last_hidden_state.mean(dim=1) |
| text_emb = self.text_proj(text_emb) |
| |
| |
| combined = torch.cat([audio_emb, text_emb], dim=1) |
| |
| logits = self.fc(combined).squeeze(-1) |
| |
| |
| loss = F.mse_loss(logits, labels) if labels is not None else None |
| |
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=None, |
| attentions=None |
| ) |