提交 1726a417 编写于 作者: L liuyuhui

fix UT failure

上级 765870b2
...@@ -49,7 +49,8 @@ def dataloader_by_name(readerclass, ...@@ -49,7 +49,8 @@ def dataloader_by_name(readerclass,
files.sort() files.sort()
# for local cluster: discard some files if files cannot be divided equally between GPUs # for local cluster: discard some files if files cannot be divided equally between GPUs
if (context["device"] == "GPU"): if (context["device"] == "GPU"
) and os.getenv("PADDLEREC_GPU_NUMS") is not None:
selected_gpu_nums = int(os.getenv("PADDLEREC_GPU_NUMS")) selected_gpu_nums = int(os.getenv("PADDLEREC_GPU_NUMS"))
discard_file_nums = len(files) % selected_gpu_nums discard_file_nums = len(files) % selected_gpu_nums
if (discard_file_nums != 0): if (discard_file_nums != 0):
...@@ -121,7 +122,8 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context): ...@@ -121,7 +122,8 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
files.sort() files.sort()
# for local cluster: discard some files if files cannot be divided equally between GPUs # for local cluster: discard some files if files cannot be divided equally between GPUs
if (context["device"] == "GPU"): if (context["device"] == "GPU"
) and os.getenv("PADDLEREC_GPU_NUMS") is not None:
selected_gpu_nums = int(os.getenv("PADDLEREC_GPU_NUMS")) selected_gpu_nums = int(os.getenv("PADDLEREC_GPU_NUMS"))
discard_file_nums = len(files) % selected_gpu_nums discard_file_nums = len(files) % selected_gpu_nums
if (discard_file_nums != 0): if (discard_file_nums != 0):
...@@ -201,7 +203,8 @@ def slotdataloader(readerclass, train, yaml_file, context): ...@@ -201,7 +203,8 @@ def slotdataloader(readerclass, train, yaml_file, context):
files.sort() files.sort()
# for local cluster: discard some files if files cannot be divided equally between GPUs # for local cluster: discard some files if files cannot be divided equally between GPUs
if (context["device"] == "GPU"): if (context["device"] == "GPU"
) and os.getenv("PADDLEREC_GPU_NUMS") is not None:
selected_gpu_nums = int(os.getenv("PADDLEREC_GPU_NUMS")) selected_gpu_nums = int(os.getenv("PADDLEREC_GPU_NUMS"))
discard_file_nums = len(files) % selected_gpu_nums discard_file_nums = len(files) % selected_gpu_nums
if (discard_file_nums != 0): if (discard_file_nums != 0):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册