remove_longshortdata.py 3.1 KB
Newer Older
1 2 3
#!/usr/bin/env python3
"""remove longshort data from manifest"""
import argparse
H
Hui Zhang 已提交
4 5
import logging

6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
import jsonlines

from paddlespeech.s2t.utils.cli_utils import get_commandline_args

# manifest after format
# josnline like this
# {
#   "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}],
#   "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}],
#   "utt2spk": "111-2222",
#   "utt": "111-2222-333"
# }


def get_parser():
    parser = argparse.ArgumentParser(
        description="remove longshort data from format manifest",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    parser.add_argument(
        "--verbose", "-V", default=0, type=int, help="Verbose option")
    parser.add_argument(
H
Hui Zhang 已提交
27 28 29 30
        "--iaxis",
        default=0,
        type=int,
        help="multi inputs index, 0 is the first")
31
    parser.add_argument(
H
Hui Zhang 已提交
32 33 34 35 36 37 38 39
        "--oaxis",
        default=0,
        type=int,
        help="multi outputs index, 0 is the first")
    parser.add_argument("--maxframes", default=2000, type=int, help="maxframes")
    parser.add_argument("--minframes", default=10, type=int, help="minframes")
    parser.add_argument("--maxchars", default=200, type=int, help="max tokens")
    parser.add_argument("--minchars", default=0, type=int, help="min tokens")
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    parser.add_argument(
        "--stride_ms", default=10, type=int, help="stride in ms unit.")
    parser.add_argument(
        "rspecifier",
        type=str,
        help="jsonl format manifest. e.g. manifest.jsonl")
    parser.add_argument(
        "wspecifier_or_wxfilename",
        type=str,
        help="Write specifier. e.g. manifest.jsonl")
    return parser


def filter_input(args, line):
    tmp = line['input'][args.iaxis]
    if args.sound:
        # second to frame
        nframe = tmp['shape'][0] * 1000 / args.stride_ms
    else:
        nframe = tmp['shape'][0]
H
Hui Zhang 已提交
60

61 62 63 64 65 66 67 68 69 70 71 72
    if nframe < args.minframes or nframe > args.maxframes:
        return True
    else:
        return False


def filter_output(args, line):
    nchars = len(line['output'][args.iaxis]['text'])
    if nchars < args.minchars or nchars > args.maxchars:
        return True
    else:
        return False
H
Hui Zhang 已提交
73

74 75 76 77 78 79 80 81 82 83

def main():
    args = get_parser().parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())
H
Hui Zhang 已提交
84

85 86 87 88
    with jsonlines.open(args.rspecifier, 'r') as reader:
        lines = list(reader)
    logging.info(f"Example: {len(lines)}")
    feat = lines[0]['input'][args.iaxis]['feat']
H
Hui Zhang 已提交
89
    args.soud = False
90 91
    if feat.split('.')[-1] not in 'ark, scp':
        args.sound = True
H
Hui Zhang 已提交
92

93 94 95 96 97 98 99 100 101 102 103
    count = 0
    filter = 0
    with jsonlines.open(args.wspecifier_or_wxfilename, 'w') as writer:
        for line in lines:
            if filter_input(args, line) or filter_output(args, line):
                filter += 1
                continue
            writer.write(line)
            count += 1
    logging.info(f"Example after filter: {count}\{filter}")

H
Hui Zhang 已提交
104

105
if __name__ == '__main__':
H
Hui Zhang 已提交
106
    main()