diff --git a/PaddleCV/image_classification/build_model.py b/PaddleCV/image_classification/build_model.py index 1168f70c1aaae041292994f2adca7a6d8720efb2..003111ab86206729ae20d56fca304943f16038da 100644 --- a/PaddleCV/image_classification/build_model.py +++ b/PaddleCV/image_classification/build_model.py @@ -15,6 +15,9 @@ import paddle import paddle.fluid as fluid import utils.utility as utility +AMP_MODEL_LIST = ["ResNet50", "SE_ResNet50_vd"] + + def _calc_label_smoothing_loss(softmax_out, label, class_dim, epsilon): """Calculate label smoothing loss @@ -34,11 +37,12 @@ def _calc_label_smoothing_loss(softmax_out, label, class_dim, epsilon): def _basic_model(data, model, args, is_train): image = data[0] label = data[1] - if args.model == "ResNet50": + if args.model in AMP_MODEL_LIST: image_data = (fluid.layers.cast(image, 'float16') - if args.use_pure_fp16 and not args.use_dali else image) + if args.use_pure_fp16 and not args.use_dali else image) image_transpose = fluid.layers.transpose( - image_data, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image_data + image_data, + [0, 2, 3, 1]) if args.data_format == 'NHWC' else image_data image_transpose.stop_gradient = image.stop_gradient net_out = model.net(input=image_transpose, class_dim=args.class_dim, @@ -55,8 +59,8 @@ def _basic_model(data, model, args, is_train): else: cost = fluid.layers.cross_entropy(input=softmax_out, label=label) - target_cost = (fluid.layers.reduce_sum(cost) if args.use_pure_fp16 - else fluid.layers.mean(cost)) + target_cost = (fluid.layers.reduce_sum(cost) + if args.use_pure_fp16 else fluid.layers.mean(cost)) acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1) acc_top5 = fluid.layers.accuracy( input=softmax_out, label=label, k=min(5, args.class_dim)) @@ -98,11 +102,12 @@ def _mixup_model(data, model, args, is_train): y_b = data[2] lam = data[3] - if args.model == "ResNet50": + if args.model in AMP_MODEL_LIST: image_data = (fluid.layers.cast(image, 'float16') - if args.use_pure_fp16 and not args.use_dali else image) + if args.use_pure_fp16 and not args.use_dali else image) image_transpose = fluid.layers.transpose( - image_data, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image_data + image_data, + [0, 2, 3, 1]) if args.data_format == 'NHWC' else image_data image_transpose.stop_gradient = image.stop_gradient net_out = model.net(input=image_transpose, class_dim=args.class_dim, diff --git a/PaddleCV/image_classification/models/se_resnet_vd.py b/PaddleCV/image_classification/models/se_resnet_vd.py index aa8b910decb09830509eb1198aa57664d50c6ee5..75ce0fad16d87f4cad8fbd7e2a248bad5a6a886c 100644 --- a/PaddleCV/image_classification/models/se_resnet_vd.py +++ b/PaddleCV/image_classification/models/se_resnet_vd.py @@ -21,8 +21,10 @@ import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr import math -__all__ = ["SE_ResNet_vd", "SE_ResNet18_vd","SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNet101_vd", "SE_ResNet152_vd", - "SE_ResNet200_vd"] +__all__ = [ + "SE_ResNet_vd", "SE_ResNet18_vd", "SE_ResNet34_vd", "SE_ResNet50_vd", + "SE_ResNet101_vd", "SE_ResNet152_vd", "SE_ResNet200_vd" +] class SE_ResNet_vd(): @@ -30,7 +32,7 @@ class SE_ResNet_vd(): self.layers = layers self.is_3x3 = is_3x3 - def net(self, input, class_dim=1000): + def net(self, input, class_dim=1000, data_format="NCHW"): is_3x3 = self.is_3x3 layers = self.layers supported_layers = [18, 34, 50, 101, 152, 200] @@ -38,7 +40,7 @@ class SE_ResNet_vd(): "supported layers are {} but input layer is {}".format(supported_layers, layers) if layers == 18: - depth = [2, 2, 2, 2] + depth = [2, 2, 2, 2] elif layers == 34 or layers == 50: depth = [3, 4, 6, 3] elif layers == 101: @@ -51,66 +53,94 @@ class SE_ResNet_vd(): reduction_ratio = 16 if is_3x3 == False: conv = self.conv_bn_layer( - input=input, num_filters=64, filter_size=7, stride=2, act='relu') + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + data_format=data_format) else: conv = self.conv_bn_layer( - input=input, num_filters=32, filter_size=3, stride=2, act='relu', name='conv1_1') + input=input, + num_filters=32, + filter_size=3, + stride=2, + act='relu', + name='conv1_1', + data_format=data_format) conv = self.conv_bn_layer( - input=conv, num_filters=32, filter_size=3, stride=1, act='relu', name='conv1_2') + input=conv, + num_filters=32, + filter_size=3, + stride=1, + act='relu', + name='conv1_2', + data_format=data_format) conv = self.conv_bn_layer( - input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_3') + input=conv, + num_filters=64, + filter_size=3, + stride=1, + act='relu', + name='conv1_3', + data_format=data_format) conv = fluid.layers.pool2d( input=conv, pool_size=3, pool_stride=2, pool_padding=1, - pool_type='max') + pool_type='max', + data_format=data_format) if layers >= 50: for block in range(len(depth)): for i in range(depth[block]): if layers in [101, 152, 200] and block == 2: if i == 0: - conv_name="res"+str(block+2)+"a" + conv_name = "res" + str(block + 2) + "a" else: - conv_name="res"+str(block+2)+"b"+str(i) + conv_name = "res" + str(block + 2) + "b" + str(i) else: - conv_name="res"+str(block+2)+chr(97+i) + conv_name = "res" + str(block + 2) + chr(97 + i) conv = self.bottleneck_block( input=conv, num_filters=num_filters[block], stride=2 if i == 0 and block != 0 else 1, - if_first=block==i==0, + if_first=block == i == 0, reduction_ratio=reduction_ratio, - name=conv_name) + name=conv_name, + data_format=data_format) else: for block in range(len(depth)): for i in range(depth[block]): - conv_name="res"+str(block+2)+chr(97+i) + conv_name = "res" + str(block + 2) + chr(97 + i) conv = self.basic_block( input=conv, num_filters=num_filters[block], stride=2 if i == 0 and block != 0 else 1, - if_first=block==i==0, + if_first=block == i == 0, reduction_ratio=reduction_ratio, - name=conv_name) + name=conv_name, + data_format=data_format) pool = fluid.layers.pool2d( - input=conv, pool_size=7, pool_type='avg', global_pooling=True) - + input=conv, + pool_size=7, + pool_type='avg', + global_pooling=True, + data_format=data_format) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) - out = fluid.layers.fc(input=pool, - size=class_dim, - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.Uniform(-stdv, stdv), name='fc6_weights'), - bias_attr=ParamAttr(name='fc6_offset')) - + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name='fc6_weights'), + bias_attr=ParamAttr(name='fc6_offset')) + return out - - - def conv_bn_layer(self, input, @@ -119,7 +149,8 @@ class SE_ResNet_vd(): stride=1, groups=1, act=None, - name=None): + name=None, + data_format='NCHW'): conv = fluid.layers.conv2d( input=input, num_filters=num_filters, @@ -129,34 +160,39 @@ class SE_ResNet_vd(): groups=groups, act=None, param_attr=ParamAttr(name=name + "_weights"), - bias_attr=False) + bias_attr=False, + data_format=data_format) if name == "conv1": bn_name = "bn_" + name else: - bn_name = "bn" + name[3:] - return fluid.layers.batch_norm(input=conv, - act=act, - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') - - + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + data_layout=data_format) + def conv_bn_layer_new(self, - input, - num_filters, - filter_size, - stride=1, - groups=1, - act=None, - name=None): - pool = fluid.layers.pool2d(input=input, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None, + data_format='NCHW'): + pool = fluid.layers.pool2d( + input=input, pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', - ceil_mode=True) - + ceil_mode=True, + data_format=data_format) + conv = fluid.layers.conv2d( input=pool, num_filters=num_filters, @@ -166,130 +202,198 @@ class SE_ResNet_vd(): groups=groups, act=None, param_attr=ParamAttr(name=name + "_weights"), - bias_attr=False) + bias_attr=False, + data_format=data_format) if name == "conv1": bn_name = "bn_" + name else: bn_name = "bn" + name[3:] - return fluid.layers.batch_norm(input=conv, - act=act, - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') - - - - def shortcut(self, input, ch_out, stride, name, if_first=False): - ch_in = input.shape[1] + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + data_layout=data_format) + + def shortcut(self, + input, + ch_out, + stride, + name, + if_first=False, + data_format='NCHW'): + if data_format == 'NCHW': + ch_in = input.shape[1] + else: + ch_in = input.shape[-1] if ch_in != ch_out or stride != 1: if if_first: - return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + return self.conv_bn_layer( + input, + ch_out, + 1, + stride, + name=name, + data_format=data_format) else: - return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name) + return self.conv_bn_layer_new( + input, + ch_out, + 1, + stride, + name=name, + data_format=data_format) elif if_first: - return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + return self.conv_bn_layer( + input, ch_out, 1, stride, name=name, data_format=data_format) else: return input - def bottleneck_block(self, input, num_filters, stride, name, if_first, reduction_ratio): + def bottleneck_block(self, input, num_filters, stride, name, if_first, + reduction_ratio, data_format): conv0 = self.conv_bn_layer( - input=input, - num_filters=num_filters, - filter_size=1, - act='relu', - name=name+"_branch2a") + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a", + data_format=data_format) conv1 = self.conv_bn_layer( input=conv0, num_filters=num_filters, filter_size=3, stride=stride, act='relu', - name=name+"_branch2b") - conv2 =self.conv_bn_layer( - input=conv1, - num_filters=num_filters * 4, - filter_size=1, - act=None, - name=name+"_branch2c") + name=name + "_branch2b", + data_format=data_format) + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + "_branch2c", + data_format=data_format) scale = self.squeeze_excitation( input=conv2, num_channels=num_filters * 4, reduction_ratio=reduction_ratio, - name='fc_'+name) + name='fc_' + name, + data_format=data_format) - short = self.shortcut(input, num_filters * 4, stride, if_first=if_first, name=name + "_branch1") + short = self.shortcut( + input, + num_filters * 4, + stride, + if_first=if_first, + name=name + "_branch1", + data_format=data_format) return fluid.layers.elementwise_add(x=short, y=scale, act='relu') - - def basic_block(self, input, num_filters, stride, name, if_first, reduction_ratio): - conv0 = self.conv_bn_layer(input=input, - num_filters=num_filters, - filter_size=3, - act='relu', - stride=stride, - name=name+"_branch2a") - conv1 = self.conv_bn_layer(input=conv0, - num_filters=num_filters, - filter_size=3, - act=None, - name=name+"_branch2b") + + def basic_block(self, input, num_filters, stride, name, if_first, + reduction_ratio, data_format): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=3, + act='relu', + stride=stride, + name=name + "_branch2a", + data_format=data_format) + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + act=None, + name=name + "_branch2b", + data_format=data_format) scale = self.squeeze_excitation( input=conv1, num_channels=num_filters, reduction_ratio=reduction_ratio, - name='fc_'+name) - short = self.shortcut(input, - num_filters, - stride, - if_first=if_first, - name=name + "_branch1") + name='fc_' + name, + data_format=data_format) + short = self.shortcut( + input, + num_filters, + stride, + if_first=if_first, + name=name + "_branch1", + data_format=data_format) return fluid.layers.elementwise_add(x=short, y=scale, act='relu') - - - def squeeze_excitation(self, input, num_channels, reduction_ratio, name=None): + + def squeeze_excitation(self, + input, + num_channels, + reduction_ratio, + name=None, + data_format='NCHW'): pool = fluid.layers.pool2d( - input=input, pool_size=0, pool_type='avg', global_pooling=True) + input=input, + pool_size=0, + pool_type='avg', + global_pooling=True, + data_format=data_format) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) - squeeze = fluid.layers.fc(input=pool, - size=num_channels // reduction_ratio, - act='relu', - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.Uniform( - -stdv, stdv),name=name+'_sqz_weights'), - bias_attr=ParamAttr(name=name+'_sqz_offset')) + squeeze = fluid.layers.fc( + input=pool, + size=num_channels // reduction_ratio, + act='relu', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_sqz_weights'), + bias_attr=ParamAttr(name=name + '_sqz_offset')) stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) - excitation = fluid.layers.fc(input=squeeze, - size=num_channels, - act='sigmoid', - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.Uniform(-stdv, stdv), - name=name+'_exc_weights'), - bias_attr=ParamAttr(name=name+'_exc_offset')) - scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) - return scale + excitation = fluid.layers.fc( + input=squeeze, + size=num_channels, + act='sigmoid', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_exc_weights'), + bias_attr=ParamAttr(name=name + '_exc_offset')) + + # scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + # return scale + + input_in = fluid.layers.transpose( + input, [0, 3, 1, 2]) if data_format == 'NHWC' else input + input_in.stop_gradient = input.stop_gradient + scale = fluid.layers.elementwise_mul(x=input_in, y=excitation, axis=0) + scale_out = fluid.layers.transpose( + scale, [0, 2, 3, 1]) if data_format == 'NHWC' else scale + scale_out.stop_gradient = scale.stop_gradient + + return scale_out + def SE_ResNet18_vd(): - model = SE_ResNet_vd(layers=18, is_3x3 = True) + model = SE_ResNet_vd(layers=18, is_3x3=True) return model + def SE_ResNet34_vd(): - model = SE_ResNet_vd(layers=34, is_3x3 = True) + model = SE_ResNet_vd(layers=34, is_3x3=True) return model - + + def SE_ResNet50_vd(): - model = SE_ResNet_vd(layers=50, is_3x3 = True) + model = SE_ResNet_vd(layers=50, is_3x3=True) return model + def SE_ResNet101_vd(): - model = SE_ResNet_vd(layers=101, is_3x3 = True) + model = SE_ResNet_vd(layers=101, is_3x3=True) return model + def SE_ResNet152_vd(): - model = SE_ResNet_vd(layers=152, is_3x3 = True) + model = SE_ResNet_vd(layers=152, is_3x3=True) return model + def SE_ResNet200_vd(): - model = SE_ResNet_vd(layers=200, is_3x3 = True) + model = SE_ResNet_vd(layers=200, is_3x3=True) return model - diff --git a/PaddleCV/image_classification/scripts/train/SE_ResNet50_vd_fp16.sh b/PaddleCV/image_classification/scripts/train/SE_ResNet50_vd_fp16.sh new file mode 100644 index 0000000000000000000000000000000000000000..a6eb162e56250e1d28caa8fcdecdea2a86f43ca6 --- /dev/null +++ b/PaddleCV/image_classification/scripts/train/SE_ResNet50_vd_fp16.sh @@ -0,0 +1,43 @@ +#SE_ResNet50_vd + +export CUDA_VISIBLE_DEVICES=4 + +export FLAGS_conv_workspace_size_limit=4000 #MB +export FLAGS_cudnn_exhaustive_search=1 +export FLAGS_cudnn_batchnorm_spatial_persistent=1 + +DATA_DIR="Your image dataset path, e.g. /work/datasets/ILSVRC2012/" + +DATA_FORMAT="NHWC" +USE_FP16=true #whether to use float16 +USE_DALI=true +USE_ADDTO=true + +if ${USE_ADDTO} ;then + export FLAGS_max_inplace_grad_add=8 +fi + +if ${USE_DALI}; then + export FLAGS_fraction_of_gpu_memory_to_use=0.8 +fi + +python train.py \ + --model=SE_ResNet50_vd \ + --data_dir=${DATA_DIR} \ + --batch_size=128 \ + --lr_strategy=cosine_decay \ + --use_fp16=${USE_FP16} \ + --data_format=${DATA_FORMAT} \ + --lr=0.1 \ + --num_epochs=200 \ + --model_save_dir=output/ \ + --l2_decay=1e-4 \ + --use_mixup=False \ + --use_label_smoothing=True \ + --label_smoothing_epsilon=0.1 \ + --enable_addto=${USE_ADDTO} \ + --use_dali=${USE_DALI} \ + --image_shape 4 224 224 \ + --fuse_bn_act_ops=true \ + --fuse_bn_add_act_ops=true \ + --fuse_elewise_add_act_ops=true \ diff --git a/PaddleCV/image_classification/train.py b/PaddleCV/image_classification/train.py index d944b33859ac9be2435a0abb669e15d162656b3f..3749c689563cc5e07f707a552aea39ce676b2a3c 100755 --- a/PaddleCV/image_classification/train.py +++ b/PaddleCV/image_classification/train.py @@ -268,6 +268,7 @@ def train(args): #NOTE: this is for benchmark if args.max_iter and total_batch_num == args.max_iter: return + reader_cost_averager.record(time.time() - batch_start) train_batch_metrics = exe.run(compiled_train_prog,