split_data.py 1.2 KB
Newer Older
W
wanghaoshuang 已提交
1
"""This tool is used for splitting data into each node of
2
paddlecloud. This script should be called in paddlecloud.
3
"""
W
wanghaoshuang 已提交
4 5 6
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
7

8 9 10 11 12 13 14 15
import os
import json
import argparse

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
    "--in_manifest_path",
    type=str,
16 17
    required=True,
    help="Input manifest path for all nodes.")
18 19 20
parser.add_argument(
    "--out_manifest_path",
    type=str,
21 22
    required=True,
    help="Output manifest file path for current node.")
23 24 25
args = parser.parse_args()


26
def split_data(in_manifest_path, out_manifest_path):
27 28 29 30 31
    with open("/trainer_id", "r") as f:
        trainer_id = int(f.readline()[:-1])
    with open("/trainer_count", "r") as f:
        trainer_count = int(f.readline()[:-1])

32 33
    out_manifest = []
    for index, json_line in enumerate(open(in_manifest_path, 'r')):
34
        if (index % trainer_count) == trainer_id:
35 36 37
            out_manifest.append("%s\n" % json_line.strip())
    with open(out_manifest_path, 'w') as f:
        f.writelines(out_manifest)
38 39 40


if __name__ == '__main__':
41
    split_data(args.in_manifest_path, args.out_manifest_path)