diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 2d06703b58ca6d933898e26113db0b7f551a1cbc..2a9a3a4003f36627aff4fe7ab4f86a4979c46525 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -363,6 +363,7 @@ class CollectiveNetwork(NetworkBase): def build_network(self, context): context["model"] = {} if len(context["env"]["phase"]) > 1: + print("CollectiveNetwork phase:{}".format(context["env"]["phase"])) warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, @@ -407,16 +408,17 @@ class CollectiveNetwork(NetworkBase): context["model"][model_dict["name"]]["compiled_program"] = None context["dataset"] = {} - for dataset in context["env"]["dataset"]: - type = envs.get_global_env("dataset." + dataset["name"] + ".type") + for phase in context["env"]["phase"]: + type = envs.get_global_env("dataset." + phase["dataset_name"] + + ".type") if type == "QueueDataset": raise ValueError( "Collective don't support QueueDataset training, please use DataLoader." ) dataset_class = QueueDataset(context) - context["dataset"][dataset[ - "name"]] = dataset_class.create_dataset(dataset["name"], - context) + context["dataset"][phase[ + "dataset_name"]] = dataset_class.create_dataset( + phase["dataset_name"], context) context["status"] = "startup_pass" def _build_strategy(self, context):