提交 cd039a7b 编写于 作者: Z zhangbo9674 提交者: Tingquan Gao

add save_dtype

上级 d437bb0a
......@@ -218,7 +218,7 @@ class Engine(object):
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
if self.config['AMP']['use_pure_fp16'] is True:
self.model = paddle.amp.decorate(models=self.model, level='O2')
self.model = paddle.amp.decorate(models=self.model, level='O2', save_dtype='float32')
# for distributed
self.config["Global"][
......
......@@ -40,17 +40,6 @@ def classification_eval(engine, epoch_id=0):
dataset) if not engine.use_dali else engine.eval_dataloader.size
max_iter = len(engine.eval_dataloader) - 1 if platform.system(
) == "Windows" else len(engine.eval_dataloader)
# print("========================fp16 layer")
# for layer in engine.model.sublayers(include_self=True):
# print(type(layer), layer._dtype)
# 用fp32做eval
engine.model.to(dtype='float32')
# print("========================to fp32 layer")
# for layer in engine.model.sublayers(include_self=True):
# print(type(layer), layer._dtype)
for iter_id, batch in enumerate(engine.eval_dataloader):
if iter_id >= max_iter:
break
......@@ -68,7 +57,6 @@ def classification_eval(engine, epoch_id=0):
if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
'''
# image input
if engine.amp:
amp_level = 'O1'
......@@ -92,19 +80,6 @@ def classification_eval(engine, epoch_id=0):
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
'''
#========================================================
out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
#========================================================
# just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size()
......@@ -176,16 +151,6 @@ def classification_eval(engine, epoch_id=0):
len(engine.eval_dataloader), metric_msg, time_msg, ips_msg))
tic = time.time()
#如果是amp-o2做eval后再将模型转回amp-o2的模式
if engine.amp:
if engine.config['AMP']['use_pure_fp16'] is True:
paddle.fluid.dygraph.amp.auto_cast.pure_fp16_initialize([engine.model])
# print("========================to fp16 layer")
# for layer in engine.model.sublayers(include_self=True):
# print(type(layer), layer._dtype)
# import sys
# sys.exit()
if engine.use_dali:
engine.eval_dataloader.reset()
metric_msg = ", ".join([
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册