未验证 提交 de73d276 编写于 作者: L Leo Chen 提交者: GitHub

add amp support (#395)

* add  amp support

* fix data_format

* fix yaml

* follow comments

* follow comments

* refine code

* follow comments

* follow comments
上级 bec79a00
mode: 'train'
ARCHITECTURE:
name: 'InceptionV3'
params:
data_format: 'NHWC'
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 200
topk: 5
image_shape: [3, 299, 299]
# mxied precision training
use_mix: True
ls_epsilon: 0.1
use_fp16: True # cannot open with dali
amp_scale_loss: 128
use_dynamic_loss_scaling: True
fuse_elewise_add_act_ops: True
fuse_elewise_add_act_ops: True
fuse_bn_act_ops: True
fuse_bn_add_act_ops: True
use_dali: False
enable_addto: True
LEARNING_RATE:
function: 'Cosine'
params:
lr: 0.045
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.00010
TRAIN:
batch_size: 256
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- RandCropImage:
size: 299
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
mix:
- MixupOperator:
alpha: 0.2
VALID:
batch_size: 16
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
resize_short: 320
- CropImage:
size: 299
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
...@@ -24,11 +24,14 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -24,11 +24,14 @@ from paddle.fluid.param_attr import ParamAttr
__all__ = ["InceptionV3"] __all__ = ["InceptionV3"]
class InceptionV3(): class InceptionV3():
def __init__(self): def __init__(self, data_format="NCHW"):
self.inception_a_list = [32, 64, 64] self.inception_a_list = [32, 64, 64]
self.inception_c_list = [128, 160, 160, 192] self.inception_c_list = [128, 160, 160, 192]
self.data_format = data_format
self.concat_axis = 3 if self.data_format=="NHWC" else 1
def net(self, input, class_dim=1000): def net(self, input, class_dim=1000):
x = self.inception_stem(input) x = self.inception_stem(input)
for i, pool_features in enumerate(self.inception_a_list): for i, pool_features in enumerate(self.inception_a_list):
x = self.inceptionA(x, pool_features, name=str(i+1)) x = self.inceptionA(x, pool_features, name=str(i+1))
...@@ -39,7 +42,7 @@ class InceptionV3(): ...@@ -39,7 +42,7 @@ class InceptionV3():
x = self.inceptionE(x, name="1") x = self.inceptionE(x, name="1")
x = self.inceptionE(x, name="2") x = self.inceptionE(x, name="2")
pool = fluid.layers.pool2d(input=x, pool_type="avg", global_pooling=True) pool = fluid.layers.pool2d(input=x, pool_type="avg", global_pooling=True, data_format=self.data_format)
drop = fluid.layers.dropout(x=pool, dropout_prob=0.2) drop = fluid.layers.dropout(x=pool, dropout_prob=0.2)
...@@ -70,13 +73,15 @@ class InceptionV3(): ...@@ -70,13 +73,15 @@ class InceptionV3():
act=None, act=None,
param_attr=ParamAttr(name=name+"_weights"), param_attr=ParamAttr(name=name+"_weights"),
bias_attr=False, bias_attr=False,
name=name) name=name,
data_format=self.data_format)
return fluid.layers.batch_norm(input=conv, return fluid.layers.batch_norm(input=conv,
act=act, act=act,
param_attr = ParamAttr(name=name+"_bn_scale"), param_attr = ParamAttr(name=name+"_bn_scale"),
bias_attr=ParamAttr(name=name+"_bn_offset"), bias_attr=ParamAttr(name=name+"_bn_offset"),
moving_mean_name=name+"_bn_mean", moving_mean_name=name+"_bn_mean",
moving_variance_name=name+"_bn_variance") moving_variance_name=name+"_bn_variance",
data_layout=self.data_format)
def inception_stem(self, x): def inception_stem(self, x):
x = self.conv_bn_layer(x, x = self.conv_bn_layer(x,
...@@ -98,7 +103,7 @@ class InceptionV3(): ...@@ -98,7 +103,7 @@ class InceptionV3():
act="relu", act="relu",
name="conv_2b_3x3") name="conv_2b_3x3")
x = fluid.layers.pool2d(input=x, pool_size=3, pool_stride=2, pool_type="max") x = fluid.layers.pool2d(input=x, pool_size=3, pool_stride=2, pool_type="max", data_format=self.data_format)
x = self.conv_bn_layer(x, x = self.conv_bn_layer(x,
num_filters=80, num_filters=80,
...@@ -111,7 +116,7 @@ class InceptionV3(): ...@@ -111,7 +116,7 @@ class InceptionV3():
act="relu", act="relu",
name="conv_4a_3x3") name="conv_4a_3x3")
x = fluid.layers.pool2d(input=x, pool_size=3, pool_stride=2, pool_type="max") x = fluid.layers.pool2d(input=x, pool_size=3, pool_stride=2, pool_type="max", data_format=self.data_format)
return x return x
...@@ -150,14 +155,14 @@ class InceptionV3(): ...@@ -150,14 +155,14 @@ class InceptionV3():
padding=1, padding=1,
act="relu", act="relu",
name="inception_a_branch3x3dbl_3_"+name) name="inception_a_branch3x3dbl_3_"+name)
branch_pool = fluid.layers.pool2d(x, pool_size=3, pool_padding=1, pool_type="avg") branch_pool = fluid.layers.pool2d(x, pool_size=3, pool_padding=1, pool_type="avg", data_format=self.data_format)
branch_pool = self.conv_bn_layer(branch_pool, branch_pool = self.conv_bn_layer(branch_pool,
num_filters=pool_features, num_filters=pool_features,
filter_size=1, filter_size=1,
act="relu", act="relu",
name="inception_a_branch_pool_"+name) name="inception_a_branch_pool_"+name)
concat = fluid.layers.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1) concat = fluid.layers.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=self.concat_axis)
return concat return concat
...@@ -187,10 +192,9 @@ class InceptionV3(): ...@@ -187,10 +192,9 @@ class InceptionV3():
stride=2, stride=2,
act="relu", act="relu",
name="inception_b_branch3x3dbl_3_"+name) name="inception_b_branch3x3dbl_3_"+name)
branch_pool = fluid.layers.pool2d(x, pool_size=3, pool_stride=2, pool_type="max") branch_pool = fluid.layers.pool2d(x, pool_size=3, pool_stride=2, pool_type="max", data_format=self.data_format)
concat = fluid.layers.concat([branch3x3, branch3x3dbl, branch_pool], axis=1) concat = fluid.layers.concat([branch3x3, branch3x3dbl, branch_pool], axis=self.concat_axis)
return concat return concat
...@@ -252,14 +256,13 @@ class InceptionV3(): ...@@ -252,14 +256,13 @@ class InceptionV3():
act="relu", act="relu",
name="inception_c_branch7x7dbl_5_"+name) name="inception_c_branch7x7dbl_5_"+name)
branch_pool = fluid.layers.pool2d(x, pool_size=3, pool_stride=1, pool_padding=1, pool_type="avg") branch_pool = fluid.layers.pool2d(x, pool_size=3, pool_stride=1, pool_padding=1, pool_type="avg", data_format=self.data_format)
branch_pool = self.conv_bn_layer(branch_pool, branch_pool = self.conv_bn_layer(branch_pool,
num_filters=192, num_filters=192,
filter_size=1, filter_size=1,
act="relu", act="relu",
name="inception_c_branch_pool_"+name) name="inception_c_branch_pool_"+name)
concat = fluid.layers.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=self.concat_axis)
concat = fluid.layers.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1)
return concat return concat
...@@ -299,8 +302,9 @@ class InceptionV3(): ...@@ -299,8 +302,9 @@ class InceptionV3():
stride=2, stride=2,
act="relu", act="relu",
name="inception_d_branch7x7x3_4_"+name) name="inception_d_branch7x7x3_4_"+name)
branch_pool = fluid.layers.pool2d(x, pool_size=3, pool_stride=2, pool_type="max") branch_pool = fluid.layers.pool2d(x, pool_size=3, pool_stride=2, pool_type="max", data_format=self.data_format)
concat = fluid.layers.concat([branch3x3, branch7x7x3, branch_pool], axis=1)
concat = fluid.layers.concat([branch3x3, branch7x7x3, branch_pool], axis=self.concat_axis)
return concat return concat
...@@ -329,7 +333,7 @@ class InceptionV3(): ...@@ -329,7 +333,7 @@ class InceptionV3():
act="relu", act="relu",
name="inception_e_branch3x3_2b_"+name) name="inception_e_branch3x3_2b_"+name)
branch3x3 = fluid.layers.concat([branch3x3_2a, branch3x3_2b], axis=1) branch3x3 = fluid.layers.concat([branch3x3_2a, branch3x3_2b], axis=self.concat_axis)
branch3x3dbl = self.conv_bn_layer(x, branch3x3dbl = self.conv_bn_layer(x,
num_filters=448, num_filters=448,
filter_size=1, filter_size=1,
...@@ -353,14 +357,13 @@ class InceptionV3(): ...@@ -353,14 +357,13 @@ class InceptionV3():
padding=(1, 0), padding=(1, 0),
act="relu", act="relu",
name="inception_e_branch3x3dbl_3b_"+name) name="inception_e_branch3x3dbl_3b_"+name)
branch3x3dbl = fluid.layers.concat([branch3x3dbl_3a, branch3x3dbl_3b], axis=1) branch3x3dbl = fluid.layers.concat([branch3x3dbl_3a, branch3x3dbl_3b], axis=self.concat_axis)
branch_pool = fluid.layers.pool2d(x, pool_size=3, pool_stride=1, pool_padding=1, pool_type="avg") branch_pool = fluid.layers.pool2d(x, pool_size=3, pool_stride=1, pool_padding=1, pool_type="avg", data_format=self.data_format)
branch_pool = self.conv_bn_layer(branch_pool, branch_pool = self.conv_bn_layer(branch_pool,
num_filters=192, num_filters=192,
filter_size=1, filter_size=1,
act="relu", act="relu",
name="inception_e_branch_pool_"+name) name="inception_e_branch_pool_"+name)
concat = fluid.layers.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1) concat = fluid.layers.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=self.concat_axis)
return concat return concat
\ No newline at end of file
...@@ -36,6 +36,7 @@ from ppcls.utils import logger ...@@ -36,6 +36,7 @@ from ppcls.utils import logger
from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.collective import DistributedStrategy from paddle.fluid.incubate.fleet.collective import DistributedStrategy
import paddle.fluid as fluid
from ema import ExponentialMovingAverage from ema import ExponentialMovingAverage
...@@ -104,9 +105,14 @@ def create_model(architecture, image, classes_num, is_train): ...@@ -104,9 +105,14 @@ def create_model(architecture, image, classes_num, is_train):
""" """
name = architecture["name"] name = architecture["name"]
params = architecture.get("params", {}) params = architecture.get("params", {})
if "is_test" in params: if "is_test" in params:
params['is_test'] = not is_train params['is_test'] = not is_train
model = architectures.__dict__[name](**params) model = architectures.__dict__[name](**params)
if "data_format" in params and params["data_format"] == "NHWC":
image = fluid.layers.transpose(image, [0, 2, 3, 1])
image.stop_gradient = True
out = model.net(input=image, class_dim=classes_num) out = model.net(input=image, class_dim=classes_num)
return out return out
...@@ -341,6 +347,7 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True): ...@@ -341,6 +347,7 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
use_distillation = config.get('use_distillation') use_distillation = config.get('use_distillation')
feeds = create_feeds(config.image_shape, use_mix=use_mix) feeds = create_feeds(config.image_shape, use_mix=use_mix)
dataloader = create_dataloader(feeds.values()) dataloader = create_dataloader(feeds.values())
out = create_model(config.ARCHITECTURE, feeds['image'], out = create_model(config.ARCHITECTURE, feeds['image'],
config.classes_num, is_train) config.classes_num, is_train)
fetchs = create_fetchs( fetchs = create_fetchs(
...@@ -361,6 +368,7 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True): ...@@ -361,6 +368,7 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
if is_distributed: if is_distributed:
optimizer = dist_optimizer(config, optimizer) optimizer = dist_optimizer(config, optimizer)
optimizer.minimize(fetchs['loss'][0]) optimizer.minimize(fetchs['loss'][0])
if config.get('use_ema'): if config.get('use_ema'):
global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter( global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter(
...@@ -392,6 +400,40 @@ def compile(config, program, loss_name=None, share_prog=None): ...@@ -392,6 +400,40 @@ def compile(config, program, loss_name=None, share_prog=None):
exec_strategy.num_threads = 1 exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10 exec_strategy.num_iteration_per_drop_scope = 10
use_fp16 = config.get('use_fp16', False)
fuse_bn_act_ops = config.get('fuse_bn_act_ops', True)
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', True)
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', True)
enable_addto = config.get('enable_addto', True)
if use_fp16:
try:
build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
except Exception as e:
logger.info(
"PaddlePaddle version 1.7.0 or higher is "
"required when you want to fuse batch_norm and activation_op.")
try:
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
except Exception as e:
logger.info(
"PaddlePaddle version 1.7.0 or higher is "
"required when you want to fuse elewise_add_act and activation_op.")
try:
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
except Exception as e:
logger.info(
"PaddlePaddle 2.0-rc or higher is "
"required when you want to enable fuse_bn_add_act_ops strategy.")
try:
build_strategy.enable_addto = enable_addto
except Exception as e:
logger.info(
"PaddlePaddle 2.0-rc or higher is "
"required when you want to enable addto strategy.")
compiled_program = fluid.CompiledProgram(program).with_data_parallel( compiled_program = fluid.CompiledProgram(program).with_data_parallel(
share_vars_from=share_prog, share_vars_from=share_prog,
loss_name=loss_name, loss_name=loss_name,
...@@ -466,6 +508,7 @@ def run(dataloader, ...@@ -466,6 +508,7 @@ def run(dataloader,
if idx == 0 else epoch_str, if idx == 0 else epoch_str,
logger.coloring(step_str, "PURPLE"), logger.coloring(step_str, "PURPLE"),
logger.coloring(fetchs_str, 'OKGREEN'))) logger.coloring(fetchs_str, 'OKGREEN')))
end_str = ''.join([str(m.mean) + ' ' end_str = ''.join([str(m.mean) + ' '
for m in metric_list] + [batch_time.total]) + 's' for m in metric_list] + [batch_time.total]) + 's'
......
...@@ -66,6 +66,16 @@ def main(args): ...@@ -66,6 +66,16 @@ def main(args):
fleet.init(role) fleet.init(role)
config = get_config(args.config, overrides=args.override, show=True) config = get_config(args.config, overrides=args.override, show=True)
use_fp16 = config.get('use_fp16', False)
if use_fp16:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': 1,
'FLAGS_conv_workspace_size_limit': 4000,
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8,
}
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1'
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
# assign the place # assign the place
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) place = fluid.CUDAPlace(gpu_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册