MossAudioTokenizer
This is the code for MOSS-Audio-Tokenizer presented in MOSS-Audio-Tokenizer: Scaling Audio Tokenizers for Future Audio Foundation Models.
MOSSAudioTokenizer is a unified discrete audio tokenizer based on the Cat (Causal Audio Tokenizer with Transformer) architecture. Scaling to 1.6 billion parameters, it functions as a unified discrete interface, delivering both lossless-quality reconstruction and high-level semantic alignment.
Key Features:
- Extreme Compression & Variable Bitrate: It compresses 48kHz stereo audio into a remarkably low frame rate of 12.5Hz. Utilizing a 32-layer Residual LFQ quantizer stack, it supports high-fidelity reconstruction across a wide range of bitrates.
- Pure Transformer Architecture: The model features a "CNN-free" homogeneous architecture built entirely from Causal Transformer blocks. With 1.6B combined parameters (Encoder + Decoder), it ensures exceptional scalability and supports low-latency streaming inference.
- Large-Scale General Audio Training: Trained on 3 million hours of diverse audio data, the model excels at encoding and reconstructing all audio domains, including speech, sound effects, and music.
- Unified Semantic-Acoustic Representation: While achieving state-of-the-art reconstruction quality, Cat produces discrete tokens that are "semantic-rich," making them ideal for downstream tasks like speech understanding (ASR) and generation (TTS).
- Fully Trained From Scratch: Cat does not rely on any pretrained encoders (such as HuBERT or Whisper) or distillation from teacher models. All representations are learned autonomously from raw data.
- End-to-End Joint Optimization: All components—including the encoder, quantizer, decoder, discriminator, and a decoder-only LLM for semantic alignment—are optimized jointly in a single unified training pipeline.
Summary: By combining a simple, scalable architecture with massive-scale data, the Cat architecture overcomes the bottlenecks of traditional audio tokenizers. It provides a robust, high-fidelity, and semantically grounded interface for the next generation of native audio foundation models.
This repository contains a lightweight remote-code implementation that mirrors the current 🤗 Transformers
transformers.models.moss_audio_tokenizer module. It is intended to be uploaded to a Hugging Face Hub model repository
and loaded with trust_remote_code=True when needed.
Usage
Quickstart
import torch
from transformers import AutoModel
import torchaudio
repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
wav, sr = torchaudio.load('demo/demo_gt.wav')
if sr != model.sampling_rate:
wav = torchaudio.functional.resample(wav, sr, model.sampling_rate)
if wav.shape[0] == 1:
wav = wav.repeat(model.config.number_channels, 1)
else:
wav = wav[: model.config.number_channels]
wav = wav.unsqueeze(0)
enc = model.encode(wav, return_dict=True)
print(f"enc.audio_codes.shape: {enc.audio_codes.shape}")
dec = model.decode(enc.audio_codes, return_dict=True)
print(f"dec.audio.shape: {dec.audio.shape}")
wav = dec.audio.squeeze(0)
torchaudio.save("demo/demo_rec.wav", wav, sample_rate=model.sampling_rate)
# Decode using only the first 8 layers of the RVQ
dec_rvq8 = model.decode(enc.audio_codes[:8], return_dict=True)
wav_rvq8 = dec_rvq8.audio.squeeze(0)
torchaudio.save("demo/demo_rec_rvq8.wav", wav_rvq8, sample_rate=model.sampling_rate)
Attention Backend And Compute Dtype
config.attention_implementation controls whether transformer layers prefer sdpa or flash_attention_2.
config.compute_dtype controls the non-quantizer autocast dtype and supports fp32, bf16, and fp16.
model.set_attention_implementation("flash_attention_2")
model.set_compute_dtype("fp16")
The quantizer always runs in fp32.
Streaming
MossAudioTokenizerModel.encode, decode, batch_encode, and batch_decode all support streaming through a
chunk_duration argument.
chunk_durationis expressed in seconds.chunk_duration * MossAudioTokenizerConfig.sampling_ratemust be divisible byMossAudioTokenizerConfig.downsample_rate.- Streaming batch inference is supported.
- The public waveform interface expects stereo inputs shaped
(2, T)or batched stereo inputs shaped(B, 2, T).
import torch
from transformers import AutoModel
repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
audio = torch.randn(2, 48000 * 6) # dummy stereo waveform
# 6.0s @ 48kHz = 288000 samples, divisible by downsample_rate=3840
enc = model.encode(audio.unsqueeze(0), return_dict=True, chunk_duration=0.08)
dec = model.decode(enc.audio_codes, return_dict=True, chunk_duration=0.08)
batch_enc = model.batch_encode([audio, audio[:, : 48000 * 3]], chunk_duration=0.08)
codes_list = [
batch_enc.audio_codes[:, i, : batch_enc.audio_codes_lengths[i]]
for i in range(batch_enc.audio_codes.shape[1])
]
batch_dec = model.batch_decode(codes_list, chunk_duration=0.08)
Continuous Batch Streaming Decode
For decoder-side continuous batching, prefer batch_decode(..., streaming=True, ...).
- The first streaming call may pass
max_batch_size=.... If it is omitted, the first batch size reserves the fixed-slot decoder budget for that public stream. - Same-size calls continue the existing logical rows in-order.
- If a later call is larger, the new rows are admitted by tail append.
finalize_indicesmeans "decode these rows one last time, then evict them". The indices are interpreted against the pre-call logical order.- After a finalize call returns, the next streaming call may use the smaller survivor batch.
reset_stream=Truediscards the hidden public streaming state and starts a fresh stream.
Milestone 1 boundaries:
- decode-only continuous batching
- one active streaming decode state per model instance
- fixed-slot decoder reservation from
max_batch_size - no encode-side continuous batching
- no physical compaction of surviving decode slots
- no multi-session concurrency on one model instance
import torch
from transformers import AutoModel
repo_id = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).eval()
num_quantizers = model.config.quantizer_kwargs["num_quantizers"]
codes_a0 = torch.randint(0, 8, (num_quantizers, 2))
codes_b0 = torch.randint(0, 8, (num_quantizers, 3))
codes_a1 = torch.randint(0, 8, (num_quantizers, 2))
codes_b1 = torch.randint(0, 8, (num_quantizers, 2))
codes_c0 = torch.randint(0, 8, (num_quantizers, 1))
codes_a2 = torch.randint(0, 8, (num_quantizers, 1))
codes_b2 = torch.randint(0, 8, (num_quantizers, 2))
codes_c1 = torch.randint(0, 8, (num_quantizers, 2))
codes_b3 = torch.randint(0, 8, (num_quantizers, 1))
codes_c2 = torch.randint(0, 8, (num_quantizers, 1))
# First call reserves 3 fixed decoder slots for A and B.
out_ab0 = model.batch_decode(
[codes_a0, codes_b0],
streaming=True,
max_batch_size=3,
reset_stream=True,
)
# Same logical rows continue in-order; C is a tail append.
out_abc1 = model.batch_decode(
[codes_a1, codes_b1, codes_c0],
streaming=True,
)
# Finalize A against the pre-call logical order. A still decodes in this call,
# then is evicted immediately afterward.
out_abc2 = model.batch_decode(
[codes_a2, codes_b2, codes_c1],
streaming=True,
finalize_indices=[0],
)
# The next call can shrink to the surviving logical rows only.
out_bc3 = model.batch_decode(
[codes_b3, codes_c2],
streaming=True,
)
Repository layout
configuration_moss_audio_tokenizer.pymodeling_moss_audio_tokenizer.py__init__.pyconfig.json- model weights
Citation
If you use this code or result in your paper, please cite our work as:
- Downloads last month
- 63