提交 99d466a1 编写于 作者: T tangwei

fix windows adapter

上级 211f7e38
......@@ -88,7 +88,7 @@ class ModelBase(object):
self._data_var.append(l)
self._sparse_data_var.append(l)
dataset_class = dataset["type"]
dataset_class = envs.get_global_env(name + "type")
if dataset_class == "DataLoader":
self._init_dataloader()
......@@ -204,31 +204,8 @@ class ModelBase(object):
def net(self, is_infer=False):
return None
def _construct_reader(self, is_infer=False):
if is_infer:
self._infer_data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._infer_data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
else:
dataset_class = envs.get_global_env("dataset_class", None,
"train.reader")
if dataset_class == "DataLoader":
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
def train_net(self):
input_data = self.input_data(is_infer=False)
self._data_var = input_data
self._construct_reader(is_infer=False)
self.net(input_data, is_infer=False)
pass
def infer_net(self):
input_data = self.input_data(is_infer=True)
self._infer_data_var = input_data
self._construct_reader(is_infer=True)
self.net(input_data, is_infer=True)
pass
......@@ -97,7 +97,8 @@ class SingleNetwork(NetworkBase):
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader":
type = envs.get_global_env("dataset." + dataset["name"] + ".type")
if type != "DataLoader":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(dataset["name"],
......@@ -155,7 +156,9 @@ class PSNetwork(NetworkBase):
context["fleet"].init_worker()
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader":
type = envs.get_global_env("dataset." + dataset["name"] +
".type")
if type != "DataLoader":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(
......@@ -248,7 +251,9 @@ class PslibNetwork(NetworkBase):
else:
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader":
type = envs.get_global_env("dataset." + dataset["name"] +
".type")
if type != "DataLoader":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(
......@@ -312,7 +317,8 @@ class CollectiveNetwork(NetworkBase):
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader":
type = envs.get_global_env("dataset." + dataset["name"] + ".type")
if type != "DataLoader":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(dataset["name"],
......
......@@ -155,7 +155,7 @@ class RunnerBase(object):
gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
else:
raise ValueError(
"Unsurpported config. gradient_scale_strategy must be one of [0, 1, 2]."
"Unsupported config. gradient_scale_strategy must be one of [0, 1, 2]."
)
_build_strategy.gradient_scale_strategy = gradient_scale_strategy
......
......@@ -96,6 +96,11 @@ def set_global_envs(envs):
value = os_path_adapter(workspace_adapter(value))
global_envs[name] = value
if get_platform() != "LINUX":
for dataset in envs["dataset"]:
name = ".".join("dataset", dataset["name"], "type")
global_envs[name] = "DataLoader"
def get_global_env(env_name, default_value=None, namespace=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册