提交 53b6f16a 编写于 作者: T tangwei

fix windows adapter

上级 62efc8dc
...@@ -37,7 +37,7 @@ class ModelBase(object): ...@@ -37,7 +37,7 @@ class ModelBase(object):
self._fetch_interval = 20 self._fetch_interval = 20
self._platform = envs.get_platform() self._platform = envs.get_platform()
self._init_hyper_parameters() self._init_hyper_parameters()
self._env = config context["env"] = config
self._slot_inited = False self._slot_inited = False
def _init_hyper_parameters(self): def _init_hyper_parameters(self):
...@@ -49,11 +49,11 @@ class ModelBase(object): ...@@ -49,11 +49,11 @@ class ModelBase(object):
self._slot_inited = True self._slot_inited = True
dataset = {} dataset = {}
model_dict = {} model_dict = {}
for i in self._env("phase"): for i in context["env"]("phase"):
if i["name"] == kargs["name"]: if i["name"] == kargs["name"]:
model_dict = i model_dict = i
break break
for i in self._env("dataset"): for i in context["env"]("dataset"):
if i["name"] == model_dict["dataset_name"]: if i["name"] == model_dict["dataset_name"]:
dataset = i dataset = i
break break
......
...@@ -96,7 +96,7 @@ class SingleNetwork(NetworkBase): ...@@ -96,7 +96,7 @@ class SingleNetwork(NetworkBase):
"default_main_program"] = train_program.clone() "default_main_program"] = train_program.clone()
context["dataset"] = {} context["dataset"] = {}
for dataset in self._env["dataset"]: for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader": if dataset["type"] != "DataLoader":
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][dataset[
...@@ -113,12 +113,12 @@ class PSNetwork(NetworkBase): ...@@ -113,12 +113,12 @@ class PSNetwork(NetworkBase):
def build_network(self, context): def build_network(self, context):
context["model"] = {} context["model"] = {}
if len(self._env["phase"]) > 1: if len(context["env"]["phase"]) > 1:
warnings.warn( warnings.warn(
"Cluster Train Only Support One Phase.", "Cluster Train Only Support One Phase.",
category=UserWarning, category=UserWarning,
stacklevel=2) stacklevel=2)
model_dict = self._env["phase"][0] model_dict = context["env"]["phase"][0]
context["model"][model_dict["name"]] = {} context["model"][model_dict["name"]] = {}
dataset_name = model_dict["dataset_name"] dataset_name = model_dict["dataset_name"]
...@@ -154,7 +154,7 @@ class PSNetwork(NetworkBase): ...@@ -154,7 +154,7 @@ class PSNetwork(NetworkBase):
else: else:
context["fleet"].init_worker() context["fleet"].init_worker()
context["dataset"] = {} context["dataset"] = {}
for dataset in self._env["dataset"]: for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader": if dataset["type"] != "DataLoader":
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][dataset[
...@@ -200,12 +200,12 @@ class PslibNetwork(NetworkBase): ...@@ -200,12 +200,12 @@ class PslibNetwork(NetworkBase):
def build_network(self, context): def build_network(self, context):
context["model"] = {} context["model"] = {}
if len(self._env["phase"]) > 1: if len(context["env"]["phase"]) > 1:
warnings.warn( warnings.warn(
"Cluster Train Only Support One Phase.", "Cluster Train Only Support One Phase.",
category=UserWarning, category=UserWarning,
stacklevel=2) stacklevel=2)
model_dict = self._env["phase"][0] model_dict = context["env"]["phase"][0]
train_program = fluid.Program() train_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
scope = fluid.Scope() scope = fluid.Scope()
...@@ -247,7 +247,7 @@ class PslibNetwork(NetworkBase): ...@@ -247,7 +247,7 @@ class PslibNetwork(NetworkBase):
self._server(context) self._server(context)
else: else:
context["dataset"] = {} context["dataset"] = {}
for dataset in self._env["dataset"]: for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader": if dataset["type"] != "DataLoader":
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][dataset[
...@@ -267,12 +267,12 @@ class CollectiveNetwork(NetworkBase): ...@@ -267,12 +267,12 @@ class CollectiveNetwork(NetworkBase):
def build_network(self, context): def build_network(self, context):
context["model"] = {} context["model"] = {}
if len(self._env["phase"]) > 1: if len(context["env"]["phase"]) > 1:
warnings.warn( warnings.warn(
"Cluster Train Only Support One Phase.", "Cluster Train Only Support One Phase.",
category=UserWarning, category=UserWarning,
stacklevel=2) stacklevel=2)
model_dict = self._env["phase"][0] model_dict = context["env"]["phase"][0]
context["model"][model_dict["name"]] = {} context["model"][model_dict["name"]] = {}
dataset_name = model_dict["dataset_name"] dataset_name = model_dict["dataset_name"]
...@@ -311,7 +311,7 @@ class CollectiveNetwork(NetworkBase): ...@@ -311,7 +311,7 @@ class CollectiveNetwork(NetworkBase):
"default_main_program"] = train_program "default_main_program"] = train_program
context["dataset"] = {} context["dataset"] = {}
for dataset in self._env["dataset"]: for dataset in context["env"]["dataset"]:
if dataset["type"] != "DataLoader": if dataset["type"] != "DataLoader":
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][dataset[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册