diff --git a/PaddleCV/image_classification/build_model.py b/PaddleCV/image_classification/build_model.py index 3f7a3a80753fc3bd4a51b78ff995112d5d45eafe..1168f70c1aaae041292994f2adca7a6d8720efb2 100644 --- a/PaddleCV/image_classification/build_model.py +++ b/PaddleCV/image_classification/build_model.py @@ -15,13 +15,12 @@ import paddle import paddle.fluid as fluid import utils.utility as utility - def _calc_label_smoothing_loss(softmax_out, label, class_dim, epsilon): """Calculate label smoothing loss Returns: label smoothing loss - + """ label_one_hot = fluid.layers.one_hot(input=label, depth=class_dim) @@ -36,33 +35,37 @@ def _basic_model(data, model, args, is_train): image = data[0] label = data[1] if args.model == "ResNet50": - image_in = fluid.layers.transpose( - image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image - image_in.stop_gradient = image.stop_gradient - net_out = model.net(input=image_in, + image_data = (fluid.layers.cast(image, 'float16') + if args.use_pure_fp16 and not args.use_dali else image) + image_transpose = fluid.layers.transpose( + image_data, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image_data + image_transpose.stop_gradient = image.stop_gradient + net_out = model.net(input=image_transpose, class_dim=args.class_dim, data_format=args.data_format) else: net_out = model.net(input=image, class_dim=args.class_dim) + if args.use_pure_fp16: + net_out = fluid.layers.cast(x=net_out, dtype="float32") softmax_out = fluid.layers.softmax(net_out, use_cudnn=False) if is_train and args.use_label_smoothing: cost = _calc_label_smoothing_loss(softmax_out, label, args.class_dim, args.label_smoothing_epsilon) - else: cost = fluid.layers.cross_entropy(input=softmax_out, label=label) - avg_cost = fluid.layers.mean(cost) + target_cost = (fluid.layers.reduce_sum(cost) if args.use_pure_fp16 + else fluid.layers.mean(cost)) acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1) acc_top5 = fluid.layers.accuracy( input=softmax_out, label=label, k=min(5, args.class_dim)) - return [avg_cost, acc_top1, acc_top5] + return [target_cost, acc_top1, acc_top5] def _googlenet_model(data, model, args, is_train): """GoogLeNet model output, include avg_cost, acc_top1 and acc_top5 - + Returns: GoogLeNet model output @@ -96,15 +99,21 @@ def _mixup_model(data, model, args, is_train): lam = data[3] if args.model == "ResNet50": - image_in = fluid.layers.transpose( - image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image - image_in.stop_gradient = image.stop_gradient - net_out = model.net(input=image_in, + image_data = (fluid.layers.cast(image, 'float16') + if args.use_pure_fp16 and not args.use_dali else image) + image_transpose = fluid.layers.transpose( + image_data, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image_data + image_transpose.stop_gradient = image.stop_gradient + net_out = model.net(input=image_transpose, class_dim=args.class_dim, data_format=args.data_format) else: net_out = model.net(input=image, class_dim=args.class_dim) - softmax_out = fluid.layers.softmax(net_out, use_cudnn=False) + if args.use_pure_fp16: + net_out_fp32 = fluid.layers.cast(x=net_out, dtype="float32") + softmax_out = fluid.layers.softmax(net_out_fp32, use_cudnn=False) + else: + softmax_out = fluid.layers.softmax(net_out, use_cudnn=False) if not args.use_label_smoothing: loss_a = fluid.layers.cross_entropy(input=softmax_out, label=y_a) loss_b = fluid.layers.cross_entropy(input=softmax_out, label=y_b) @@ -114,11 +123,17 @@ def _mixup_model(data, model, args, is_train): loss_b = _calc_label_smoothing_loss(softmax_out, y_b, args.class_dim, args.label_smoothing_epsilon) - loss_a_mean = fluid.layers.mean(x=loss_a) - loss_b_mean = fluid.layers.mean(x=loss_b) - cost = lam * loss_a_mean + (1 - lam) * loss_b_mean - avg_cost = fluid.layers.mean(x=cost) - return [avg_cost] + if args.use_pure_fp16: + target_loss_a = fluid.layers.reduce_sum(x=loss_a) + target_loss_b = fluid.layers.reduce_sum(x=loss_b) + cost = lam * target_loss_a + (1 - lam) * target_loss_b + target_cost = fluid.layers.reduce_sum(x=cost) + else: + target_loss_a = fluid.layers.mean(x=loss_a) + target_loss_b = fluid.layers.mean(x=loss_b) + cost = lam * target_loss_a + (1 - lam) * target_loss_b + target_cost = fluid.layers.mean(x=cost) + return [target_cost] def create_model(model, args, is_train): diff --git a/PaddleCV/image_classification/dali.py b/PaddleCV/image_classification/dali.py index 061aaefd32d8ef4a381130e7155485762a2297ff..653acd6631e2f2062becddcdce5b95ff23ee6f2f 100644 --- a/PaddleCV/image_classification/dali.py +++ b/PaddleCV/image_classification/dali.py @@ -44,7 +44,8 @@ class HybridTrainPipe(Pipeline): random_shuffle=True, num_threads=4, seed=42, - pad_output=False): + pad_output=False, + output_dtype=types.FLOAT): super(HybridTrainPipe, self).__init__( batch_size, num_threads, device_id, seed=seed) self.input = ops.FileReader( @@ -69,7 +70,7 @@ class HybridTrainPipe(Pipeline): device='gpu', resize_x=crop, resize_y=crop, interp_type=interp) self.cmnp = ops.CropMirrorNormalize( device="gpu", - output_dtype=types.FLOAT, + output_dtype=output_dtype, output_layout=types.NCHW, crop=(crop, crop), image_type=types.RGB, @@ -107,7 +108,8 @@ class HybridValPipe(Pipeline): random_shuffle=False, num_threads=4, seed=42, - pad_output=False): + pad_output=False, + output_dtype=types.FLOAT): super(HybridValPipe, self).__init__( batch_size, num_threads, device_id, seed=seed) self.input = ops.FileReader( @@ -121,7 +123,7 @@ class HybridValPipe(Pipeline): device="gpu", resize_shorter=resize_shorter, interp_type=interp) self.cmnp = ops.CropMirrorNormalize( device="gpu", - output_dtype=types.FLOAT, + output_dtype=output_dtype, output_layout=types.NCHW, crop=(crop, crop), image_type=types.RGB, @@ -163,6 +165,7 @@ def build(settings, mode='train'): min_area = settings.lower_scale lower = settings.lower_ratio upper = settings.upper_ratio + output_dtype = types.FLOAT16 if settings.use_pure_fp16 else types.FLOAT interp = settings.interpolation or 1 # default to linear interp_map = { @@ -196,7 +199,8 @@ def build(settings, mode='train'): mean, std, device_id=device_id, - pad_output=pad_output) + pad_output=pad_output, + output_dtype=output_dtype) pipe.build() return DALIGenericIterator( pipe, ['feed_image', 'feed_label'], @@ -230,7 +234,8 @@ def build(settings, mode='train'): shard_id, num_shards, seed=42 + shard_id, - pad_output=pad_output) + pad_output=pad_output, + output_dtype=output_dtype) pipe.build() pipelines = [pipe] sample_per_shard = len(pipe) // num_shards @@ -258,7 +263,8 @@ def build(settings, mode='train'): idx, num_shards, seed=42 + idx, - pad_output=pad_output) + pad_output=pad_output, + output_dtype=output_dtype) pipe.build() pipelines.append(pipe) sample_per_shard = len(pipelines[0]) diff --git a/PaddleCV/image_classification/scripts/train/ResNet50_fp16.sh b/PaddleCV/image_classification/scripts/train/ResNet50_fp16.sh index 3a4090c1c43d42fdf72dea90d1df3d53e9a6126d..9c1ba24b0fc6bbfe4d88e646f3b129ccdcd23014 100755 --- a/PaddleCV/image_classification/scripts/train/ResNet50_fp16.sh +++ b/PaddleCV/image_classification/scripts/train/ResNet50_fp16.sh @@ -8,7 +8,9 @@ export FLAGS_cudnn_batchnorm_spatial_persistent=1 DATA_DIR="Your image dataset path, e.g. /work/datasets/ILSVRC2012/" DATA_FORMAT="NHWC" -USE_FP16=true #whether to use float16 +USE_AMP=true #whether to use amp +USE_PURE_FP16=false +MULTI_PRECISION=${USE_PURE_FP16} USE_DALI=true USE_ADDTO=true @@ -30,7 +32,9 @@ python train.py \ --print_step=10 \ --model_save_dir=output/ \ --lr_strategy=piecewise_decay \ - --use_fp16=${USE_FP16} \ + --use_amp=${USE_AMP} \ + --use_pure_fp16=${USE_PURE_FP16} \ + --multi_precision=${MULTI_PRECISION} \ --scale_loss=128.0 \ --use_dynamic_loss_scaling=true \ --data_format=${DATA_FORMAT} \ @@ -44,5 +48,6 @@ python train.py \ --reader_thread=10 \ --reader_buf_size=4000 \ --use_dali=${USE_DALI} \ - --lr=0.1 + --lr=0.1 \ + --random_seed=2020 diff --git a/PaddleCV/image_classification/train.py b/PaddleCV/image_classification/train.py index 59ae6983cec8521bc23c716d46da441ef887f7a1..d944b33859ac9be2435a0abb669e15d162656b3f 100755 --- a/PaddleCV/image_classification/train.py +++ b/PaddleCV/image_classification/train.py @@ -61,7 +61,7 @@ def build_program(is_train, main_prog, startup_prog, args): startup_prog: strartup program args: arguments - Returns : + Returns : train mode: [Loss, global_lr, data_loader] test mode: [Loss, data_loader] """ @@ -85,12 +85,12 @@ def build_program(is_train, main_prog, startup_prog, args): if is_train: optimizer = create_optimizer(args) avg_cost = loss_out[0] - #XXX: fetch learning rate now, better implement is required here. + #XXX: fetch learning rate now, better implement is required here. global_lr = optimizer._global_learning_rate() global_lr.persistable = True loss_out.append(global_lr) - if args.use_fp16: + if args.use_amp: optimizer = fluid.contrib.mixed_precision.decorate( optimizer, init_loss_scaling=args.scale_loss, @@ -172,9 +172,9 @@ def validate(args, def train(args): """Train model - + Args: - args: all arguments. + args: all arguments. """ startup_prog = fluid.Program() train_prog = fluid.Program() diff --git a/PaddleCV/image_classification/utils/optimizer.py b/PaddleCV/image_classification/utils/optimizer.py index 176ba0af0be50aaaba9decffa79c0917f137dfe2..8a9c11041b04316a915720b6bc04d21ac02a372c 100644 --- a/PaddleCV/image_classification/utils/optimizer.py +++ b/PaddleCV/image_classification/utils/optimizer.py @@ -17,7 +17,8 @@ from __future__ import division from __future__ import print_function import math - +import os +import subprocess import paddle.fluid as fluid import paddle.fluid.layers.ops as ops from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter @@ -144,7 +145,7 @@ class Optimizer(object): total_images: total images. step: total steps in the an epoch. - + """ def __init__(self, args): @@ -159,6 +160,9 @@ class Optimizer(object): self.decay_epochs = args.decay_epochs self.decay_rate = args.decay_rate self.total_images = args.total_images + self.multi_precision = args.multi_precision + self.rescale_grad = (1.0 / (args.batch_size / len(fluid.cuda_places())) + if args.use_pure_fp16 else 1.0) self.step = int(math.ceil(float(self.total_images) / self.batch_size)) @@ -171,10 +175,12 @@ class Optimizer(object): bd = [self.step * e for e in self.step_epochs] lr = [self.lr * (0.1**i) for i in range(len(bd) + 1)] learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) - optimizer = fluid.optimizer.Momentum( + optimizer = fluid.contrib.optimizer.Momentum( learning_rate=learning_rate, momentum=self.momentum_rate, - regularization=fluid.regularizer.L2Decay(self.l2_decay)) + regularization=fluid.regularizer.L2Decay(self.l2_decay), + multi_precision=self.multi_precision, + rescale_grad=self.rescale_grad) return optimizer def cosine_decay(self): @@ -188,10 +194,12 @@ class Optimizer(object): learning_rate=self.lr, step_each_epoch=self.step, epochs=self.num_epochs) - optimizer = fluid.optimizer.Momentum( + optimizer = fluid.contrib.optimizer.Momentum( learning_rate=learning_rate, momentum=self.momentum_rate, - regularization=fluid.regularizer.L2Decay(self.l2_decay)) + regularization=fluid.regularizer.L2Decay(self.l2_decay), + multi_precision=self.multi_precision, + rescale_grad=self.rescale_grad) return optimizer def cosine_decay_warmup(self): @@ -206,10 +214,12 @@ class Optimizer(object): step_each_epoch=self.step, epochs=self.num_epochs, warm_up_epoch=self.warm_up_epochs) - optimizer = fluid.optimizer.Momentum( + optimizer = fluid.contrib.optimizer.Momentum( learning_rate=learning_rate, momentum=self.momentum_rate, - regularization=fluid.regularizer.L2Decay(self.l2_decay)) + regularization=fluid.regularizer.L2Decay(self.l2_decay), + multi_precision=self.multi_precision, + rescale_grad=self.rescale_grad) return optimizer def exponential_decay_warmup(self): @@ -243,17 +253,19 @@ class Optimizer(object): end_lr = 0 learning_rate = fluid.layers.polynomial_decay( self.lr, self.step, end_lr, power=1) - optimizer = fluid.optimizer.Momentum( + optimizer = fluid.contrib.optimizer.Momentum( learning_rate=learning_rate, momentum=self.momentum_rate, - regularization=fluid.regularizer.L2Decay(self.l2_decay)) + regularization=fluid.regularizer.L2Decay(self.l2_decay), + multi_precision=self.multi_precision, + rescale_grad=self.rescale_grad) return optimizer def adam_decay(self): """Adam optimizer - Returns: + Returns: an adam_decay optimizer """ @@ -262,7 +274,7 @@ class Optimizer(object): def cosine_decay_RMSProp(self): """cosine decay with RMSProp optimizer - Returns: + Returns: an cosine_decay_RMSProp optimizer """ @@ -285,10 +297,12 @@ class Optimizer(object): default decay optimizer """ - optimizer = fluid.optimizer.Momentum( + optimizer = fluid.contrib.optimizer.Momentum( learning_rate=self.lr, momentum=self.momentum_rate, - regularization=fluid.regularizer.L2Decay(self.l2_decay)) + regularization=fluid.regularizer.L2Decay(self.l2_decay), + multi_precision=self.multi_precision, + rescale_grad=self.rescale_grad) return optimizer diff --git a/PaddleCV/image_classification/utils/utility.py b/PaddleCV/image_classification/utils/utility.py index 537004c0f23a0a8806732c84f94bd518b40b4157..e9162d0510c33e3369764ae6361f07925c9533f8 100644 --- a/PaddleCV/image_classification/utils/utility.py +++ b/PaddleCV/image_classification/utils/utility.py @@ -139,7 +139,9 @@ def parse_args(): # SWITCH add_arg('validate', bool, True, "whether to validate when training.") - add_arg('use_fp16', bool, False, "Whether to enable half precision training with fp16." ) + add_arg('use_amp', bool, False, "Whether to enable mixed precision training with fp16." ) + add_arg('use_pure_fp16', bool, False, "Whether to enable all half precision training with fp16." ) + add_arg('multi_precision', bool, False, "Whether to enable multi-precision training with fp16." ) add_arg('scale_loss', float, 1.0, "The value of scale_loss for fp16." ) add_arg('use_dynamic_loss_scaling', bool, True, "Whether to use dynamic loss scaling.") add_arg('data_format', str, "NCHW", "Tensor data format when training.") @@ -377,10 +379,13 @@ def create_data_loader(is_train, args): data_loader and the input data of net, """ image_shape = args.image_shape + image_dtype = "float32" + if args.model == "ResNet50" and args.use_pure_fp16 and args.use_dali: + image_dtype = "float16" feed_image = fluid.data( name="feed_image", shape=[None] + image_shape, - dtype="float32", + dtype=image_dtype, lod_level=0) feed_label = fluid.data( @@ -394,7 +399,7 @@ def create_data_loader(is_train, args): feed_y_b = fluid.data( name="feed_y_b", shape=[None, 1], dtype="int64", lod_level=0) feed_lam = fluid.data( - name="feed_lam", shape=[None, 1], dtype="float32", lod_level=0) + name="feed_lam", shape=[None, 1], dtype=image_dtype, lod_level=0) data_loader = fluid.io.DataLoader.from_generator( feed_list=[feed_image, feed_y_a, feed_y_b, feed_lam], @@ -556,7 +561,7 @@ def best_strategy_compiled(args, if args.use_gpu: exec_strategy.num_threads = fluid.core.get_cuda_device_count() - exec_strategy.num_iteration_per_drop_scope = 10 + exec_strategy.num_iteration_per_drop_scope = 10000 if args.use_pure_fp16 else 10 num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) if num_trainers > 1 and args.use_gpu: