From 70b6bbe8f979f524784fd48f140ec39552ad4fcb Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Sun, 27 Sep 2020 11:15:37 +0800 Subject: [PATCH] fix PSRunner multi dataset_name --- core/trainers/framework/network.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 7d7a8273..2d06703b 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -238,8 +238,8 @@ class PSNetwork(NetworkBase): else: context["fleet"].init_worker() context["dataset"] = {} - for dataset in context["env"]["dataset"]: - type = envs.get_global_env("dataset." + dataset["name"] + + for phase in context["env"]["phase"]: + type = envs.get_global_env("dataset." + phase["dataset_name"] + ".type") if type == "DataLoader": data_loader = DataLoader(context) @@ -247,9 +247,9 @@ class PSNetwork(NetworkBase): model._data_loader) elif type == "QueueDataset": dataset_class = QueueDataset(context) - context["dataset"][dataset[ - "name"]] = dataset_class.create_dataset( - dataset["name"], context) + context["dataset"][phase[ + "dataset_name"]] = dataset_class.create_dataset( + phase["dataset_name"], context) context["status"] = "startup_pass" def _build_strategy(self, context): @@ -336,7 +336,7 @@ class PslibNetwork(NetworkBase): self._server(context) else: context["dataset"] = {} - for dataset in context["env"]["dataset"]: + for phase in context["env"]["phase"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") if type == "DataLoader": -- GitLab