From c6a598a2767f930f675980f6527be3c55895d9d9 Mon Sep 17 00:00:00 2001 From: Jie Fang Date: Fri, 6 Sep 2019 09:30:43 +0800 Subject: [PATCH] init new amp, optimize inserting cast op for batchnorm (#18596) init new amp, optimize inserting cast op for batchnorm --- .../contrib/mixed_precision/decorator.py | 34 +++---- .../contrib/mixed_precision/fp16_lists.py | 1 + .../contrib/mixed_precision/fp16_utils.py | 99 ++----------------- .../tests/test_image_classification_fp16.py | 5 +- 4 files changed, 28 insertions(+), 111 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index abca8c52a45..d7e028a8675 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -17,7 +17,6 @@ from ... import default_startup_program from ... import layers from ... import unique_name from . import fp16_utils -from .fp16_utils import create_master_params_grads, master_param_to_train_param from .fp16_utils import update_loss_scaling, rewrite_program from .fp16_lists import AutoMixedPrecisionLists @@ -128,19 +127,20 @@ class OptimizerWithMixedPrecison(object): self._param_grads = self._optimizer.backward( scaled_loss, startup_program, parameter_list, no_grad_set, callbacks) - master_params_grads = create_master_params_grads( - self._param_grads, self._train_program, self._startup_prog, - self._loss_scaling) + scaled_params_grad = [] + for p, g in self._param_grads: + scaled_g = g / self._loss_scaling + scaled_params_grad.append([p, scaled_g]) - return master_params_grads, scaled_loss + return scaled_params_grad, scaled_loss - def apply_gradients(self, master_params_grads): + def apply_gradients(self, scaled_params_grads): """ - Update master parameters by their gradients, and cast to parameters - in float16. + Check scaled gradients to determine whether to update loss scaling and update + parameters by their scaled gradients, Args: - master_params_grads (list): A list of master params and grads. + scaled_params_grads (list): A list of params and scaled grads. Returns: A list of optimize operators. @@ -148,7 +148,7 @@ class OptimizerWithMixedPrecison(object): if self._use_dynamic_loss_scaling: - grads = [layers.reduce_sum(g) for [_, g] in master_params_grads] + grads = [layers.reduce_sum(g) for [_, g] in scaled_params_grads] all_grads = layers.concat(grads) all_grads_sum = layers.reduce_sum(all_grads) is_overall_finite = layers.isfinite(all_grads_sum) @@ -165,12 +165,10 @@ class OptimizerWithMixedPrecison(object): with switch.case(is_overall_finite): pass with switch.default(): - for _, g in master_params_grads: + for _, g in scaled_params_grads: layers.assign(layers.zeros_like(g), g) - optimize_ops = self._optimizer.apply_gradients(master_params_grads) - master_param_to_train_param(master_params_grads, self._param_grads, - self._train_program) + optimize_ops = self._optimizer.apply_gradients(scaled_params_grads) return optimize_ops @@ -183,12 +181,12 @@ class OptimizerWithMixedPrecison(object): Returns: The scaled loss by scaling factor, the list of optimize ops, and a - list of master parameters and gradients. + list of scaled parameters and gradients. """ - master_params_grads, scaled_loss = self.backward(loss) - optimize_ops = self.apply_gradients(master_params_grads) + scaled_params_grads, scaled_loss = self.backward(loss) + optimize_ops = self.apply_gradients(scaled_params_grads) - return scaled_loss, optimize_ops, master_params_grads + return scaled_loss, optimize_ops, scaled_params_grads def decorate(optimizer, diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index a4705e8b833..44a2497045d 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -94,6 +94,7 @@ gray_list = { 'elementwise_pow', 'elementwise_mod', 'elementwise_floordiv', + 'batch_norm', 'tanh', 'sigmoid', 'lookup_table', diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 8d9abf0762f..52fd2ba9ca1 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -36,92 +36,6 @@ def append_cast_op(i, o, prog): "out_dtype": o.dtype}) -def copy_to_master_param(p, block): - """ - New a master parameter for the input parameter, and they two share the same - attributes except the data type. - - Args: - p(Parameter): The input parameter in float16. - block(Program): The block in which the parameter is. - """ - v = block.vars.get(p.name, None) - if v is None: - raise ValueError("no param name %s found!" % p.name) - new_p = framework.Parameter( - block=block, - shape=v.shape, - dtype=core.VarDesc.VarType.FP32, - type=v.type, - lod_level=v.lod_level, - stop_gradient=p.stop_gradient, - trainable=p.trainable, - optimize_attr=p.optimize_attr, - regularizer=p.regularizer, - gradient_clip_attr=p.gradient_clip_attr, - error_clip=p.error_clip, - name=v.name + ".master") - return new_p - - -def create_master_params_grads(params_grads, main_prog, startup_prog, - loss_scaling): - """ - Create master parameters and gradients in float32 from params and grads - in float16. - - Args: - params_grads (list): A list of tuple (parameter, gradient) in float32. - main_prog (Program): The main program for training. - startup_prog (Program): The startup program to initialize all parameters. - loss_scaling (float): The factor to scale loss and gradients. - - Returns: - A list of master parameters and gradients. - """ - master_params_grads = [] - for p, g in params_grads: - # create master parameters - with main_prog._optimized_guard([p, g]): - # create master parameters - master_param = copy_to_master_param(p, main_prog.global_block()) - startup_master_param = startup_prog.global_block()._clone_variable( - master_param) - startup_p = startup_prog.global_block().var(p.name) - # fp16 -> fp32 - append_cast_op(startup_p, startup_master_param, startup_prog) - # cast fp16 gradients to fp32 before apply gradients - if g.name.find("batch_norm") > -1: - scaled_g = g / loss_scaling - master_params_grads.append([p, scaled_g]) - continue - master_grad = layers.cast(x=g, dtype="float32") - master_grad = master_grad / loss_scaling - master_params_grads.append([master_param, master_grad]) - - return master_params_grads - - -def master_param_to_train_param(master_params_grads, params_grads, main_prog): - """ - Convert master master parameters and gradients in float32 to parameters and - gradients in float16 for forward computation. - - Args: - master_params_grads (list): A list of master parameters and gradients in - float32. - params_grads (list): A list of parameters and gradients in float16. - main_prog (list): The main program for execution. - """ - for idx, m_p_g in enumerate(master_params_grads): - train_p, _ = params_grads[idx] - if train_p.name.find("batch_norm") > -1: - continue - with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]): - # fp32 -> fp16 - append_cast_op(m_p_g[0], train_p, main_prog) - - def _rename_arg(op, old_name, new_name): """ If an op has old_name input and output, rename these input @@ -172,6 +86,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): core.VarDesc.VarType.LOD_TENSOR_ARRAY ] for in_name in op.input_names: + if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm': + if in_name != 'X': + continue for in_var_name in op.input(in_name): in_var = block.var(in_var_name) if in_var.type not in valid_types: @@ -197,16 +114,18 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): else: if op.has_attr('in_dtype'): op._set_attr('in_dtype', dest_dtype) - if src_dtype == core.VarDesc.VarType.FP16: + if src_dtype == core.VarDesc.VarType.FP32: for out_name in op.output_names: + if op.type == 'batch_norm' and out_name != 'Y': + continue for out_var_name in op.output(out_name): out_var = block.var(out_var_name) if out_var.type not in valid_types: continue - if out_var.dtype == core.VarDesc.VarType.FP16: - out_var.desc.set_dtype(core.VarDesc.VarType.FP32) + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(core.VarDesc.VarType.FP16) if op.has_attr('out_dtype'): - op._set_attr('out_dtype', core.VarDesc.VarType.FP32) + op._set_attr('out_dtype', core.VarDesc.VarType.FP16) return num_cast_ops diff --git a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py index bde77b3d316..47beb9e21cb 100644 --- a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py +++ b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py @@ -113,13 +113,12 @@ def train(net_type, use_cuda, save_dirname, is_local): name='pixel', shape=data_shape, dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') - imgs = fluid.layers.cast(images, "float16") if net_type == "vgg": print("train vgg net") - net = vgg16_bn_drop(imgs) + net = vgg16_bn_drop(images) elif net_type == "resnet": print("train resnet") - net = resnet_cifar10(imgs, 32) + net = resnet_cifar10(images, 32) else: raise ValueError("%s network is not supported" % net_type) -- GitLab