diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 74d2c97540419b15e6a5d0f87b3c5af368a7e9b3..d2a6b71e4f74a639095eb404a82c9c1fefaf7fdf 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -99,7 +99,8 @@ class SingleNetwork(NetworkBase): context["dataset"] = {} for dataset in context["env"]["dataset"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") - if type != "DataLoader": + + if type == "QueueDataset": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset(dataset["name"], @@ -133,9 +134,7 @@ class PSNetwork(NetworkBase): if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader": model._init_dataloader(is_infer=False) - data_loader = DataLoader(context) - data_loader.get_dataloader(context, dataset_name, - model._data_loader) + model.net(model._data_var, False) optimizer = model.optimizer() strategy = self._build_strategy(context) @@ -160,7 +159,11 @@ class PSNetwork(NetworkBase): for dataset in context["env"]["dataset"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") - if type != "DataLoader": + if type == "DataLoader": + data_loader = DataLoader(context) + data_loader.get_dataloader(context, dataset_name, + model._data_loader) + elif type == "QueueDataset": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset( @@ -229,9 +232,6 @@ class PslibNetwork(NetworkBase): if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader": model._init_dataloader(is_infer=False) - data_loader = DataLoader(context) - data_loader.get_dataloader(context, dataset_name, - model._data_loader) model.net(model._data_var, False) optimizer = model.optimizer() @@ -257,7 +257,11 @@ class PslibNetwork(NetworkBase): for dataset in context["env"]["dataset"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") - if type != "DataLoader": + if type == "DataLoader": + data_loader = DataLoader(context) + data_loader.get_dataloader(context, dataset_name, context[ + "model"][model_dict["name"]]["model"]._data_loader) + elif type == "QueueDataset": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset( @@ -323,7 +327,10 @@ class CollectiveNetwork(NetworkBase): context["dataset"] = {} for dataset in context["env"]["dataset"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") - if type != "DataLoader": + if type == "QueueDataset": + raise ValueError( + "Collective don't support QueueDataset training, please use DataLoader." + ) dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset(dataset["name"],