未验证 提交 0e7bea51 编写于 作者: L littletomatodonkey 提交者: GitHub

Merge pull request #223 from littletomatodonkey/fix_single_card_dyg

fix single card dygraph train process
...@@ -329,9 +329,13 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'): ...@@ -329,9 +329,13 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
feeds = create_feeds(batch, use_mix) feeds = create_feeds(batch, use_mix)
fetchs = create_fetchs(feeds, net, config, mode) fetchs = create_fetchs(feeds, net, config, mode)
if mode == 'train': if mode == 'train':
if config["use_data_parallel"]:
avg_loss = net.scale_loss(fetchs['loss']) avg_loss = net.scale_loss(fetchs['loss'])
avg_loss.backward() avg_loss.backward()
net.apply_collective_grads() net.apply_collective_grads()
else:
avg_loss = fetchs['loss']
avg_loss.backward()
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
net.clear_gradients() net.clear_gradients()
......
...@@ -52,9 +52,13 @@ def main(args): ...@@ -52,9 +52,13 @@ def main(args):
gpu_id = fluid.dygraph.parallel.Env().dev_id gpu_id = fluid.dygraph.parallel.Env().dev_id
place = fluid.CUDAPlace(gpu_id) place = fluid.CUDAPlace(gpu_id)
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
config["use_data_parallel"] = use_data_parallel
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
strategy = fluid.dygraph.parallel.prepare_context()
net = program.create_model(config.ARCHITECTURE, config.classes_num) net = program.create_model(config.ARCHITECTURE, config.classes_num)
if config["use_data_parallel"]:
strategy = fluid.dygraph.parallel.prepare_context()
net = fluid.dygraph.parallel.DataParallel(net, strategy) net = fluid.dygraph.parallel.DataParallel(net, strategy)
optimizer = program.create_optimizer( optimizer = program.create_optimizer(
...@@ -79,7 +83,8 @@ def main(args): ...@@ -79,7 +83,8 @@ def main(args):
program.run(train_dataloader, config, net, optimizer, epoch_id, program.run(train_dataloader, config, net, optimizer, epoch_id,
'train') 'train')
if fluid.dygraph.parallel.Env().local_rank == 0: if not config["use_data_parallel"] or fluid.dygraph.parallel.Env(
).local_rank == 0:
# 2. validate with validate dataset # 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0: if config.validate and epoch_id % config.valid_interval == 0:
net.eval() net.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册