From 703306ceeb7a22f083dd7365148927dd8f4db617 Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Sun, 27 Sep 2020 15:53:58 +0800 Subject: [PATCH] fix Collective multi dataset_name --- core/trainers/framework/network.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 2d06703b..2a9a3a40 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -363,6 +363,7 @@ class CollectiveNetwork(NetworkBase): def build_network(self, context): context["model"] = {} if len(context["env"]["phase"]) > 1: + print("CollectiveNetwork phase:{}".format(context["env"]["phase"])) warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, @@ -407,16 +408,17 @@ class CollectiveNetwork(NetworkBase): context["model"][model_dict["name"]]["compiled_program"] = None context["dataset"] = {} - for dataset in context["env"]["dataset"]: - type = envs.get_global_env("dataset." + dataset["name"] + ".type") + for phase in context["env"]["phase"]: + type = envs.get_global_env("dataset." + phase["dataset_name"] + + ".type") if type == "QueueDataset": raise ValueError( "Collective don't support QueueDataset training, please use DataLoader." ) dataset_class = QueueDataset(context) - context["dataset"][dataset[ - "name"]] = dataset_class.create_dataset(dataset["name"], - context) + context["dataset"][phase[ + "dataset_name"]] = dataset_class.create_dataset( + phase["dataset_name"], context) context["status"] = "startup_pass" def _build_strategy(self, context): -- GitLab