diff --git a/PaddleCV/image_classification/README.md b/PaddleCV/image_classification/README.md index f6901ffb91d2a29797344c76fcdef8c6256b138c..507932b40d184ee352a602598244cf1523393234 100644 --- a/PaddleCV/image_classification/README.md +++ b/PaddleCV/image_classification/README.md @@ -365,15 +365,31 @@ Mixup相关介绍参考[mixup: Beyond Empirical Risk Minimization](https://arxiv ### 混合精度训练 -通过指定--fp16=True 启动混合训练,在训练过程中会使用float16数据,并输出float32的模型参数。您可能需要同时传入--scale_loss来解决fp16训练的精度问题,通常传入--scale_loss=0.8即可。 +通过指定--use_fp16=True 启动混合精度训练,在训练过程中会使用float16数据类型,并输出float32的模型参数。您可能需要同时传入--scale_loss来解决fp16训练的精度问题,如传入--scale_loss=128.0。 -```bash -python train.py \ - --model=ResNet50 \ - --fp16=True \ - --scale_loss=0.8 -``` -具体内容也可参考[Fleet](https://github.com/PaddlePaddle/Fleet/tree/develop/benchmark/collective/resnet) +在配置好数据集路径后(修改[scripts/train/ResNet50_fp16.sh](scripts/train/ResNet50_fp16.sh)文件中`DATA_DIR`的值),对ResNet50模型进行混合精度训练可通过运行`bash run.sh train ResNet50_fp16`命令完成。 + +多机多卡ResNet50模型的混合精度训练请参考[PaddlePaddle/Fleet](https://github.com/PaddlePaddle/Fleet/tree/develop/benchmark/collective/resnet)。 + +使用Tesla V100单机8卡、2机器16卡、4机器32卡,对ResNet50模型进行混合精度训练的结果如下(开启DALI): + +* BatchSize = 256 + +节点数*卡数|吞吐|加速比|test\_acc1|test\_acc5 +---|---|---|---|--- +1*1|1035 ins/s|1|0.75333|0.92702 +1*8|7840 ins/s|7.57|0.75603|0.92771 +2*8|14277 ins/s|13.79|0.75872|0.92793 +4*8|28594 ins/s|27.63|0.75253|0.92713 + +* BatchSize = 128 + +节点数*卡数|吞吐|加速比|test\_acc1|test\_acc5 +---|---|---|---|--- +1*1|936 ins/s|1|0.75280|0.92531 +1*8|7108 ins/s|7.59|0.75832|0.92771 +2*8|12343 ins/s|13.18|0.75766|0.92723 +4*8|24407 ins/s|26.07|0.75859|0.92871 ### 性能分析 diff --git a/PaddleCV/image_classification/README_en.md b/PaddleCV/image_classification/README_en.md index 1f1cbbf3ed76f71a017e371bfba21032092c4e5f..437cd1ba01352650437e0cdbda2dcdb406568708 100644 --- a/PaddleCV/image_classification/README_en.md +++ b/PaddleCV/image_classification/README_en.md @@ -253,16 +253,31 @@ Refer to [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710. ### Using Mixed-Precision Training -Set --fp16=True to sart Mixed-Precision Training. +Set --use_fp16=True to sart Automatic Mixed Precision (AMP) Training. During the training process, the float16 data type will be used to speed up the training performance. You may need to use the --scale_loss parameter to avoid the accuracy dropping, such as setting --scale_loss=128.0. -```bash -python train.py \ - --model=ResNet50 \ - --fp16=True \ - --scale_loss=0.8 -``` +After configuring the data path (modify the value of `DATA_DIR` in [scripts/train/ResNet50_fp16.sh](scripts/train/ResNet50_fp16.sh)), you can enable ResNet50 to start AMP Training by executing the command of `bash run.sh train ResNet50_fp16`. + +Refer to [PaddlePaddle/Fleet](https://github.com/PaddlePaddle/Fleet/tree/develop/benchmark/collective/resnet) for the multi-machine and multi-card training. + +Performing on Tesla V100 single machine with 8 cards, two machines with 16 cards and four machines with 32 cards, the performance of ResNet50 AMP training is shown as below (enable DALI). + +* BatchSize = 256 + +nodes*crads|throughput|speedup|test\_acc1|test\_acc5 +---|---|---|---|--- +1*1|1035 ins/s|1|0.75333|0.92702 +1*8|7840 ins/s|7.57|0.75603|0.92771 +2*8|14277 ins/s|13.79|0.75872|0.92793 +4*8|28594 ins/s|27.63|0.75253|0.92713 + +* BatchSize = 128 -Refer to [PaddlePaddle/Fleet](https://github.com/PaddlePaddle/Fleet/tree/develop/benchmark/collective/resnet) +nodes*crads|throughput|speedup|test\_acc1|test\_acc5 +---|---|---|---|--- +1*1|936 ins/s|1|0.75280|0.92531 +1*8|7108 ins/s|7.59|0.75832|0.92771 +2*8|12343 ins/s|13.18|0.75766|0.92723 +4*8|24407 ins/s|26.07|0.75859|0.92871 ### Preprocessing with Nvidia DALI diff --git a/PaddleCV/image_classification/build_model.py b/PaddleCV/image_classification/build_model.py index 04c0d0ee71834b475cf00bdda26ea9209a302e97..a0dfd1310ad83c5bb16efceb4895c98f471a5c20 100644 --- a/PaddleCV/image_classification/build_model.py +++ b/PaddleCV/image_classification/build_model.py @@ -35,8 +35,11 @@ def _calc_label_smoothing_loss(softmax_out, label, class_dim, epsilon): def _basic_model(data, model, args, is_train): image = data[0] label = data[1] - - net_out = model.net(input=image, class_dim=args.class_dim) + if args.model == "ResNet50": + image_in = fluid.layers.transpose(image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image + net_out = model.net(input=image_in, class_dim=args.class_dim, data_format=args.data_format) + else: + net_out = model.net(input=image, class_dim=args.class_dim) softmax_out = fluid.layers.softmax(net_out, use_cudnn=False) if is_train and args.use_label_smoothing: @@ -88,7 +91,11 @@ def _mixup_model(data, model, args, is_train): y_b = data[2] lam = data[3] - net_out = model.net(input=image, class_dim=args.class_dim) + if args.model == "ResNet50": + image_in = fluid.layers.transpose(image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image + net_out = model.net(input=image_in, class_dim=args.class_dim, data_format=args.data_format) + else: + net_out = model.net(input=image, class_dim=args.class_dim) softmax_out = fluid.layers.softmax(net_out, use_cudnn=False) if not args.use_label_smoothing: loss_a = fluid.layers.cross_entropy(input=softmax_out, label=y_a) diff --git a/PaddleCV/image_classification/models/resnet.py b/PaddleCV/image_classification/models/resnet.py index bb68d018b58d6bd7197306a17042619215558eb9..fcf453588ff13e8c53d185940cfc2b060ec4e1ac 100644 --- a/PaddleCV/image_classification/models/resnet.py +++ b/PaddleCV/image_classification/models/resnet.py @@ -31,7 +31,7 @@ class ResNet(): def __init__(self, layers=50): self.layers = layers - def net(self, input, class_dim=1000): + def net(self, input, class_dim=1000, data_format="NCHW"): layers = self.layers supported_layers = [18, 34, 50, 101, 152] assert layers in supported_layers, \ @@ -53,13 +53,15 @@ class ResNet(): filter_size=7, stride=2, act='relu', - name="conv1") + name="conv1", + data_format=data_format) conv = fluid.layers.pool2d( input=conv, pool_size=3, pool_stride=2, pool_padding=1, - pool_type='max') + pool_type='max', + data_format=data_format) if layers >= 50: for block in range(len(depth)): for i in range(depth[block]): @@ -74,10 +76,11 @@ class ResNet(): input=conv, num_filters=num_filters[block], stride=2 if i == 0 and block != 0 else 1, - name=conv_name) + name=conv_name, + data_format=data_format) pool = fluid.layers.pool2d( - input=conv, pool_type='avg', global_pooling=True) + input=conv, pool_type='avg', global_pooling=True, data_format=data_format) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) out = fluid.layers.fc( input=pool, @@ -93,10 +96,11 @@ class ResNet(): num_filters=num_filters[block], stride=2 if i == 0 and block != 0 else 1, is_first=block == i == 0, - name=conv_name) + name=conv_name, + data_format=data_format) pool = fluid.layers.pool2d( - input=conv, pool_type='avg', global_pooling=True) + input=conv, pool_type='avg', global_pooling=True, data_format=data_format) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) out = fluid.layers.fc( input=pool, @@ -112,7 +116,8 @@ class ResNet(): stride=1, groups=1, act=None, - name=None): + name=None, + data_format='NCHW'): conv = fluid.layers.conv2d( input=input, num_filters=num_filters, @@ -123,7 +128,8 @@ class ResNet(): act=None, param_attr=ParamAttr(name=name + "_weights"), bias_attr=False, - name=name + '.conv2d.output.1') + name=name + '.conv2d.output.1', + data_format=data_format) if name == "conv1": bn_name = "bn_" + name @@ -136,62 +142,72 @@ class ResNet(): param_attr=ParamAttr(name=bn_name + '_scale'), bias_attr=ParamAttr(bn_name + '_offset'), moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance', ) + moving_variance_name=bn_name + '_variance', + data_layout=data_format) - def shortcut(self, input, ch_out, stride, is_first, name): - ch_in = input.shape[1] + def shortcut(self, input, ch_out, stride, is_first, name, data_format): + if data_format == 'NCHW': + ch_in = input.shape[1] + else: + ch_in = input.shape[-1] if ch_in != ch_out or stride != 1 or is_first == True: - return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + return self.conv_bn_layer(input, ch_out, 1, stride, name=name, data_format=data_format) else: return input - def bottleneck_block(self, input, num_filters, stride, name): + def bottleneck_block(self, input, num_filters, stride, name, data_format): conv0 = self.conv_bn_layer( input=input, num_filters=num_filters, filter_size=1, act='relu', - name=name + "_branch2a") + name=name + "_branch2a", + data_format=data_format) conv1 = self.conv_bn_layer( input=conv0, num_filters=num_filters, filter_size=3, stride=stride, act='relu', - name=name + "_branch2b") + name=name + "_branch2b", + data_format=data_format) conv2 = self.conv_bn_layer( input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, - name=name + "_branch2c") + name=name + "_branch2c", + data_format=data_format) short = self.shortcut( input, num_filters * 4, stride, is_first=False, - name=name + "_branch1") + name=name + "_branch1", + data_format=data_format) return fluid.layers.elementwise_add( x=short, y=conv2, act='relu', name=name + ".add.output.5") - def basic_block(self, input, num_filters, stride, is_first, name): + def basic_block(self, input, num_filters, stride, is_first, name, data_format): conv0 = self.conv_bn_layer( input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride, - name=name + "_branch2a") + name=name + "_branch2a", + data_format=data_format) conv1 = self.conv_bn_layer( input=conv0, num_filters=num_filters, filter_size=3, act=None, - name=name + "_branch2b") + name=name + "_branch2b", + data_format=data_format) short = self.shortcut( - input, num_filters, stride, is_first, name=name + "_branch1") + input, num_filters, stride, is_first, name=name + "_branch1", data_format=data_format) return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') diff --git a/PaddleCV/image_classification/scripts/train/ResNet50_fp16.sh b/PaddleCV/image_classification/scripts/train/ResNet50_fp16.sh new file mode 100755 index 0000000000000000000000000000000000000000..6ebd5c01ff3de389ec5cf9f8aaeb5d0f3f690715 --- /dev/null +++ b/PaddleCV/image_classification/scripts/train/ResNet50_fp16.sh @@ -0,0 +1,40 @@ +#!/bin/bash -ex + +export FLAGS_conv_workspace_size_limit=4000 #MB +export FLAGS_cudnn_exhaustive_search=1 +export FLAGS_cudnn_batchnorm_spatial_persistent=1 + +DATA_DIR="Your image dataset path, e.g. /work/datasets/ILSVRC2012/" + +DATA_FORMAT="NHWC" +USE_FP16=true #whether to use float16 +USE_DALI=true + +if ${USE_DALI}; then + export FLAGS_fraction_of_gpu_memory_to_use=0.8 +fi + +python train.py \ + --model=ResNet50 \ + --data_dir=${DATA_DIR} \ + --batch_size=256 \ + --total_images=1281167 \ + --image_shape 3 224 224 \ + --class_dim=1000 \ + --print_step=10 \ + --model_save_dir=output/ \ + --lr_strategy=piecewise_decay \ + --use_fp16=${USE_FP16} \ + --scale_loss=128.0 \ + --use_dynamic_loss_scaling=true \ + --data_format=${DATA_FORMAT} \ + --fuse_elewise_add_act_ops=true \ + --fuse_bn_act_ops=true \ + --validate=true \ + --is_profiler=false \ + --profiler_path=profile/ \ + --reader_thread=10 \ + --reader_buf_size=4000 \ + --use_dali=${USE_DALI} \ + --lr=0.1 + diff --git a/PaddleCV/image_classification/utils/utility.py b/PaddleCV/image_classification/utils/utility.py index 3779460367a7b6877d24b8a200989beb568f277c..5abd0c7d6ec0c4996cc994d1ce76d84f0cab5a15 100644 --- a/PaddleCV/image_classification/utils/utility.py +++ b/PaddleCV/image_classification/utils/utility.py @@ -65,7 +65,7 @@ def print_arguments(args): def add_arguments(argname, type, default, help, argparser, **kwargs): - """Add argparse's argument. + """Add argparse's argument. Usage: @@ -87,7 +87,7 @@ def add_arguments(argname, type, default, help, argparser, **kwargs): def parse_args(): """Add arguments - Returns: + Returns: all training args """ parser = argparse.ArgumentParser(description=__doc__) @@ -142,6 +142,9 @@ def parse_args(): add_arg('use_fp16', bool, False, "Whether to enable half precision training with fp16." ) add_arg('scale_loss', float, 1.0, "The value of scale_loss for fp16." ) add_arg('use_dynamic_loss_scaling', bool, True, "Whether to use dynamic loss scaling.") + add_arg('data_format', str, "NCHW", "Tensor data format when training.") + add_arg('fuse_elewise_add_act_ops', bool, False, "Whether to use elementwise_act fusion.") + add_arg('fuse_bn_act_ops', bool, False, "Whether to use batch_norm and act fusion.") add_arg('use_label_smoothing', bool, False, "Whether to use label_smoothing") add_arg('label_smoothing_epsilon', float, 0.1, "The value of label_smoothing_epsilon parameter") @@ -168,7 +171,7 @@ def parse_args(): def check_gpu(): - """ + """ Log error and exit when set use_gpu=true in paddlepaddle cpu ver sion. """ @@ -364,12 +367,12 @@ def create_data_loader(is_train, args): Usage: Using mixup process in training, it will return 5 results, include data_loader, image, y_a(label), y_b(label) and lamda, or it will return 3 results, include data_loader, image, and label. - Args: + Args: is_train: mode args: arguments Returns: - data_loader and the input data of net, + data_loader and the input data of net, """ image_shape = args.image_shape feed_image = fluid.data( @@ -428,7 +431,7 @@ def print_info(info_mode, time_info: time infomation info_mode: mode """ - #XXX: Use specific name to choose pattern, not the length of metrics. + #XXX: Use specific name to choose pattern, not the length of metrics. if info_mode == "batch": if batch_id % print_step == 0: #if isinstance(metrics,np.ndarray): @@ -518,6 +521,8 @@ def best_strategy_compiled(args, return program else: build_strategy = fluid.compiler.BuildStrategy() + build_strategy.fuse_bn_act_ops = args.fuse_bn_act_ops + build_strategy.fuse_elewise_add_act_ops = args.fuse_elewise_add_act_ops exec_strategy = fluid.ExecutionStrategy()