提交 6d21e2e7 编写于 作者: Y Yang Zhang 提交者: GitHub

Add support for mixed precision training (#3406)

* Add support for mixed precision training

* Disable `fuse_all_reduce_ops` when training in mixed precision

waiting for upstream fix
上级 f4d74a90
...@@ -36,6 +36,7 @@ set_paddle_flags( ...@@ -36,6 +36,7 @@ set_paddle_flags(
) )
from paddle import fluid from paddle import fluid
from paddle.fluid.contrib import mixed_precision
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
from ppdet.data.data_feed import create_reader from ppdet.data.data_feed import create_reader
...@@ -115,6 +116,11 @@ def main(): ...@@ -115,6 +116,11 @@ def main():
loss = train_fetches['loss'] loss = train_fetches['loss']
lr = lr_builder() lr = lr_builder()
optimizer = optim_builder(lr) optimizer = optim_builder(lr)
if FLAGS.fp16:
optimizer = mixed_precision.decorate(
optimizer=optimizer,
init_loss_scaling=FLAGS.loss_scale,
use_dynamic_loss_scaling=False)
optimizer.minimize(loss) optimizer.minimize(loss)
# parse train fetches # parse train fetches
...@@ -145,6 +151,8 @@ def main(): ...@@ -145,6 +151,8 @@ def main():
# compile program for multi-devices # compile program for multi-devices
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn' sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
if FLAGS.fp16:
build_strategy.fuse_all_reduce_ops = False
# only enable sync_bn in multi GPU devices # only enable sync_bn in multi GPU devices
build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \ build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
and cfg.use_gpu and cfg.use_gpu
...@@ -268,6 +276,16 @@ def main(): ...@@ -268,6 +276,16 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgsParser() parser = ArgsParser()
parser.add_argument(
"--fp16",
action='store_true',
default=False,
help="Enable mixed precision training.")
parser.add_argument(
"--loss_scale",
default=8.,
type=float,
help="Mixed precision training loss scale.")
parser.add_argument( parser.add_argument(
"-r", "-r",
"--resume_checkpoint", "--resume_checkpoint",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册