未验证 提交 0e7bea51 编写于 作者: L littletomatodonkey 提交者: GitHub

Merge pull request #223 from littletomatodonkey/fix_single_card_dyg

fix single card dygraph train process
......@@ -329,9 +329,13 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
feeds = create_feeds(batch, use_mix)
fetchs = create_fetchs(feeds, net, config, mode)
if mode == 'train':
avg_loss = net.scale_loss(fetchs['loss'])
avg_loss.backward()
net.apply_collective_grads()
if config["use_data_parallel"]:
avg_loss = net.scale_loss(fetchs['loss'])
avg_loss.backward()
net.apply_collective_grads()
else:
avg_loss = fetchs['loss']
avg_loss.backward()
optimizer.minimize(avg_loss)
net.clear_gradients()
......
......@@ -52,10 +52,14 @@ def main(args):
gpu_id = fluid.dygraph.parallel.Env().dev_id
place = fluid.CUDAPlace(gpu_id)
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
config["use_data_parallel"] = use_data_parallel
with fluid.dygraph.guard(place):
strategy = fluid.dygraph.parallel.prepare_context()
net = program.create_model(config.ARCHITECTURE, config.classes_num)
net = fluid.dygraph.parallel.DataParallel(net, strategy)
if config["use_data_parallel"]:
strategy = fluid.dygraph.parallel.prepare_context()
net = fluid.dygraph.parallel.DataParallel(net, strategy)
optimizer = program.create_optimizer(
config, parameter_list=net.parameters())
......@@ -79,7 +83,8 @@ def main(args):
program.run(train_dataloader, config, net, optimizer, epoch_id,
'train')
if fluid.dygraph.parallel.Env().local_rank == 0:
if not config["use_data_parallel"] or fluid.dygraph.parallel.Env(
).local_rank == 0:
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
net.eval()
......@@ -108,4 +113,4 @@ def main(args):
if __name__ == '__main__':
args = parse_args()
main(args)
\ No newline at end of file
main(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册