未验证 提交 7a359597 编写于 作者: C Chengmo 提交者: GitHub

fix split files at PY3 (#103)

* fix split files at PY3

* fix linux at PY3

* fix desc error

* fix collective cards and worknum
Co-authored-by: Ntangwei <tangwei12@baidu.com>
上级 947395bb
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
from __future__ import print_function from __future__ import print_function
import os import os
import warnings
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
from paddlerec.core.utils import dataloader_instance from paddlerec.core.utils import dataloader_instance
from paddlerec.core.reader import SlotReader from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode from paddlerec.core.trainer import EngineMode
from paddlerec.core.utils.util import split_files
__all__ = ["DatasetBase", "DataLoader", "QueueDataset"] __all__ = ["DatasetBase", "DataLoader", "QueueDataset"]
...@@ -123,7 +123,8 @@ class QueueDataset(DatasetBase): ...@@ -123,7 +123,8 @@ class QueueDataset(DatasetBase):
for x in os.listdir(train_data_path) for x in os.listdir(train_data_path)
] ]
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
file_list = context["fleet"].split_files(file_list) file_list = split_files(file_list, context["fleet"].worker_index(),
context["fleet"].worker_num())
dataset.set_filelist(file_list) dataset.set_filelist(file_list)
for model_dict in context["phases"]: for model_dict in context["phases"]:
......
...@@ -19,6 +19,7 @@ from paddlerec.core.utils.envs import get_global_env ...@@ -19,6 +19,7 @@ from paddlerec.core.utils.envs import get_global_env
from paddlerec.core.utils.envs import get_runtime_environ from paddlerec.core.utils.envs import get_runtime_environ
from paddlerec.core.reader import SlotReader from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode from paddlerec.core.trainer import EngineMode
from paddlerec.core.utils.util import split_files
def dataloader_by_name(readerclass, def dataloader_by_name(readerclass,
...@@ -39,7 +40,8 @@ def dataloader_by_name(readerclass, ...@@ -39,7 +40,8 @@ def dataloader_by_name(readerclass,
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
files = context["fleet"].split_files(files) files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list : {}".format(files)) print("file_list : {}".format(files))
reader = reader_class(yaml_file) reader = reader_class(yaml_file)
...@@ -80,7 +82,8 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context): ...@@ -80,7 +82,8 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
files = context["fleet"].split_files(files) files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list: {}".format(files)) print("file_list: {}".format(files))
sparse = get_global_env(name + "sparse_slots", "#") sparse = get_global_env(name + "sparse_slots", "#")
...@@ -133,7 +136,8 @@ def slotdataloader(readerclass, train, yaml_file, context): ...@@ -133,7 +136,8 @@ def slotdataloader(readerclass, train, yaml_file, context):
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
files = context["fleet"].split_files(files) files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list: {}".format(files)) print("file_list: {}".format(files))
sparse = get_global_env("sparse_slots", "#", namespace) sparse = get_global_env("sparse_slots", "#", namespace)
......
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
import os import os
import socket import socket
import sys import sys
import six
import traceback import traceback
import six import six
...@@ -102,6 +103,12 @@ def set_global_envs(envs): ...@@ -102,6 +103,12 @@ def set_global_envs(envs):
name = ".".join(["dataset", dataset["name"], "type"]) name = ".".join(["dataset", dataset["name"], "type"])
global_envs[name] = "DataLoader" global_envs[name] = "DataLoader"
if get_platform() == "LINUX" and six.PY3:
print("QueueDataset can not support PY3, change to DataLoader")
for dataset in envs["dataset"]:
name = ".".join(["dataset", dataset["name"], "type"])
global_envs[name] = "DataLoader"
def get_global_env(env_name, default_value=None, namespace=None): def get_global_env(env_name, default_value=None, namespace=None):
""" """
......
...@@ -19,11 +19,8 @@ import time ...@@ -19,11 +19,8 @@ import time
import numpy as np import numpy as np
from paddle import fluid from paddle import fluid
from paddlerec.core.utils import fs as fs
def save_program_proto(path, program=None): def save_program_proto(path, program=None):
if program is None: if program is None:
_program = fluid.default_main_program() _program = fluid.default_main_program()
else: else:
...@@ -171,6 +168,39 @@ def print_cost(cost, params): ...@@ -171,6 +168,39 @@ def print_cost(cost, params):
return log_str return log_str
def split_files(files, trainer_id, trainers):
"""
split files before distributed training,
example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer
0 gets [a, b, c] and trainer 1 gets [d, e].
example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets
[a], trainer 1 gets [b], trainer 2 gets []
Args:
files(list): file list need to be read.
Returns:
list: files belongs to this worker.
"""
if not isinstance(files, list):
raise TypeError("files should be a list of file need to be read.")
remainder = len(files) % trainers
blocksize = int(len(files) / trainers)
blocks = [blocksize] * trainers
for i in range(remainder):
blocks[i] += 1
trainer_files = [[]] * trainers
begin = 0
for i in range(trainers):
trainer_files[i] = files[begin:begin + blocks[i]]
begin += blocks[i]
return trainer_files[trainer_id]
class CostPrinter(object): class CostPrinter(object):
""" """
For count cost time && print cost log For count cost time && print cost log
......
...@@ -139,7 +139,7 @@ def get_engine(args, running_config, mode): ...@@ -139,7 +139,7 @@ def get_engine(args, running_config, mode):
engine = "LOCAL_CLUSTER_TRAIN" engine = "LOCAL_CLUSTER_TRAIN"
if engine not in engine_choices: if engine not in engine_choices:
raise ValueError("{} can not be chosen in {}".format(engine_class, raise ValueError("{} can only be chosen in {}".format(engine_class,
engine_choices)) engine_choices))
run_engine = engines[transpiler].get(engine, None) run_engine = engines[transpiler].get(engine, None)
...@@ -439,8 +439,8 @@ def local_cluster_engine(args): ...@@ -439,8 +439,8 @@ def local_cluster_engine(args):
if fleet_mode == "COLLECTIVE": if fleet_mode == "COLLECTIVE":
cluster_envs["selected_gpus"] = selected_gpus cluster_envs["selected_gpus"] = selected_gpus
gpus = selected_gpus.split(",") gpus = selected_gpus.split(",")
gpu_num = get_worker_num(run_extras, len(gpus)) worker_num = get_worker_num(run_extras, len(gpus))
cluster_envs["selected_gpus"] = ','.join(gpus[:gpu_num]) cluster_envs["selected_gpus"] = ','.join(gpus[:worker_num])
cluster_envs["server_num"] = server_num cluster_envs["server_num"] = server_num
cluster_envs["worker_num"] = worker_num cluster_envs["worker_num"] = worker_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册