提交 70b6bbe8 编写于 作者: L liuyuhui

fix PSRunner multi dataset_name

上级 921d6c03
......@@ -238,8 +238,8 @@ class PSNetwork(NetworkBase):
else:
context["fleet"].init_worker()
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
type = envs.get_global_env("dataset." + dataset["name"] +
for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + phase["dataset_name"] +
".type")
if type == "DataLoader":
data_loader = DataLoader(context)
......@@ -247,9 +247,9 @@ class PSNetwork(NetworkBase):
model._data_loader)
elif type == "QueueDataset":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(
dataset["name"], context)
context["dataset"][phase[
"dataset_name"]] = dataset_class.create_dataset(
phase["dataset_name"], context)
context["status"] = "startup_pass"
def _build_strategy(self, context):
......@@ -336,7 +336,7 @@ class PslibNetwork(NetworkBase):
self._server(context)
else:
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + dataset["name"] +
".type")
if type == "DataLoader":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册