split_data.py 1.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
"""
This tool is used for splitting data into each node of
paddle cloud by total trainer count and current trainer id.
The meaning of trainer is a instance of k8s cluster.
This script should be called in paddle cloud.
"""
import os
import json
import argparse

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
    "--in_manifest_path",
14
    default='./cloud.train.manifest',
15 16 17 18
    type=str,
    help="Input manifest path. (default: %(default)s)")
parser.add_argument(
    "--data_tar_path",
19
    default='./cloud.train.tar',
20 21 22 23
    type=str,
    help="Data tar file path. (default: %(default)s)")
parser.add_argument(
    "--out_manifest_path",
24
    default='./local.train.manifest',
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    type=str,
    help="Out manifest file path. (default: %(default)s)")
args = parser.parse_args()


def split_data(in_manifest, tar_path, out_manifest):
    with open("/trainer_id", "r") as f:
        trainer_id = int(f.readline()[:-1])
    with open("/trainer_count", "r") as f:
        trainer_count = int(f.readline()[:-1])

    tar_path = os.path.abspath(tar_path)
    result = []
    for index, json_line in enumerate(open(in_manifest)):
        if (index % trainer_count) == trainer_id:
            json_data = json.loads(json_line)
            json_data['audio_filepath'] = "tar:%s#%s" % (
                tar_path, json_data['audio_filepath'])
            result.append("%s\n" % json.dumps(json_data))
    with open(out_manifest, 'w') as manifest:
        manifest.writelines(result)


if __name__ == '__main__':
    split_data(args.in_manifest_path, args.data_tar_path,
               args.out_manifest_path)