提交 ea3e138c 编写于 作者: C chenguowei01

update train.py

上级 d1067776
......@@ -57,7 +57,7 @@ def train(model,
if nranks > 1:
strategy = fluid.dygraph.prepare_context()
model_parallel = fluid.dygraph.DataParallel(model, strategy)
ddp_model = fluid.dygraph.DataParallel(model, strategy)
batch_sampler = DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
......@@ -86,10 +86,10 @@ def train(model,
images = data[0]
labels = data[1].astype('int64')
if nranks > 1:
loss = model_parallel(images, labels)
loss = model_parallel.scale_loss(loss)
loss = ddp_model(images, labels)
loss = ddp_model.scale_loss(loss)
loss.backward()
model_parallel.apply_collective_grads()
ddp_model.apply_collective_grads()
else:
loss = model(images, labels)
loss.backward()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册