未验证 提交 89bfcdaf 编写于 作者: W Wenyu 提交者: GitHub

rename fp16 -> amp (#5268)

上级 ea2148ab
...@@ -362,8 +362,8 @@ class Trainer(object): ...@@ -362,8 +362,8 @@ class Trainer(object):
model = paddle.DataParallel( model = paddle.DataParallel(
self.model, find_unused_parameters=find_unused_parameters) self.model, find_unused_parameters=find_unused_parameters)
# initial fp16 # enabel auto mixed precision mode
if self.cfg.get('fp16', False): if self.cfg.get('amp', False):
scaler = amp.GradScaler( scaler = amp.GradScaler(
enable=self.cfg.use_gpu, init_loss_scaling=1024) enable=self.cfg.use_gpu, init_loss_scaling=1024)
...@@ -401,7 +401,7 @@ class Trainer(object): ...@@ -401,7 +401,7 @@ class Trainer(object):
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
data['epoch_id'] = epoch_id data['epoch_id'] = epoch_id
if self.cfg.get('fp16', False): if self.cfg.get('amp', False):
with amp.auto_cast(enable=self.cfg.use_gpu): with amp.auto_cast(enable=self.cfg.use_gpu):
# model forward # model forward
outputs = model(data) outputs = model(data)
......
...@@ -60,10 +60,10 @@ def parse_args(): ...@@ -60,10 +60,10 @@ def parse_args():
help="If set True, enable continuous evaluation job." help="If set True, enable continuous evaluation job."
"This flag is only used for internal test.") "This flag is only used for internal test.")
parser.add_argument( parser.add_argument(
"--fp16", "--amp",
action='store_true', action='store_true',
default=False, default=False,
help="Enable mixed precision training.") help="Enable auto mixed precision training.")
parser.add_argument( parser.add_argument(
"--fleet", action='store_true', default=False, help="Use fleet or not") "--fleet", action='store_true', default=False, help="Use fleet or not")
parser.add_argument( parser.add_argument(
...@@ -130,7 +130,7 @@ def run(FLAGS, cfg): ...@@ -130,7 +130,7 @@ def run(FLAGS, cfg):
def main(): def main():
FLAGS = parse_args() FLAGS = parse_args()
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
cfg['fp16'] = FLAGS.fp16 cfg['amp'] = FLAGS.amp
cfg['fleet'] = FLAGS.fleet cfg['fleet'] = FLAGS.fleet
cfg['use_vdl'] = FLAGS.use_vdl cfg['use_vdl'] = FLAGS.use_vdl
cfg['vdl_log_dir'] = FLAGS.vdl_log_dir cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册