From fe6fc5b360cc32642966cf15ed31d43c9ffb2a5e Mon Sep 17 00:00:00 2001 From: tangwei Date: Wed, 10 Jun 2020 21:03:08 +0800 Subject: [PATCH] fix windows adapter --- core/trainers/framework/network.py | 12 +++++++---- core/utils/envs.py | 33 ++++++++++++------------------ 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index b012fb7d..6685842c 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -58,7 +58,8 @@ class SingleNetwork(NetworkBase): with fluid.program_guard(train_program, startup_program): with fluid.unique_name.guard(): with fluid.scope_guard(scope): - model_path = model_dict["model"] + model_path = envs.os_path_adapter( + envs.workspace_adapter(model_dict["model"])) model = envs.lazy_instance_by_fliename(model_path, "Model")(None) @@ -121,7 +122,8 @@ class PSNetwork(NetworkBase): context["model"][model_dict["name"]] = {} dataset_name = model_dict["dataset_name"] - model_path = model_dict["model"] + model_path = envs.os_path_adapter( + envs.workspace_adapter(model_dict["model"])) model = envs.lazy_instance_by_fliename(model_path, "Model")(None) model._data_var = model.input_data( dataset_name=model_dict["dataset_name"]) @@ -212,7 +214,8 @@ class PslibNetwork(NetworkBase): with fluid.unique_name.guard(): with fluid.scope_guard(scope): context["model"][model_dict["name"]] = {} - model_path = model_dict["model"] + model_path = envs.os_path_adapter( + envs.workspace_adapter(model_dict["model"])) model = envs.lazy_instance_by_fliename(model_path, "Model")(None) model._data_var = model.input_data( @@ -277,7 +280,8 @@ class CollectiveNetwork(NetworkBase): scope = fluid.Scope() with fluid.program_guard(train_program, startup_program): with fluid.scope_guard(scope): - model_path = model_dict["model"] + model_path = envs.os_path_adapter( + envs.workspace_adapter(model_dict["model"])) model = envs.lazy_instance_by_fliename(model_path, "Model")(None) model._data_var = model.input_data( diff --git a/core/utils/envs.py b/core/utils/envs.py index 526d098e..98bc6a8c 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -91,9 +91,10 @@ def set_global_envs(envs): fatten_env_namespace([], envs) - workspace_adapter() - os_path_adapter() - reader_adapter() + for name, value in global_envs.items(): + if isinstance(value, str): + value = os_path_adapter(workspace_adapter(value)) + global_envs[name] = value def get_global_env(env_name, default_value=None, namespace=None): @@ -118,27 +119,19 @@ def paddlerec_adapter(path): return path -def os_path_adapter(): - for name, value in global_envs.items(): - if isinstance(value, str): - if get_platform() == "WINDOWS": - value = value.replace("/", "\\") - else: - value = value.replace("\\", "/") - global_envs[name] = value +def os_path_adapter(value): + if get_platform() == "WINDOWS": + value = value.replace("/", "\\") + else: + value = value.replace("\\", "/") + return value -def workspace_adapter(): +def workspace_adapter(value): workspace = global_envs.get("workspace") - if not workspace: - return - workspace = paddlerec_adapter(workspace) - - for name, value in global_envs.items(): - if isinstance(value, str): - value = value.replace("{workspace}", workspace) - global_envs[name] = value + value = value.replace("{workspace}", workspace) + return value def reader_adapter(): -- GitLab