diff --git a/core/model.py b/core/model.py index e39573c22b2aa65c8a2333a3953f9ed0e607168c..bb2040e44458d5a402d54bfc780ae501c9c9d06d 100755 --- a/core/model.py +++ b/core/model.py @@ -88,7 +88,7 @@ class ModelBase(object): self._data_var.append(l) self._sparse_data_var.append(l) - dataset_class = dataset["type"] + dataset_class = envs.get_global_env(name + "type") if dataset_class == "DataLoader": self._init_dataloader() @@ -204,31 +204,8 @@ class ModelBase(object): def net(self, is_infer=False): return None - def _construct_reader(self, is_infer=False): - if is_infer: - self._infer_data_loader = fluid.io.DataLoader.from_generator( - feed_list=self._infer_data_var, - capacity=64, - use_double_buffer=False, - iterable=False) - else: - dataset_class = envs.get_global_env("dataset_class", None, - "train.reader") - if dataset_class == "DataLoader": - self._data_loader = fluid.io.DataLoader.from_generator( - feed_list=self._data_var, - capacity=64, - use_double_buffer=False, - iterable=False) - def train_net(self): - input_data = self.input_data(is_infer=False) - self._data_var = input_data - self._construct_reader(is_infer=False) - self.net(input_data, is_infer=False) + pass def infer_net(self): - input_data = self.input_data(is_infer=True) - self._infer_data_var = input_data - self._construct_reader(is_infer=True) - self.net(input_data, is_infer=True) + pass diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 7d1e276ea2d1496a2225572d917fcb3f9b2fc1fb..71f2a4e7fa6ef671c3a07724183edf2e759aec5e 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -97,7 +97,8 @@ class SingleNetwork(NetworkBase): context["dataset"] = {} for dataset in context["env"]["dataset"]: - if dataset["type"] != "DataLoader": + type = envs.get_global_env("dataset." + dataset["name"] + ".type") + if type != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset(dataset["name"], @@ -155,7 +156,9 @@ class PSNetwork(NetworkBase): context["fleet"].init_worker() context["dataset"] = {} for dataset in context["env"]["dataset"]: - if dataset["type"] != "DataLoader": + type = envs.get_global_env("dataset." + dataset["name"] + + ".type") + if type != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset( @@ -248,7 +251,9 @@ class PslibNetwork(NetworkBase): else: context["dataset"] = {} for dataset in context["env"]["dataset"]: - if dataset["type"] != "DataLoader": + type = envs.get_global_env("dataset." + dataset["name"] + + ".type") + if type != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset( @@ -312,7 +317,8 @@ class CollectiveNetwork(NetworkBase): context["dataset"] = {} for dataset in context["env"]["dataset"]: - if dataset["type"] != "DataLoader": + type = envs.get_global_env("dataset." + dataset["name"] + ".type") + if type != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset(dataset["name"], diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 7c61c41259c610f27c5e643277f5dfc91d834164..52be3e003edf7f18d4b21c214277b622c30aeba8 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -155,7 +155,7 @@ class RunnerBase(object): gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized else: raise ValueError( - "Unsurpported config. gradient_scale_strategy must be one of [0, 1, 2]." + "Unsupported config. gradient_scale_strategy must be one of [0, 1, 2]." ) _build_strategy.gradient_scale_strategy = gradient_scale_strategy diff --git a/core/utils/envs.py b/core/utils/envs.py index 98bc6a8c9f789ae9480b98a17a09743ed75bd085..f70db966ae750a0c61e776aac6b6c8a3331081cd 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -96,6 +96,11 @@ def set_global_envs(envs): value = os_path_adapter(workspace_adapter(value)) global_envs[name] = value + if get_platform() != "LINUX": + for dataset in envs["dataset"]: + name = ".".join("dataset", dataset["name"], "type") + global_envs[name] = "DataLoader" + def get_global_env(env_name, default_value=None, namespace=None): """