提交 70b6bbe8 编写于 作者: L liuyuhui

fix PSRunner multi dataset_name

上级 921d6c03
...@@ -238,8 +238,8 @@ class PSNetwork(NetworkBase): ...@@ -238,8 +238,8 @@ class PSNetwork(NetworkBase):
else: else:
context["fleet"].init_worker() context["fleet"].init_worker()
context["dataset"] = {} context["dataset"] = {}
for dataset in context["env"]["dataset"]: for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + dataset["name"] + type = envs.get_global_env("dataset." + phase["dataset_name"] +
".type") ".type")
if type == "DataLoader": if type == "DataLoader":
data_loader = DataLoader(context) data_loader = DataLoader(context)
...@@ -247,9 +247,9 @@ class PSNetwork(NetworkBase): ...@@ -247,9 +247,9 @@ class PSNetwork(NetworkBase):
model._data_loader) model._data_loader)
elif type == "QueueDataset": elif type == "QueueDataset":
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][phase[
"name"]] = dataset_class.create_dataset( "dataset_name"]] = dataset_class.create_dataset(
dataset["name"], context) phase["dataset_name"], context)
context["status"] = "startup_pass" context["status"] = "startup_pass"
def _build_strategy(self, context): def _build_strategy(self, context):
...@@ -336,7 +336,7 @@ class PslibNetwork(NetworkBase): ...@@ -336,7 +336,7 @@ class PslibNetwork(NetworkBase):
self._server(context) self._server(context)
else: else:
context["dataset"] = {} context["dataset"] = {}
for dataset in context["env"]["dataset"]: for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + dataset["name"] + type = envs.get_global_env("dataset." + dataset["name"] +
".type") ".type")
if type == "DataLoader": if type == "DataLoader":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册