From 16c7f9b7af62e41606f1f4d1908ffb6bc04b5d58 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Fri, 18 Sep 2020 16:40:49 +0800 Subject: [PATCH] update optimizer and collect loss --- dygraph/paddleseg/core/train.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/dygraph/paddleseg/core/train.py b/dygraph/paddleseg/core/train.py index e60502cf..d1854a0b 100644 --- a/dygraph/paddleseg/core/train.py +++ b/dygraph/paddleseg/core/train.py @@ -16,6 +16,7 @@ import os import paddle from paddle.distributed import ParallelEnv +from paddle.distributed import init_parallel_env from paddle.io import DistributedBatchSampler from paddle.io import DataLoader import paddle.nn.functional as F @@ -77,11 +78,14 @@ def train(model, os.makedirs(save_dir) if nranks > 1: + # Initialize parallel training environment. + init_parallel_env() strategy = paddle.distributed.prepare_context() ddp_model = paddle.DataParallel(model, strategy) batch_sampler = DistributedBatchSampler( train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + loader = DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -115,7 +119,6 @@ def train(model, if nranks > 1: logits = ddp_model(images) loss = loss_computation(logits, labels, losses) - # loss = ddp_model(images, labels) # apply_collective_grads sum grads over multiple gpus. loss = ddp_model.scale_loss(loss) loss.backward() @@ -125,8 +128,15 @@ def train(model, loss = loss_computation(logits, labels, losses) # loss = model(images, labels) loss.backward() - optimizer.minimize(loss) + # optimizer.minimize(loss) + optimizer.step() + if isinstance(optimizer._learning_rate, + paddle.optimizer._LRScheduler): + optimizer._learning_rate.step() model.clear_gradients() + # Sum loss over all ranks + if nranks > 1: + paddle.distributed.all_reduce(loss) avg_loss += loss.numpy()[0] lr = optimizer.get_lr() train_batch_cost += timer.elapsed_time() @@ -141,10 +151,10 @@ def train(model, logger.info( "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}" .format((iter - 1) // iters_per_epoch + 1, iter, iters, - avg_loss * nranks, lr, avg_train_batch_cost, + avg_loss, lr, avg_train_batch_cost, avg_train_reader_cost, eta)) if use_vdl: - log_writer.add_scalar('Train/loss', avg_loss * nranks, iter) + log_writer.add_scalar('Train/loss', avg_loss, iter) log_writer.add_scalar('Train/lr', lr, iter) log_writer.add_scalar('Train/batch_cost', avg_train_batch_cost, iter) -- GitLab