提交 16c7f9b7 编写于 作者: C chenguowei01

update optimizer and collect loss

上级 334a4b30
......@@ -16,6 +16,7 @@ import os
import paddle
from paddle.distributed import ParallelEnv
from paddle.distributed import init_parallel_env
from paddle.io import DistributedBatchSampler
from paddle.io import DataLoader
import paddle.nn.functional as F
......@@ -77,11 +78,14 @@ def train(model,
os.makedirs(save_dir)
if nranks > 1:
# Initialize parallel training environment.
init_parallel_env()
strategy = paddle.distributed.prepare_context()
ddp_model = paddle.DataParallel(model, strategy)
batch_sampler = DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
......@@ -115,7 +119,6 @@ def train(model,
if nranks > 1:
logits = ddp_model(images)
loss = loss_computation(logits, labels, losses)
# loss = ddp_model(images, labels)
# apply_collective_grads sum grads over multiple gpus.
loss = ddp_model.scale_loss(loss)
loss.backward()
......@@ -125,8 +128,15 @@ def train(model,
loss = loss_computation(logits, labels, losses)
# loss = model(images, labels)
loss.backward()
optimizer.minimize(loss)
# optimizer.minimize(loss)
optimizer.step()
if isinstance(optimizer._learning_rate,
paddle.optimizer._LRScheduler):
optimizer._learning_rate.step()
model.clear_gradients()
# Sum loss over all ranks
if nranks > 1:
paddle.distributed.all_reduce(loss)
avg_loss += loss.numpy()[0]
lr = optimizer.get_lr()
train_batch_cost += timer.elapsed_time()
......@@ -141,10 +151,10 @@ def train(model,
logger.info(
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
.format((iter - 1) // iters_per_epoch + 1, iter, iters,
avg_loss * nranks, lr, avg_train_batch_cost,
avg_loss, lr, avg_train_batch_cost,
avg_train_reader_cost, eta))
if use_vdl:
log_writer.add_scalar('Train/loss', avg_loss * nranks, iter)
log_writer.add_scalar('Train/loss', avg_loss, iter)
log_writer.add_scalar('Train/lr', lr, iter)
log_writer.add_scalar('Train/batch_cost',
avg_train_batch_cost, iter)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册