From 7a359597c298836ffaea3f62cdcf03116c66cbec Mon Sep 17 00:00:00 2001 From: Chengmo Date: Wed, 17 Jun 2020 17:00:31 +0800 Subject: [PATCH] fix split files at PY3 (#103) * fix split files at PY3 * fix linux at PY3 * fix desc error * fix collective cards and worknum Co-authored-by: tangwei --- core/trainers/framework/dataset.py | 5 +++-- core/utils/dataloader_instance.py | 10 ++++++--- core/utils/envs.py | 7 ++++++ core/utils/util.py | 36 +++++++++++++++++++++++++++--- run.py | 8 +++---- 5 files changed, 54 insertions(+), 12 deletions(-) diff --git a/core/trainers/framework/dataset.py b/core/trainers/framework/dataset.py index 00652e35..273e3a2a 100644 --- a/core/trainers/framework/dataset.py +++ b/core/trainers/framework/dataset.py @@ -15,13 +15,13 @@ from __future__ import print_function import os -import warnings import paddle.fluid as fluid from paddlerec.core.utils import envs from paddlerec.core.utils import dataloader_instance from paddlerec.core.reader import SlotReader from paddlerec.core.trainer import EngineMode +from paddlerec.core.utils.util import split_files __all__ = ["DatasetBase", "DataLoader", "QueueDataset"] @@ -123,7 +123,8 @@ class QueueDataset(DatasetBase): for x in os.listdir(train_data_path) ] if context["engine"] == EngineMode.LOCAL_CLUSTER: - file_list = context["fleet"].split_files(file_list) + file_list = split_files(file_list, context["fleet"].worker_index(), + context["fleet"].worker_num()) dataset.set_filelist(file_list) for model_dict in context["phases"]: diff --git a/core/utils/dataloader_instance.py b/core/utils/dataloader_instance.py index e3062bb7..c66d1b36 100755 --- a/core/utils/dataloader_instance.py +++ b/core/utils/dataloader_instance.py @@ -19,6 +19,7 @@ from paddlerec.core.utils.envs import get_global_env from paddlerec.core.utils.envs import get_runtime_environ from paddlerec.core.reader import SlotReader from paddlerec.core.trainer import EngineMode +from paddlerec.core.utils.util import split_files def dataloader_by_name(readerclass, @@ -39,7 +40,8 @@ def dataloader_by_name(readerclass, files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] if context["engine"] == EngineMode.LOCAL_CLUSTER: - files = context["fleet"].split_files(files) + files = split_files(files, context["fleet"].worker_index(), + context["fleet"].worker_num()) print("file_list : {}".format(files)) reader = reader_class(yaml_file) @@ -80,7 +82,8 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context): files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] if context["engine"] == EngineMode.LOCAL_CLUSTER: - files = context["fleet"].split_files(files) + files = split_files(files, context["fleet"].worker_index(), + context["fleet"].worker_num()) print("file_list: {}".format(files)) sparse = get_global_env(name + "sparse_slots", "#") @@ -133,7 +136,8 @@ def slotdataloader(readerclass, train, yaml_file, context): files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] if context["engine"] == EngineMode.LOCAL_CLUSTER: - files = context["fleet"].split_files(files) + files = split_files(files, context["fleet"].worker_index(), + context["fleet"].worker_num()) print("file_list: {}".format(files)) sparse = get_global_env("sparse_slots", "#", namespace) diff --git a/core/utils/envs.py b/core/utils/envs.py index 29403420..ddcc9a94 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -18,6 +18,7 @@ import copy import os import socket import sys +import six import traceback import six @@ -102,6 +103,12 @@ def set_global_envs(envs): name = ".".join(["dataset", dataset["name"], "type"]) global_envs[name] = "DataLoader" + if get_platform() == "LINUX" and six.PY3: + print("QueueDataset can not support PY3, change to DataLoader") + for dataset in envs["dataset"]: + name = ".".join(["dataset", dataset["name"], "type"]) + global_envs[name] = "DataLoader" + def get_global_env(env_name, default_value=None, namespace=None): """ diff --git a/core/utils/util.py b/core/utils/util.py index 381d35ca..4eba912c 100755 --- a/core/utils/util.py +++ b/core/utils/util.py @@ -19,11 +19,8 @@ import time import numpy as np from paddle import fluid -from paddlerec.core.utils import fs as fs - def save_program_proto(path, program=None): - if program is None: _program = fluid.default_main_program() else: @@ -171,6 +168,39 @@ def print_cost(cost, params): return log_str +def split_files(files, trainer_id, trainers): + """ + split files before distributed training, + example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer + 0 gets [a, b, c] and trainer 1 gets [d, e]. + example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets + [a], trainer 1 gets [b], trainer 2 gets [] + + Args: + files(list): file list need to be read. + + Returns: + list: files belongs to this worker. + """ + if not isinstance(files, list): + raise TypeError("files should be a list of file need to be read.") + + remainder = len(files) % trainers + blocksize = int(len(files) / trainers) + + blocks = [blocksize] * trainers + for i in range(remainder): + blocks[i] += 1 + + trainer_files = [[]] * trainers + begin = 0 + for i in range(trainers): + trainer_files[i] = files[begin:begin + blocks[i]] + begin += blocks[i] + + return trainer_files[trainer_id] + + class CostPrinter(object): """ For count cost time && print cost log diff --git a/run.py b/run.py index 699d48f9..b9e15a50 100755 --- a/run.py +++ b/run.py @@ -139,8 +139,8 @@ def get_engine(args, running_config, mode): engine = "LOCAL_CLUSTER_TRAIN" if engine not in engine_choices: - raise ValueError("{} can not be chosen in {}".format(engine_class, - engine_choices)) + raise ValueError("{} can only be chosen in {}".format(engine_class, + engine_choices)) run_engine = engines[transpiler].get(engine, None) return run_engine @@ -439,8 +439,8 @@ def local_cluster_engine(args): if fleet_mode == "COLLECTIVE": cluster_envs["selected_gpus"] = selected_gpus gpus = selected_gpus.split(",") - gpu_num = get_worker_num(run_extras, len(gpus)) - cluster_envs["selected_gpus"] = ','.join(gpus[:gpu_num]) + worker_num = get_worker_num(run_extras, len(gpus)) + cluster_envs["selected_gpus"] = ','.join(gpus[:worker_num]) cluster_envs["server_num"] = server_num cluster_envs["worker_num"] = worker_num -- GitLab