diff --git a/core/trainers/framework/dataset.py b/core/trainers/framework/dataset.py index 8059eeb09a482671b8329fb88f5b52cfd64f163b..b0d18acd8cbcc0326a0e37da4398b8fdca2d002c 100644 --- a/core/trainers/framework/dataset.py +++ b/core/trainers/framework/dataset.py @@ -89,6 +89,24 @@ class QueueDataset(DatasetBase): else: return self._get_dataset(dataset_name, context) + def check_filelist(self, file_list, train_data_path): + for root, dirs, files in os.walk(train_data_path): + files = [f for f in files if not f[0] == '.'] + dirs[:] = [d for d in dirs if not d[0] == '.'] + if (files == None and dirs == None): + return None + else: + # use files and dirs + for file_name in files: + file_list.append(os.path.join(train_data_path, file_name)) + print(os.path.join(train_data_path, file_name)) + for dirs_name in dirs: + dir_root.append(os.path.join(train_data_path, dirs_name)) + check_filelist(file_list, + os.path.join(train_data_path, dirs_name)) + print(os.path.join(train_data_path, dirs_name)) + return file_list + def _get_dataset(self, dataset_name, context): name = "dataset." + dataset_name + "." reader_class = envs.get_global_env(name + "data_converter") @@ -119,10 +137,13 @@ class QueueDataset(DatasetBase): dataset.set_pipe_command(pipe_cmd) train_data_path = envs.get_global_env(name + "data_path") - file_list = [ - os.path.join(train_data_path, x) - for x in os.listdir(train_data_path) - ] + # file_list = [ + # os.path.join(train_data_path, x) + # for x in os.listdir(train_data_path) + # ] + file_list = [] + file_list = self.check_filelist(file_list, train_data_path) + if context["engine"] == EngineMode.LOCAL_CLUSTER: file_list = split_files(file_list, context["fleet"].worker_index(), context["fleet"].worker_num()) diff --git a/core/utils/dataloader_instance.py b/core/utils/dataloader_instance.py index 2461473aa79a51133db8aa319f4ee7d45981d815..76b113b72086faa5ec866f62518d86a00a0d0793 100755 --- a/core/utils/dataloader_instance.py +++ b/core/utils/dataloader_instance.py @@ -38,7 +38,27 @@ def dataloader_by_name(readerclass, assert package_base is not None data_path = os.path.join(package_base, data_path.split("::")[1]) - files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] + def check_filelist(file_list, train_data_path): + for root, dirs, files in os.walk(train_data_path): + files = [f for f in files if not f[0] == '.'] + dirs[:] = [d for d in dirs if not d[0] == '.'] + if (files == None and dirs == None): + return None + else: + # use files and dirs + for file_name in files: + file_list.append(os.path.join(train_data_path, file_name)) + print(os.path.join(train_data_path, file_name)) + for dirs_name in dirs: + dir_root.append(os.path.join(train_data_path, dirs_name)) + check_filelist(file_list, + os.path.join(train_data_path, dirs_name)) + print(os.path.join(train_data_path, dirs_name)) + return file_list + + #files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] + files = [] + files = check_filelist(files, data_path) if context["engine"] == EngineMode.LOCAL_CLUSTER: files = split_files(files, context["fleet"].worker_index(), context["fleet"].worker_num()) @@ -80,7 +100,27 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context): assert package_base is not None data_path = os.path.join(package_base, data_path.split("::")[1]) - files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] + def check_filelist(file_list, train_data_path): + for root, dirs, files in os.walk(train_data_path): + files = [f for f in files if not f[0] == '.'] + dirs[:] = [d for d in dirs if not d[0] == '.'] + if (files == None and dirs == None): + return None + else: + # use files and dirs + for file_name in files: + file_list.append(os.path.join(train_data_path, file_name)) + print(os.path.join(train_data_path, file_name)) + for dirs_name in dirs: + dir_root.append(os.path.join(train_data_path, dirs_name)) + check_filelist(file_list, + os.path.join(train_data_path, dirs_name)) + print(os.path.join(train_data_path, dirs_name)) + return file_list + + #files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] + files = [] + files = check_filelist(files, data_path) if context["engine"] == EngineMode.LOCAL_CLUSTER: files = split_files(files, context["fleet"].worker_index(), context["fleet"].worker_num())