diff --git a/tools/program.py b/tools/program.py index 34541043623ea9c4f78387488203f57e7fa8a0c7..55900b98599ecc15055f7368a9fea9b1430e852f 100644 --- a/tools/program.py +++ b/tools/program.py @@ -329,9 +329,13 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'): feeds = create_feeds(batch, use_mix) fetchs = create_fetchs(feeds, net, config, mode) if mode == 'train': - avg_loss = net.scale_loss(fetchs['loss']) - avg_loss.backward() - net.apply_collective_grads() + if config["use_data_parallel"]: + avg_loss = net.scale_loss(fetchs['loss']) + avg_loss.backward() + net.apply_collective_grads() + else: + avg_loss = fetchs['loss'] + avg_loss.backward() optimizer.minimize(avg_loss) net.clear_gradients() diff --git a/tools/train.py b/tools/train.py index c244dd490297afa1132a794f4a0f3b85578c7408..976136e359f6631235cc009ac90e7ee9b72e594e 100644 --- a/tools/train.py +++ b/tools/train.py @@ -52,10 +52,14 @@ def main(args): gpu_id = fluid.dygraph.parallel.Env().dev_id place = fluid.CUDAPlace(gpu_id) + use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1 + config["use_data_parallel"] = use_data_parallel + with fluid.dygraph.guard(place): - strategy = fluid.dygraph.parallel.prepare_context() net = program.create_model(config.ARCHITECTURE, config.classes_num) - net = fluid.dygraph.parallel.DataParallel(net, strategy) + if config["use_data_parallel"]: + strategy = fluid.dygraph.parallel.prepare_context() + net = fluid.dygraph.parallel.DataParallel(net, strategy) optimizer = program.create_optimizer( config, parameter_list=net.parameters()) @@ -79,7 +83,8 @@ def main(args): program.run(train_dataloader, config, net, optimizer, epoch_id, 'train') - if fluid.dygraph.parallel.Env().local_rank == 0: + if not config["use_data_parallel"] or fluid.dygraph.parallel.Env( + ).local_rank == 0: # 2. validate with validate dataset if config.validate and epoch_id % config.valid_interval == 0: net.eval() @@ -108,4 +113,4 @@ def main(args): if __name__ == '__main__': args = parse_args() - main(args) \ No newline at end of file + main(args)