未验证 提交 89bfcdaf 编写于 作者: W Wenyu 提交者: GitHub

rename fp16 -> amp (#5268)

上级 ea2148ab
......@@ -362,8 +362,8 @@ class Trainer(object):
model = paddle.DataParallel(
self.model, find_unused_parameters=find_unused_parameters)
# initial fp16
if self.cfg.get('fp16', False):
# enabel auto mixed precision mode
if self.cfg.get('amp', False):
scaler = amp.GradScaler(
enable=self.cfg.use_gpu, init_loss_scaling=1024)
......@@ -401,7 +401,7 @@ class Trainer(object):
self._compose_callback.on_step_begin(self.status)
data['epoch_id'] = epoch_id
if self.cfg.get('fp16', False):
if self.cfg.get('amp', False):
with amp.auto_cast(enable=self.cfg.use_gpu):
# model forward
outputs = model(data)
......
......@@ -60,10 +60,10 @@ def parse_args():
help="If set True, enable continuous evaluation job."
"This flag is only used for internal test.")
parser.add_argument(
"--fp16",
"--amp",
action='store_true',
default=False,
help="Enable mixed precision training.")
help="Enable auto mixed precision training.")
parser.add_argument(
"--fleet", action='store_true', default=False, help="Use fleet or not")
parser.add_argument(
......@@ -130,7 +130,7 @@ def run(FLAGS, cfg):
def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
cfg['fp16'] = FLAGS.fp16
cfg['amp'] = FLAGS.amp
cfg['fleet'] = FLAGS.fleet
cfg['use_vdl'] = FLAGS.use_vdl
cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册