未验证 提交 af7c0e7d 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #174 from vslyu/vslyu-fixhidefiles

fix read hide files bug for reader bug
......@@ -21,7 +21,7 @@ from paddlerec.core.utils import envs
from paddlerec.core.utils import dataloader_instance
from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode
from paddlerec.core.utils.util import split_files
from paddlerec.core.utils.util import split_files, check_filelist
__all__ = ["DatasetBase", "DataLoader", "QueueDataset"]
......@@ -119,10 +119,15 @@ class QueueDataset(DatasetBase):
dataset.set_pipe_command(pipe_cmd)
train_data_path = envs.get_global_env(name + "data_path")
file_list = [
os.path.join(train_data_path, x)
for x in os.listdir(train_data_path)
]
hidden_file_list, file_list = check_filelist(
hidden_file_list=[],
data_file_list=[],
train_data_path=train_data_path)
if (hidden_file_list is not None):
print(
"Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
format(hidden_file_list))
file_list.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER:
......
......@@ -19,7 +19,7 @@ from paddlerec.core.utils.envs import get_global_env
from paddlerec.core.utils.envs import get_runtime_environ
from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode
from paddlerec.core.utils.util import split_files
from paddlerec.core.utils.util import split_files, check_filelist
def dataloader_by_name(readerclass,
......@@ -38,7 +38,13 @@ def dataloader_by_name(readerclass,
assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
hidden_file_list, files = check_filelist(
hidden_file_list=[], data_file_list=[], train_data_path=data_path)
if (hidden_file_list is not None):
print(
"Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
format(hidden_file_list))
files.sort()
need_split_files = False
......@@ -54,8 +60,6 @@ def dataloader_by_name(readerclass,
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list : {}".format(files))
reader = reader_class(yaml_file)
reader.init()
......@@ -92,7 +96,13 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
hidden_file_list, files = check_filelist(
hidden_file_list=[], data_file_list=[], train_data_path=data_path)
if (hidden_file_list is not None):
print(
"Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
format(hidden_file_list))
files.sort()
need_split_files = False
......@@ -156,7 +166,13 @@ def slotdataloader(readerclass, train, yaml_file, context):
assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
hidden_file_list, files = check_filelist(
hidden_file_list=[], data_file_list=[], train_data_path=data_path)
if (hidden_file_list is not None):
print(
"Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
format(hidden_file_list))
files.sort()
need_split_files = False
......
......@@ -201,6 +201,28 @@ def split_files(files, trainer_id, trainers):
return trainer_files[trainer_id]
def check_filelist(hidden_file_list, data_file_list, train_data_path):
for root, dirs, files in os.walk(train_data_path):
if (files == None and dirs == None):
return None, None
else:
# use files and dirs
for file_name in files:
file_path = os.path.join(train_data_path, file_name)
if file_name[0] == '.':
hidden_file_list.append(file_path)
else:
data_file_list.append(file_path)
for dirs_name in dirs:
dirs_path = os.path.join(train_data_path, dirs_name)
if dirs_name[0] == '.':
hidden_file_list.append(dirs_path)
else:
#train_data_path = os.path.join(train_data_path, dirs_name)
check_filelist(hidden_file_list, data_file_list, dirs_path)
return hidden_file_list, data_file_list
class CostPrinter(object):
"""
For count cost time && print cost log
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册