提交 62efc8dc 编写于 作者: T tangwei

fix windows adapter

上级 68208bc4
......@@ -76,6 +76,7 @@ class Trainer(object):
_config = envs.load_yaml(config)
self._context["env"] = _config
self._context["dataset"] = _config.get("dataset")
phases = []
......
......@@ -96,7 +96,7 @@ class SingleNetwork(NetworkBase):
"default_main_program"] = train_program.clone()
context["dataset"] = {}
for dataset in envs.get_global_env("dataset"):
for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
......@@ -113,12 +113,12 @@ class PSNetwork(NetworkBase):
def build_network(self, context):
context["model"] = {}
if len(envs.get_global_env("phase")) > 1:
if len(self._env["phase"]) > 1:
warnings.warn(
"Cluster Train Only Support One Phase.",
category=UserWarning,
stacklevel=2)
model_dict = envs.get_global_env("phase")[0]
model_dict = self._env["phase"][0]
context["model"][model_dict["name"]] = {}
dataset_name = model_dict["dataset_name"]
......@@ -154,7 +154,7 @@ class PSNetwork(NetworkBase):
else:
context["fleet"].init_worker()
context["dataset"] = {}
for dataset in envs.get_global_env("dataset"):
for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
......@@ -200,12 +200,12 @@ class PslibNetwork(NetworkBase):
def build_network(self, context):
context["model"] = {}
if len(envs.get_global_env("phase")) > 1:
if len(self._env["phase"]) > 1:
warnings.warn(
"Cluster Train Only Support One Phase.",
category=UserWarning,
stacklevel=2)
model_dict = envs.get_global_env("phase")[0]
model_dict = self._env["phase"][0]
train_program = fluid.Program()
startup_program = fluid.Program()
scope = fluid.Scope()
......@@ -247,7 +247,7 @@ class PslibNetwork(NetworkBase):
self._server(context)
else:
context["dataset"] = {}
for dataset in envs.get_global_env("dataset"):
for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
......@@ -267,12 +267,12 @@ class CollectiveNetwork(NetworkBase):
def build_network(self, context):
context["model"] = {}
if len(envs.get_global_env("phase")) > 1:
if len(self._env["phase"]) > 1:
warnings.warn(
"Cluster Train Only Support One Phase.",
category=UserWarning,
stacklevel=2)
model_dict = envs.get_global_env("phase")[0]
model_dict = self._env["phase"][0]
context["model"][model_dict["name"]] = {}
dataset_name = model_dict["dataset_name"]
......@@ -311,7 +311,7 @@ class CollectiveNetwork(NetworkBase):
"default_main_program"] = train_program
context["dataset"] = {}
for dataset in envs.get_global_env("dataset"):
for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册