提交 c52c0d25 编写于 作者: L liuyuhui

Revert "fix read hide files bug for core/trainers/framework/dataset.py...

Revert "fix read hide files bug for core/trainers/framework/dataset.py core/utils/dataloader_instance.py"

This reverts commit 348042af.
上级 348042af
......@@ -89,24 +89,6 @@ class QueueDataset(DatasetBase):
else:
return self._get_dataset(dataset_name, context)
def check_filelist(self, file_list, train_data_path):
for root, dirs, files in os.walk(train_data_path):
files = [f for f in files if not f[0] == '.']
dirs[:] = [d for d in dirs if not d[0] == '.']
if (files == None and dirs == None):
return None
else:
# use files and dirs
for file_name in files:
file_list.append(os.path.join(train_data_path, file_name))
print(os.path.join(train_data_path, file_name))
for dirs_name in dirs:
dir_root.append(os.path.join(train_data_path, dirs_name))
check_filelist(file_list, os.path.join(
train_data_path, dirs_name))
print(os.path.join(train_data_path, dirs_name))
return file_list
def _get_dataset(self, dataset_name, context):
name = "dataset." + dataset_name + "."
reader_class = envs.get_global_env(name + "data_converter")
......@@ -137,13 +119,10 @@ class QueueDataset(DatasetBase):
dataset.set_pipe_command(pipe_cmd)
train_data_path = envs.get_global_env(name + "data_path")
# file_list = [
# os.path.join(train_data_path, x)
# for x in os.listdir(train_data_path)
# ]
file_list=[]
file_list = self.check_filelist(file_list,train_data_path)
file_list = [
os.path.join(train_data_path, x)
for x in os.listdir(train_data_path)
]
if context["engine"] == EngineMode.LOCAL_CLUSTER:
file_list = split_files(file_list, context["fleet"].worker_index(),
context["fleet"].worker_num())
......
......@@ -38,27 +38,7 @@ def dataloader_by_name(readerclass,
assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1])
def check_filelist(file_list, train_data_path):
for root, dirs, files in os.walk(train_data_path):
files = [f for f in files if not f[0] == '.']
dirs[:] = [d for d in dirs if not d[0] == '.']
if (files == None and dirs == None):
return None
else:
# use files and dirs
for file_name in files:
file_list.append(os.path.join(train_data_path, file_name))
print(os.path.join(train_data_path, file_name))
for dirs_name in dirs:
dir_root.append(os.path.join(train_data_path, dirs_name))
check_filelist(file_list, os.path.join(
train_data_path, dirs_name))
print(os.path.join(train_data_path, dirs_name))
return file_list
#files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
files = []
files = check_filelist(files, data_path)
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
if context["engine"] == EngineMode.LOCAL_CLUSTER:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
......@@ -100,27 +80,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1])
def check_filelist(file_list, train_data_path):
for root, dirs, files in os.walk(train_data_path):
files = [f for f in files if not f[0] == '.']
dirs[:] = [d for d in dirs if not d[0] == '.']
if (files == None and dirs == None):
return None
else:
# use files and dirs
for file_name in files:
file_list.append(os.path.join(train_data_path, file_name))
print(os.path.join(train_data_path, file_name))
for dirs_name in dirs:
dir_root.append(os.path.join(train_data_path, dirs_name))
check_filelist(file_list, os.path.join(
train_data_path, dirs_name))
print(os.path.join(train_data_path, dirs_name))
return file_list
#files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
files = []
files = check_filelist(files, data_path)
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
if context["engine"] == EngineMode.LOCAL_CLUSTER:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册