diff --git a/core/trainers/framework/dataset.py b/core/trainers/framework/dataset.py index 5c5a2357ff4a07d54d4e0c56e692b4d79fcb2095..278df79f4d6e116a19e106ad141e417779f6c02d 100644 --- a/core/trainers/framework/dataset.py +++ b/core/trainers/framework/dataset.py @@ -21,7 +21,7 @@ from paddlerec.core.utils import envs from paddlerec.core.utils import dataloader_instance from paddlerec.core.reader import SlotReader from paddlerec.core.trainer import EngineMode -from paddlerec.core.utils.util import split_files +from paddlerec.core.utils.util import split_files, check_filelist __all__ = ["DatasetBase", "DataLoader", "QueueDataset"] @@ -119,10 +119,15 @@ class QueueDataset(DatasetBase): dataset.set_pipe_command(pipe_cmd) train_data_path = envs.get_global_env(name + "data_path") - file_list = [ - os.path.join(train_data_path, x) - for x in os.listdir(train_data_path) - ] + hidden_file_list, file_list = check_filelist( + hidden_file_list=[], + data_file_list=[], + train_data_path=train_data_path) + if (hidden_file_list is not None): + print( + "Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}". + format(hidden_file_list)) + file_list.sort() need_split_files = False if context["engine"] == EngineMode.LOCAL_CLUSTER: diff --git a/core/utils/dataloader_instance.py b/core/utils/dataloader_instance.py index d878f08415c7b0405bc593f06ab4541801aa5501..f484626a73f69481f3ae51b35fc5b6e717870938 100755 --- a/core/utils/dataloader_instance.py +++ b/core/utils/dataloader_instance.py @@ -19,7 +19,7 @@ from paddlerec.core.utils.envs import get_global_env from paddlerec.core.utils.envs import get_runtime_environ from paddlerec.core.reader import SlotReader from paddlerec.core.trainer import EngineMode -from paddlerec.core.utils.util import split_files +from paddlerec.core.utils.util import split_files, check_filelist def dataloader_by_name(readerclass, @@ -38,7 +38,13 @@ def dataloader_by_name(readerclass, assert package_base is not None data_path = os.path.join(package_base, data_path.split("::")[1]) - files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] + hidden_file_list, files = check_filelist( + hidden_file_list=[], data_file_list=[], train_data_path=data_path) + if (hidden_file_list is not None): + print( + "Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}". + format(hidden_file_list)) + files.sort() need_split_files = False @@ -54,8 +60,6 @@ def dataloader_by_name(readerclass, files = split_files(files, context["fleet"].worker_index(), context["fleet"].worker_num()) - print("file_list : {}".format(files)) - reader = reader_class(yaml_file) reader.init() @@ -92,7 +96,13 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context): assert package_base is not None data_path = os.path.join(package_base, data_path.split("::")[1]) - files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] + hidden_file_list, files = check_filelist( + hidden_file_list=[], data_file_list=[], train_data_path=data_path) + if (hidden_file_list is not None): + print( + "Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}". + format(hidden_file_list)) + files.sort() need_split_files = False @@ -156,7 +166,13 @@ def slotdataloader(readerclass, train, yaml_file, context): assert package_base is not None data_path = os.path.join(package_base, data_path.split("::")[1]) - files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] + hidden_file_list, files = check_filelist( + hidden_file_list=[], data_file_list=[], train_data_path=data_path) + if (hidden_file_list is not None): + print( + "Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}". + format(hidden_file_list)) + files.sort() need_split_files = False diff --git a/core/utils/util.py b/core/utils/util.py index 4eba912cafda6619ba37c3f8bc170d7d41ea40c4..f6acfe203612326a77f41326581583278dac4183 100755 --- a/core/utils/util.py +++ b/core/utils/util.py @@ -201,6 +201,28 @@ def split_files(files, trainer_id, trainers): return trainer_files[trainer_id] +def check_filelist(hidden_file_list, data_file_list, train_data_path): + for root, dirs, files in os.walk(train_data_path): + if (files == None and dirs == None): + return None, None + else: + # use files and dirs + for file_name in files: + file_path = os.path.join(train_data_path, file_name) + if file_name[0] == '.': + hidden_file_list.append(file_path) + else: + data_file_list.append(file_path) + for dirs_name in dirs: + dirs_path = os.path.join(train_data_path, dirs_name) + if dirs_name[0] == '.': + hidden_file_list.append(dirs_path) + else: + #train_data_path = os.path.join(train_data_path, dirs_name) + check_filelist(hidden_file_list, data_file_list, dirs_path) + return hidden_file_list, data_file_list + + class CostPrinter(object): """ For count cost time && print cost log