From 0ac93273e84741a25b25097f78bbb97aad016352 Mon Sep 17 00:00:00 2001 From: Chengmo Date: Tue, 4 Aug 2020 14:09:18 +0800 Subject: [PATCH] fix dataloader in distributed training (#160) * fix * test * remove setup change * fix * fix Co-authored-by: tangwei12 --- core/trainers/framework/network.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 74d2c975..d2a6b71e 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -99,7 +99,8 @@ class SingleNetwork(NetworkBase): context["dataset"] = {} for dataset in context["env"]["dataset"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") - if type != "DataLoader": + + if type == "QueueDataset": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset(dataset["name"], @@ -133,9 +134,7 @@ class PSNetwork(NetworkBase): if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader": model._init_dataloader(is_infer=False) - data_loader = DataLoader(context) - data_loader.get_dataloader(context, dataset_name, - model._data_loader) + model.net(model._data_var, False) optimizer = model.optimizer() strategy = self._build_strategy(context) @@ -160,7 +159,11 @@ class PSNetwork(NetworkBase): for dataset in context["env"]["dataset"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") - if type != "DataLoader": + if type == "DataLoader": + data_loader = DataLoader(context) + data_loader.get_dataloader(context, dataset_name, + model._data_loader) + elif type == "QueueDataset": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset( @@ -229,9 +232,6 @@ class PslibNetwork(NetworkBase): if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader": model._init_dataloader(is_infer=False) - data_loader = DataLoader(context) - data_loader.get_dataloader(context, dataset_name, - model._data_loader) model.net(model._data_var, False) optimizer = model.optimizer() @@ -257,7 +257,11 @@ class PslibNetwork(NetworkBase): for dataset in context["env"]["dataset"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") - if type != "DataLoader": + if type == "DataLoader": + data_loader = DataLoader(context) + data_loader.get_dataloader(context, dataset_name, context[ + "model"][model_dict["name"]]["model"]._data_loader) + elif type == "QueueDataset": dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset( @@ -323,7 +327,10 @@ class CollectiveNetwork(NetworkBase): context["dataset"] = {} for dataset in context["env"]["dataset"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") - if type != "DataLoader": + if type == "QueueDataset": + raise ValueError( + "Collective don't support QueueDataset training, please use DataLoader." + ) dataset_class = QueueDataset(context) context["dataset"][dataset[ "name"]] = dataset_class.create_dataset(dataset["name"], -- GitLab