提交 bb1376db 编写于 作者: Z zhiqiu

add flags setting

上级 a9f35981
...@@ -10,19 +10,26 @@ Global: ...@@ -10,19 +10,26 @@ Global:
epochs: 120 epochs: 120
print_batch_step: 10 print_batch_step: 10
use_visualdl: False use_visualdl: False
image_channel: &image_channel 4
# used for static mode and model export # used for static mode and model export
image_shape: [3, 224, 224] image_shape: [*image_channel, 224, 224]
save_inference_dir: ./inference save_inference_dir: ./inference
# training model under @to_static # training model under @to_static
to_static: False to_static: False
use_dali: True use_dali: True
# mixed precision training
AMP: AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
use_pure_fp16: &use_pure_fp16 False
# model architecture # model architecture
Arch: Arch:
name: ResNet50 name: ResNet50
class_num: 1000 class_num: 1000
input_image_channel: *image_channel
data_format: "NHWC"
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
...@@ -67,10 +74,12 @@ DataLoader: ...@@ -67,10 +74,12 @@ DataLoader:
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: '' order: ''
output_fp16: *use_pure_fp16
channel_num: *image_channel
sampler: sampler:
name: DistributedBatchSampler name: DistributedBatchSampler
batch_size: 64 batch_size: 256
drop_last: False drop_last: False
shuffle: True shuffle: True
loader: loader:
...@@ -95,6 +104,8 @@ DataLoader: ...@@ -95,6 +104,8 @@ DataLoader:
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: '' order: ''
output_fp16: *use_pure_fp16
channel_num: *image_channel
sampler: sampler:
name: DistributedBatchSampler name: DistributedBatchSampler
batch_size: 64 batch_size: 64
...@@ -120,6 +131,8 @@ Infer: ...@@ -120,6 +131,8 @@ Infer:
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: '' order: ''
output_fp16: *use_pure_fp16
channel_num: *image_channel
- ToCHWImage: - ToCHWImage:
PostProcess: PostProcess:
name: Topk name: Topk
......
...@@ -112,6 +112,12 @@ class Trainer(object): ...@@ -112,6 +112,12 @@ class Trainer(object):
else: else:
self.scale_loss = 1.0 self.scale_loss = 1.0
self.use_dynamic_loss_scaling = False self.use_dynamic_loss_scaling = False
if self.amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8,
}
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
self.train_loss_func = None self.train_loss_func = None
self.eval_loss_func = None self.eval_loss_func = None
self.train_metric_func = None self.train_metric_func = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册