diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 7d7a8273b6a402bd163f653a7beb3900de899ae3..2d06703b58ca6d933898e26113db0b7f551a1cbc 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -238,8 +238,8 @@ class PSNetwork(NetworkBase): else: context["fleet"].init_worker() context["dataset"] = {} - for dataset in context["env"]["dataset"]: - type = envs.get_global_env("dataset." + dataset["name"] + + for phase in context["env"]["phase"]: + type = envs.get_global_env("dataset." + phase["dataset_name"] + ".type") if type == "DataLoader": data_loader = DataLoader(context) @@ -247,9 +247,9 @@ class PSNetwork(NetworkBase): model._data_loader) elif type == "QueueDataset": dataset_class = QueueDataset(context) - context["dataset"][dataset[ - "name"]] = dataset_class.create_dataset( - dataset["name"], context) + context["dataset"][phase[ + "dataset_name"]] = dataset_class.create_dataset( + phase["dataset_name"], context) context["status"] = "startup_pass" def _build_strategy(self, context): @@ -336,7 +336,7 @@ class PslibNetwork(NetworkBase): self._server(context) else: context["dataset"] = {} - for dataset in context["env"]["dataset"]: + for phase in context["env"]["phase"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") if type == "DataLoader":