提交 6d21e2e7 编写于 作者: Y Yang Zhang 提交者: GitHub

Add support for mixed precision training (#3406)

* Add support for mixed precision training

* Disable `fuse_all_reduce_ops` when training in mixed precision

waiting for upstream fix
上级 f4d74a90
......@@ -36,6 +36,7 @@ set_paddle_flags(
)
from paddle import fluid
from paddle.fluid.contrib import mixed_precision
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.data.data_feed import create_reader
......@@ -115,6 +116,11 @@ def main():
loss = train_fetches['loss']
lr = lr_builder()
optimizer = optim_builder(lr)
if FLAGS.fp16:
optimizer = mixed_precision.decorate(
optimizer=optimizer,
init_loss_scaling=FLAGS.loss_scale,
use_dynamic_loss_scaling=False)
optimizer.minimize(loss)
# parse train fetches
......@@ -145,6 +151,8 @@ def main():
# compile program for multi-devices
build_strategy = fluid.BuildStrategy()
sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
if FLAGS.fp16:
build_strategy.fuse_all_reduce_ops = False
# only enable sync_bn in multi GPU devices
build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
and cfg.use_gpu
......@@ -268,6 +276,16 @@ def main():
if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"--fp16",
action='store_true',
default=False,
help="Enable mixed precision training.")
parser.add_argument(
"--loss_scale",
default=8.,
type=float,
help="Mixed precision training loss scale.")
parser.add_argument(
"-r",
"--resume_checkpoint",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册