提交 683adcda 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: support AMP infer

上级 5f88903e
......@@ -105,7 +105,6 @@ DataLoader:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
output_fp16: True
channel_num: *image_channel
sampler:
name: DistributedBatchSampler
......@@ -132,7 +131,6 @@ Infer:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
output_fp16: True
channel_num: *image_channel
- ToCHWImage:
PostProcess:
......
......@@ -99,7 +99,6 @@ DataLoader:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
output_fp16: True
channel_num: *image_channel
sampler:
name: DistributedBatchSampler
......@@ -126,7 +125,6 @@ Infer:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
output_fp16: True
channel_num: *image_channel
- ToCHWImage:
PostProcess:
......
......@@ -239,7 +239,7 @@ class Engine(object):
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
if self.config["Global"].get(
if self.mode == "train" and self.config["Global"].get(
"eval_during_train",
True) and self.amp_level == "O2" and self.amp_eval == False:
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
......@@ -269,10 +269,11 @@ class Engine(object):
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:
......@@ -431,7 +432,17 @@ class Engine(object):
image_file_list.append(image_file)
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data)
out = self.model(batch_tensor)
if self.amp and self.amp_eval:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=self.amp_level):
out = self.model(batch_tensor)
else:
out = self.model(batch_tensor)
if isinstance(out, list):
out = out[0]
if isinstance(out, dict) and "logits" in out:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册