提交 99d466a1 编写于 作者: T tangwei

fix windows adapter

上级 211f7e38
...@@ -88,7 +88,7 @@ class ModelBase(object): ...@@ -88,7 +88,7 @@ class ModelBase(object):
self._data_var.append(l) self._data_var.append(l)
self._sparse_data_var.append(l) self._sparse_data_var.append(l)
dataset_class = dataset["type"] dataset_class = envs.get_global_env(name + "type")
if dataset_class == "DataLoader": if dataset_class == "DataLoader":
self._init_dataloader() self._init_dataloader()
...@@ -204,31 +204,8 @@ class ModelBase(object): ...@@ -204,31 +204,8 @@ class ModelBase(object):
def net(self, is_infer=False): def net(self, is_infer=False):
return None return None
def _construct_reader(self, is_infer=False):
if is_infer:
self._infer_data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._infer_data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
else:
dataset_class = envs.get_global_env("dataset_class", None,
"train.reader")
if dataset_class == "DataLoader":
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
def train_net(self): def train_net(self):
input_data = self.input_data(is_infer=False) pass
self._data_var = input_data
self._construct_reader(is_infer=False)
self.net(input_data, is_infer=False)
def infer_net(self): def infer_net(self):
input_data = self.input_data(is_infer=True) pass
self._infer_data_var = input_data
self._construct_reader(is_infer=True)
self.net(input_data, is_infer=True)
...@@ -97,7 +97,8 @@ class SingleNetwork(NetworkBase): ...@@ -97,7 +97,8 @@ class SingleNetwork(NetworkBase):
context["dataset"] = {} context["dataset"] = {}
for dataset in context["env"]["dataset"]: for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader": type = envs.get_global_env("dataset." + dataset["name"] + ".type")
if type != "DataLoader":
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(dataset["name"], "name"]] = dataset_class.create_dataset(dataset["name"],
...@@ -155,7 +156,9 @@ class PSNetwork(NetworkBase): ...@@ -155,7 +156,9 @@ class PSNetwork(NetworkBase):
context["fleet"].init_worker() context["fleet"].init_worker()
context["dataset"] = {} context["dataset"] = {}
for dataset in context["env"]["dataset"]: for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader": type = envs.get_global_env("dataset." + dataset["name"] +
".type")
if type != "DataLoader":
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][dataset[
"name"]] = dataset_class.create_dataset( "name"]] = dataset_class.create_dataset(
...@@ -248,7 +251,9 @@ class PslibNetwork(NetworkBase): ...@@ -248,7 +251,9 @@ class PslibNetwork(NetworkBase):
else: else:
context["dataset"] = {} context["dataset"] = {}
for dataset in context["env"]["dataset"]: for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader": type = envs.get_global_env("dataset." + dataset["name"] +
".type")
if type != "DataLoader":
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][dataset[
"name"]] = dataset_class.create_dataset( "name"]] = dataset_class.create_dataset(
...@@ -312,7 +317,8 @@ class CollectiveNetwork(NetworkBase): ...@@ -312,7 +317,8 @@ class CollectiveNetwork(NetworkBase):
context["dataset"] = {} context["dataset"] = {}
for dataset in context["env"]["dataset"]: for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader": type = envs.get_global_env("dataset." + dataset["name"] + ".type")
if type != "DataLoader":
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(dataset["name"], "name"]] = dataset_class.create_dataset(dataset["name"],
......
...@@ -155,7 +155,7 @@ class RunnerBase(object): ...@@ -155,7 +155,7 @@ class RunnerBase(object):
gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
else: else:
raise ValueError( raise ValueError(
"Unsurpported config. gradient_scale_strategy must be one of [0, 1, 2]." "Unsupported config. gradient_scale_strategy must be one of [0, 1, 2]."
) )
_build_strategy.gradient_scale_strategy = gradient_scale_strategy _build_strategy.gradient_scale_strategy = gradient_scale_strategy
......
...@@ -96,6 +96,11 @@ def set_global_envs(envs): ...@@ -96,6 +96,11 @@ def set_global_envs(envs):
value = os_path_adapter(workspace_adapter(value)) value = os_path_adapter(workspace_adapter(value))
global_envs[name] = value global_envs[name] = value
if get_platform() != "LINUX":
for dataset in envs["dataset"]:
name = ".".join("dataset", dataset["name"], "type")
global_envs[name] = "DataLoader"
def get_global_env(env_name, default_value=None, namespace=None): def get_global_env(env_name, default_value=None, namespace=None):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册