From 99d466a1a6cdafe6321f9dfc380070abc872cf43 Mon Sep 17 00:00:00 2001 From: tangwei Date: Thu, 11 Jun 2020 14:40:21 +0800 Subject: [PATCH] fix windows adapter --- core/model.py | 29 +++-------------------------- core/trainers/framework/network.py | 14 ++++++++++---- core/trainers/framework/runner.py | 2 +- core/utils/envs.py | 5 +++++ 4 files changed, 19 insertions(+), 31 deletions(-) diff --git a/core/model.py b/core/model.py index e39573c2..bb2040e4 100755 --- a/core/model.py +++ b/core/model.py @@ -88,7 +88,7 @@ class ModelBase(object): self._data_var.append(l) self._sparse_data_var.append(l) - dataset_class = dataset["type"] + dataset_class = envs.get_global_env(name + "type") if dataset_class == "DataLoader": self._init_dataloader() @@ -204,31 +204,8 @@ class ModelBase(object): def net(self, is_infer=False): return None - def _construct_reader(self, is_infer=False): - if is_infer: - self._infer_data_loader = fluid.io.DataLoader.from_generator( - feed_list=self._infer_data_var, - capacity=64, - use_double_buffer=False, - iterable=False) - else: - dataset_class = envs.get_global_env("dataset_class", None, - "train.reader") - if dataset_class == "DataLoader": - self._data_loader = fluid.io.DataLoader.from_generator( - feed_list=self._data_var, - capacity=64, - use_double_buffer=False, - iterable=False) - def train_net(self): - input_data = self.input_data(is_infer=False) - self._data_var = input_data - self._construct_reader(is_infer=False) - self.net(input_data, is_infer=False) + pass def infer_net(self): - input_data = self.input_data(is_infer=True) - self._infer_data_var = input_data - self._construct_reader(is_infer=True) - self.net(input_data, is_infer=True) + pass diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 7d1e276e..71f2a4e7 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -97,7 +97,8 @@ class SingleNetwork(NetworkBase): context["dataset"] = {} for dataset in context["env"]["dataset"]: - if dataset["type"] != "DataLoader": + type = envs.get_global_env("dataset." + dataset["name"] + ".type") + if type != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset(dataset["name"], @@ -155,7 +156,9 @@ class PSNetwork(NetworkBase): context["fleet"].init_worker() context["dataset"] = {} for dataset in context["env"]["dataset"]: - if dataset["type"] != "DataLoader": + type = envs.get_global_env("dataset." + dataset["name"] + + ".type") + if type != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset( @@ -248,7 +251,9 @@ class PslibNetwork(NetworkBase): else: context["dataset"] = {} for dataset in context["env"]["dataset"]: - if dataset["type"] != "DataLoader": + type = envs.get_global_env("dataset." + dataset["name"] + + ".type") + if type != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset( @@ -312,7 +317,8 @@ class CollectiveNetwork(NetworkBase): context["dataset"] = {} for dataset in context["env"]["dataset"]: - if dataset["type"] != "DataLoader": + type = envs.get_global_env("dataset." + dataset["name"] + ".type") + if type != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset(dataset["name"], diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 7c61c412..52be3e00 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -155,7 +155,7 @@ class RunnerBase(object): gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized else: raise ValueError( - "Unsurpported config. gradient_scale_strategy must be one of [0, 1, 2]." + "Unsupported config. gradient_scale_strategy must be one of [0, 1, 2]." ) _build_strategy.gradient_scale_strategy = gradient_scale_strategy diff --git a/core/utils/envs.py b/core/utils/envs.py index 98bc6a8c..f70db966 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -96,6 +96,11 @@ def set_global_envs(envs): value = os_path_adapter(workspace_adapter(value)) global_envs[name] = value + if get_platform() != "LINUX": + for dataset in envs["dataset"]: + name = ".".join("dataset", dataset["name"], "type") + global_envs[name] = "DataLoader" + def get_global_env(env_name, default_value=None, namespace=None): """ -- GitLab