File size: 4,526 Bytes
3c50954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import os
import shutil

from tqdm import tqdm

from ort_common import WenetONNXRunner, pack_calibration_dataset


def get_args():
    parser = argparse.ArgumentParser(
        description="Generate calibration_dataset for exported ONNX models")
    parser.add_argument("--input",
                        "-i",
                        nargs="+",
                        required=True,
                        help="Input wav file(s) or directory/directories")
    parser.add_argument("--config",
                        required=True,
                        help="yaml file in checkpoint path")
    parser.add_argument(
        "--vocab",
        required=True,
        help="pretrained units.txt, for example pretrained/<model>/units.txt",
    )
    parser.add_argument("--onnx_dir",
                        default="onnx_model",
                        help="directory containing exported ONNX models")
    parser.add_argument("--calib_data_path",
                        default="calibration_dataset",
                        help="output calibration dataset directory")
    parser.add_argument("--parts",
                        nargs="+",
                        choices=["all", "offline", "online", "decoder"],
                        default=["all"],
                        help="which model inputs to generate")
    parser.add_argument("--offline_seq_len", type=int, default=1024)
    parser.add_argument("--decoder_len", type=int, default=32)
    parser.add_argument("--decoding_chunk_size", type=int, default=16)
    parser.add_argument("--num_decoding_left_chunks", type=int, default=5)
    parser.add_argument("--max_num",
                        type=int,
                        default=100,
                        help="maximum number of audio files used for calibration; set <= 0 to use all")
    parser.add_argument("--keep_existing",
                        action="store_true",
                        help="append to an existing calibration directory")
    return parser.parse_args()


def expand_audio_inputs(inputs):
    audio_exts = {".wav", ".flac", ".mp3", ".m4a", ".ogg"}
    audio_files = []
    for path in inputs:
        if os.path.isdir(path):
            for root, _, files in os.walk(path):
                for filename in files:
                    if os.path.splitext(filename)[1].lower() in audio_exts:
                        audio_files.append(os.path.join(root, filename))
        else:
            audio_files.append(path)
    audio_files = sorted(audio_files)
    if not audio_files:
        raise FileNotFoundError("No audio files found")
    return audio_files


def normalize_parts(parts):
    if "all" in parts:
        return {"offline", "online", "decoder"}
    return set(parts)


def limit_audio_files(audio_files, max_num):
    if max_num is None or max_num <= 0:
        return audio_files
    return audio_files[:max_num]


def main():
    args = get_args()
    parts = normalize_parts(args.parts)
    audio_files = limit_audio_files(expand_audio_inputs(args.input),
                                    args.max_num)

    if os.path.exists(args.calib_data_path) and not args.keep_existing:
        shutil.rmtree(args.calib_data_path)
    os.makedirs(args.calib_data_path, exist_ok=True)

    runner = WenetONNXRunner(
        args.config,
        args.vocab,
        onnx_dir=args.onnx_dir,
        offline_seq_len=args.offline_seq_len,
        decoder_len=args.decoder_len,
        decoding_chunk_size=args.decoding_chunk_size,
        num_decoding_left_chunks=args.num_decoding_left_chunks,
    )

    counts = {"offline": 0, "online": 0, "decoder": 0}
    progress = tqdm(audio_files,
                    desc="Generating calibration data",
                    unit="wav")
    for audio_idx, audio_file in enumerate(progress):
        sample_counts = runner.save_calibration_for_audio(
            audio_file, parts, args.calib_data_path, audio_idx)
        for key, value in sample_counts.items():
            counts[key] += value
        progress.set_postfix(offline=counts["offline"],
                             online=counts["online"],
                             decoder=counts["decoder"])

    print("Packing calibration dataset...")
    pack_calibration_dataset(args.calib_data_path)
    print(f"Generated calibration data in {args.calib_data_path}")
    print(f"offline samples: {counts['offline']}")
    print(f"online samples: {counts['online']}")
    print(f"decoder samples: {counts['decoder']}")


if __name__ == "__main__":
    main()