提交 15d8501e 编写于 作者: L liuyuhui

mv check_filelist to util.py and add Warning for hidden files,test=develop

上级 d229f76f
...@@ -21,7 +21,7 @@ from paddlerec.core.utils import envs ...@@ -21,7 +21,7 @@ from paddlerec.core.utils import envs
from paddlerec.core.utils import dataloader_instance from paddlerec.core.utils import dataloader_instance
from paddlerec.core.reader import SlotReader from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode from paddlerec.core.trainer import EngineMode
from paddlerec.core.utils.util import split_files from paddlerec.core.utils.util import split_files, check_filelist
__all__ = ["DatasetBase", "DataLoader", "QueueDataset"] __all__ = ["DatasetBase", "DataLoader", "QueueDataset"]
...@@ -89,24 +89,6 @@ class QueueDataset(DatasetBase): ...@@ -89,24 +89,6 @@ class QueueDataset(DatasetBase):
else: else:
return self._get_dataset(dataset_name, context) return self._get_dataset(dataset_name, context)
def check_filelist(self, file_list, train_data_path):
for root, dirs, files in os.walk(train_data_path):
files = [f for f in files if not f[0] == '.']
dirs[:] = [d for d in dirs if not d[0] == '.']
if (files == None and dirs == None):
return None
else:
# use files and dirs
for file_name in files:
file_list.append(os.path.join(train_data_path, file_name))
print(os.path.join(train_data_path, file_name))
for dirs_name in dirs:
dir_root.append(os.path.join(train_data_path, dirs_name))
check_filelist(file_list,
os.path.join(train_data_path, dirs_name))
print(os.path.join(train_data_path, dirs_name))
return file_list
def _get_dataset(self, dataset_name, context): def _get_dataset(self, dataset_name, context):
name = "dataset." + dataset_name + "." name = "dataset." + dataset_name + "."
reader_class = envs.get_global_env(name + "data_converter") reader_class = envs.get_global_env(name + "data_converter")
...@@ -137,12 +119,14 @@ class QueueDataset(DatasetBase): ...@@ -137,12 +119,14 @@ class QueueDataset(DatasetBase):
dataset.set_pipe_command(pipe_cmd) dataset.set_pipe_command(pipe_cmd)
train_data_path = envs.get_global_env(name + "data_path") train_data_path = envs.get_global_env(name + "data_path")
# file_list = [ hidden_file_list, file_list = check_filelist(
# os.path.join(train_data_path, x) hidden_file_list=[],
# for x in os.listdir(train_data_path) data_file_list=[],
# ] train_data_path=train_data_path)
file_list = [] if (hidden_file_list is not None):
file_list = self.check_filelist(file_list, train_data_path) print(
"Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
format(hidden_file_list))
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
file_list = split_files(file_list, context["fleet"].worker_index(), file_list = split_files(file_list, context["fleet"].worker_index(),
......
...@@ -19,7 +19,7 @@ from paddlerec.core.utils.envs import get_global_env ...@@ -19,7 +19,7 @@ from paddlerec.core.utils.envs import get_global_env
from paddlerec.core.utils.envs import get_runtime_environ from paddlerec.core.utils.envs import get_runtime_environ
from paddlerec.core.reader import SlotReader from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode from paddlerec.core.trainer import EngineMode
from paddlerec.core.utils.util import split_files from paddlerec.core.utils.util import split_files, check_filelist
def dataloader_by_name(readerclass, def dataloader_by_name(readerclass,
...@@ -38,27 +38,13 @@ def dataloader_by_name(readerclass, ...@@ -38,27 +38,13 @@ def dataloader_by_name(readerclass,
assert package_base is not None assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1]) data_path = os.path.join(package_base, data_path.split("::")[1])
def check_filelist(file_list, train_data_path): hidden_file_list, files = check_filelist(
for root, dirs, files in os.walk(train_data_path): hidden_file_list=[], data_file_list=[], train_data_path=data_path)
files = [f for f in files if not f[0] == '.'] if (hidden_file_list is not None):
dirs[:] = [d for d in dirs if not d[0] == '.'] print(
if (files == None and dirs == None): "Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
return None format(hidden_file_list))
else:
# use files and dirs
for file_name in files:
file_list.append(os.path.join(train_data_path, file_name))
print(os.path.join(train_data_path, file_name))
for dirs_name in dirs:
dir_root.append(os.path.join(train_data_path, dirs_name))
check_filelist(file_list,
os.path.join(train_data_path, dirs_name))
print(os.path.join(train_data_path, dirs_name))
return file_list
#files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
files = []
files = check_filelist(files, data_path)
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
files = split_files(files, context["fleet"].worker_index(), files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
...@@ -100,27 +86,13 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context): ...@@ -100,27 +86,13 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
assert package_base is not None assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1]) data_path = os.path.join(package_base, data_path.split("::")[1])
def check_filelist(file_list, train_data_path): hidden_file_list, files = check_filelist(
for root, dirs, files in os.walk(train_data_path): hidden_file_list=[], data_file_list=[], train_data_path=data_path)
files = [f for f in files if not f[0] == '.'] if (hidden_file_list is not None):
dirs[:] = [d for d in dirs if not d[0] == '.'] print(
if (files == None and dirs == None): "Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
return None format(hidden_file_list))
else:
# use files and dirs
for file_name in files:
file_list.append(os.path.join(train_data_path, file_name))
print(os.path.join(train_data_path, file_name))
for dirs_name in dirs:
dir_root.append(os.path.join(train_data_path, dirs_name))
check_filelist(file_list,
os.path.join(train_data_path, dirs_name))
print(os.path.join(train_data_path, dirs_name))
return file_list
#files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
files = []
files = check_filelist(files, data_path)
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
files = split_files(files, context["fleet"].worker_index(), files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
......
...@@ -201,6 +201,28 @@ def split_files(files, trainer_id, trainers): ...@@ -201,6 +201,28 @@ def split_files(files, trainer_id, trainers):
return trainer_files[trainer_id] return trainer_files[trainer_id]
def check_filelist(hidden_file_list, data_file_list, train_data_path):
for root, dirs, files in os.walk(train_data_path):
if (files == None and dirs == None):
return None, None
else:
# use files and dirs
for file_name in files:
file_path = os.path.join(train_data_path, file_name)
if file_name[0] == '.':
hidden_file_list.append(file_path)
else:
data_file_list.append(file_path)
for dirs_name in dirs:
dirs_path = os.path.join(train_data_path, dirs_name)
if dirs_name[0] == '.':
hidden_file_list.append(dirs_path)
else:
#train_data_path = os.path.join(train_data_path, dirs_name)
check_filelist(hidden_file_list, data_file_list, dirs_path)
return hidden_file_list, data_file_list
class CostPrinter(object): class CostPrinter(object):
""" """
For count cost time && print cost log For count cost time && print cost log
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册