From efc994ecf63be5775406d6d23868c6fa879172da Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Sun, 29 Sep 2019 12:23:07 +0800 Subject: [PATCH] Revert to mixed precision training with manual control (#3434) * Place mixed precision inside PaddleDetection roll back to the monkey patch version as a temporary measure, before it is merged into paddle * Add command flag for `loss_scale` * Fix a stupid indentation error optimizer should be in the mixed precision context * Initial FP16 training * Add mixed precision training to rest of the detection models * Revert "Add support for mixed precision training (#3406)" This reverts commit 3a2c106271885071db7c0d85587540a8f83c24db. * Bug fixes and some tweaks --- .../ppdet/experimental/__init__.py | 6 + .../ppdet/experimental/mixed_precision.py | 324 ++++++++++++++++++ .../architectures/cascade_mask_rcnn.py | 13 + .../modeling/architectures/cascade_rcnn.py | 13 + .../modeling/architectures/faster_rcnn.py | 15 + .../ppdet/modeling/architectures/mask_rcnn.py | 17 +- .../ppdet/modeling/architectures/retinanet.py | 14 + .../ppdet/modeling/architectures/ssd.py | 20 +- .../ppdet/modeling/architectures/yolov3.py | 14 + PaddleCV/PaddleDetection/tools/train.py | 39 ++- 10 files changed, 454 insertions(+), 21 deletions(-) create mode 100644 PaddleCV/PaddleDetection/ppdet/experimental/__init__.py create mode 100644 PaddleCV/PaddleDetection/ppdet/experimental/mixed_precision.py diff --git a/PaddleCV/PaddleDetection/ppdet/experimental/__init__.py b/PaddleCV/PaddleDetection/ppdet/experimental/__init__.py new file mode 100644 index 00000000..5b9d75f9 --- /dev/null +++ b/PaddleCV/PaddleDetection/ppdet/experimental/__init__.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import + +from .mixed_precision import * +from . import mixed_precision + +__all__ = mixed_precision.__all__ diff --git a/PaddleCV/PaddleDetection/ppdet/experimental/mixed_precision.py b/PaddleCV/PaddleDetection/ppdet/experimental/mixed_precision.py new file mode 100644 index 00000000..0de2f26c --- /dev/null +++ b/PaddleCV/PaddleDetection/ppdet/experimental/mixed_precision.py @@ -0,0 +1,324 @@ +from __future__ import absolute_import +from __future__ import print_function + +import six +from paddle.fluid.framework import Parameter +from paddle.fluid import layers +from paddle.fluid import core +from paddle.fluid import unique_name +import paddle.fluid.layer_helper_base as lhb +import paddle.fluid.optimizer as optim + +__all__ = ['mixed_precision_global_state', 'mixed_precision_context', + 'StaticLossScale', 'DynamicLossScale'] + +_mixed_precision_global_state = None + + +def mixed_precision_global_state(): + return _mixed_precision_global_state + + +class LossScale(object): + def __init__(self): + super(LossScale, self).__init__() + + def get_loss_scale_var(self): + return self.scale + + def increment(self): + raise NotImplementedError() + + def decrement(self): + raise NotImplementedError() + + +class StaticLossScale(LossScale): + """ + Static (fixed) loss scale manager. + + Args: + init_loss_scale (float): initial loss scale value. + + Examples: + + .. code-block:: python + + from paddle import fluid + from ppdet.experimental import (mixed_precision_context, + StaticLossScale) + + with mixed_precision_context(StaticLossScale(8.), True) as ctx: + # ... + # scale loss + loss_scale = ctx.get_loss_scale_var() + + """ + + def __init__(self, init_loss_scale=1.): + super(StaticLossScale, self).__init__() + self.scale = layers.create_global_var( + name=unique_name.generate("loss_scale"), + shape=[1], + value=init_loss_scale, + dtype='float32', + persistable=True) + + +class DynamicLossScale(LossScale): + """ + Dynamic loss scale manager. it works as follows: + if gradients is valid for `increment_every` steps, loss scale values is + increased by `factor`, otherwise loss scale values is decreased by `factor` + + Args: + init_loss_scale (float): initial loss scale value. + increment_every (int): minimum 'good' steps before loss scale increase. + factor (float): increase/decrease loss scale by this much. + + Examples: + + .. code-block:: python + + from paddle import fluid + from ppdet.experimental import (mixed_precision_context, + DynamicLossScale) + + loss_scale = DynamicLossScale(8., 1000, 4.) + with mixed_precision_context(loss_scale, True) as ctx: + # ... + # scale loss + loss_scale = ctx.get_loss_scale_var() + + """ + + def __init__(self, init_loss_scale=2**15, increment_every=2000, factor=2.): + super(DynamicLossScale, self).__init__() + self.scale = layers.create_global_var( + name=unique_name.generate("loss_scale"), + shape=[1], + value=init_loss_scale, + dtype='float32', + persistable=True) + self.good_steps = layers.create_global_var( + name=unique_name.generate("good_steps"), + shape=[1], + value=0, + dtype='int32', + persistable=True) + self.increment_every = layers.fill_constant( + shape=[1], dtype='int32', value=increment_every) + self.factor = factor + + def increment(self): + enough_steps = layers.less_than(self.increment_every, + self.good_steps + 1) + with layers.Switch() as switch: + with switch.case(enough_steps): + new_scale = self.scale * self.factor + scale_valid = layers.isfinite(new_scale) + with layers.Switch() as switch2: + with switch2.case(scale_valid): + layers.assign(new_scale, self.scale) + layers.assign(layers.zeros_like(self.good_steps), + self.good_steps) + with switch2.default(): + layers.increment(self.good_steps) + with switch.default(): + layers.increment(self.good_steps) + + def decrement(self): + new_scale = self.scale / self.factor + one = layers.fill_constant(shape=[1], dtype='float32', value=1.0) + less_than_one = layers.less_than(new_scale, one) + with layers.Switch() as switch: + with switch.case(less_than_one): + layers.assign(one, self.scale) + with switch.default(): + layers.assign(new_scale, self.scale) + + layers.assign(layers.zeros_like(self.good_steps), + self.good_steps) + + +class mixed_precision_context(object): + """ + Context manager for mixed precision training. + + Args: + loss_scale (float, str or obj): loss scale settings, can be: + 1. an number: use fixed loss scale. + 2. 'dynamic': use a default `DynamicLossScale`. + 3. `DynamicLossScale` or `StaticLossScale` instance. + enabled (bool): enable mixed precision training. + + Examples: + + .. code-block:: python + + from paddle import fluid + from ppdet.experimental import mixed_precision_context + + with mixed_precision_context('dynamic', True) as ctx: + # cast inputs to float16 + inputs = fluid.layers.cast(inputs, "float16") + # build model here + logits = model(inputs) + # use float32 for softmax + logits = fluid.layers.cast(logits, "float32") + softmax = fluid.layers.softmax(logits) + loss = fluid.layers.cross_entropy(input=softmax, label=label) + avg_loss = fluid.layers.mean(loss) + # scale loss + loss_scale = ctx.get_loss_scale_var() + avg_loss *= loss_scale + optimizer = fluid.optimizer.Momentum(...) + optimizer.minimize(avg_loss) + + """ + + def __init__(self, loss_scale=1., enabled=True): + super(mixed_precision_context, self).__init__() + self.enabled = enabled + if not enabled: + return + monkey_patch() + if isinstance(loss_scale, six.integer_types + (float,)): + self.loss_scale = StaticLossScale(loss_scale) + elif loss_scale == 'dynamic': + self.loss_scale = DynamicLossScale() + else: + assert isinstance(loss_scale, LossScale), \ + "Invalid loss scale argument" + self.loss_scale = loss_scale + + @property + def dynamic_scaling(self): + return isinstance(self.loss_scale, DynamicLossScale) + + def __getattr__(self, attr): + if attr in ['get_loss_scale_var', 'increment', 'decrement']: + return getattr(self.loss_scale, attr) + + def __enter__(self): + if not self.enabled: + return + global _mixed_precision_global_state + _mixed_precision_global_state = self + return mixed_precision_global_state() + + def __exit__(self, *args): + if not self.enabled: + return + global _mixed_precision_global_state + _mixed_precision_global_state = None + return mixed_precision_global_state() + + +def create_parameter(self, + attr, + shape, + dtype, + is_bias=False, + default_initializer=None): + mp_state = mixed_precision_global_state() + is_half = (isinstance(dtype, str) and dtype == 'float16') \ + or (isinstance(dtype, core.VarDesc.VarType) + and dtype == core.VarDesc.VarType.FP16) + + if is_half and mp_state is not None: + dtype = 'float32' + + param = self._create_parameter(attr, shape, dtype, + is_bias, default_initializer) + if not is_half or mp_state is None: + return param + + param16 = self.main_program.current_block().create_var( + name=param.name + '.fp16', + dtype='float16', + type=param.type, + persistable=False) + self.append_op( + type='cast', + inputs={'X': [param]}, + outputs={'Out': [param16]}, + attrs={'in_dtype': param.dtype, + 'out_dtype': param16.dtype}) + return param16 + + +def scale_gradient(block, context): + state = mixed_precision_global_state() + if state is None: + return + scale = state.get_loss_scale_var() + op_desc = block.desc.op(block.desc.op_size() - 1) + op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() + bwd_role = core.op_proto_and_checker_maker.OpRole.Backward + for name in [n for n in op_desc.output_arg_names() if n in context]: + fwd_var = block._var_recursive(context[name]) + if not isinstance(fwd_var, Parameter): + continue # TODO verify all use cases + clip_op_desc = block.desc.append_op() + clip_op_desc.set_type("elementwise_div") + clip_op_desc.set_input("X", [name]) + clip_op_desc.set_input("Y", [scale.name]) + clip_op_desc.set_output("Out", [name]) + clip_op_desc._set_attr(op_role_attr_name, bwd_role) + + +def update_loss_scale(grads): + state = mixed_precision_global_state() + if state is None or not state.dynamic_scaling: + return + per_grad_check = layers.stack([layers.reduce_sum(g) for g in grads]) + grad_valid = layers.isfinite(per_grad_check) + + with layers.Switch() as switch: + with switch.case(grad_valid): + state.increment() + with switch.default(): + state.decrement() + return grad_valid + + +def backward(self, loss, **kwargs): + state = mixed_precision_global_state() + callbacks = 'callbacks' in kwargs and kwargs['callbacks'] or None + if callbacks is None: + from paddle.fluid.clip import error_clip_callback + callbacks = [error_clip_callback] # XXX what if gradient is zero? + if state is not None: + kwargs['callbacks'] = [scale_gradient] + callbacks + else: + kwargs['callbacks'] = callbacks + param_grads = self._backward(loss, **kwargs) + if state is not None: + grad_valid = update_loss_scale(v for k, v in param_grads) + if state.dynamic_scaling: + with layers.Switch() as switch: + with switch.case(grad_valid): + pass + with switch.default(): + for _, g in param_grads: + layers.assign(layers.zeros_like(g), g) + + return param_grads + + +mixed_precision_patched = False + + +# XXX this is a temporary measure, until thoroughly evaluated +def monkey_patch(): + global mixed_precision_patched + if mixed_precision_patched: + return + create_parameter_orig = lhb.LayerHelperBase.create_parameter + lhb.LayerHelperBase.create_parameter = create_parameter + lhb.LayerHelperBase._create_parameter = create_parameter_orig + backward_orig = optim.Optimizer.backward + optim.Optimizer.backward = backward + optim.Optimizer._backward = backward_orig + mixed_precision_patched = True diff --git a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/cascade_mask_rcnn.py b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/cascade_mask_rcnn.py index 2e480b78..ccfb16c2 100644 --- a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/cascade_mask_rcnn.py +++ b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/cascade_mask_rcnn.py @@ -16,8 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import OrderedDict + import paddle.fluid as fluid +from ppdet.experimental import mixed_precision_global_state from ppdet.core.workspace import register __all__ = ['CascadeMaskRCNN'] @@ -98,9 +101,19 @@ class CascadeMaskRCNN(object): im_info = feed_vars['im_info'] + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + # backbone body_feats = self.backbone(im) + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + # FPN if self.fpn is not None: body_feats, spatial_scale = self.fpn.get_output(body_feats) diff --git a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn.py b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn.py index 26e05925..647d8bbc 100644 --- a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn.py +++ b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn.py @@ -16,8 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import OrderedDict + import paddle.fluid as fluid +from ppdet.experimental import mixed_precision_global_state from ppdet.core.workspace import register __all__ = ['CascadeRCNN'] @@ -87,9 +90,19 @@ class CascadeRCNN(object): gt_box = feed_vars['gt_box'] is_crowd = feed_vars['is_crowd'] + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + # backbone body_feats = self.backbone(im) + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + # FPN if self.fpn is not None: body_feats, spatial_scale = self.fpn.get_output(body_feats) diff --git a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/faster_rcnn.py b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/faster_rcnn.py index a6ef2f6a..69855986 100644 --- a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/faster_rcnn.py +++ b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/faster_rcnn.py @@ -16,8 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import OrderedDict + from paddle import fluid +from ppdet.experimental import mixed_precision_global_state from ppdet.core.workspace import register __all__ = ['FasterRCNN'] @@ -67,9 +70,21 @@ class FasterRCNN(object): is_crowd = feed_vars['is_crowd'] else: im_shape = feed_vars['im_shape'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + body_feats = self.backbone(im) body_feat_names = list(body_feats.keys()) + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + if self.fpn is not None: body_feats, spatial_scale = self.fpn.get_output(body_feats) diff --git a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/mask_rcnn.py b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/mask_rcnn.py index ef7f4af1..97eacbf0 100644 --- a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/mask_rcnn.py +++ b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/mask_rcnn.py @@ -16,7 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from paddle import fluid +from collections import OrderedDict + +import paddle.fluid as fluid + +from ppdet.experimental import mixed_precision_global_state from ppdet.core.workspace import register __all__ = ['MaskRCNN'] @@ -79,8 +83,19 @@ class MaskRCNN(object): "{} has no {} field".format(feed_vars, var) im_info = feed_vars['im_info'] + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + # backbone body_feats = self.backbone(im) + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + # FPN if self.fpn is not None: body_feats, spatial_scale = self.fpn.get_output(body_feats) diff --git a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/retinanet.py b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/retinanet.py index e06cf9e0..4ce5ac50 100644 --- a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/retinanet.py +++ b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/retinanet.py @@ -16,8 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import OrderedDict + import paddle.fluid as fluid +from ppdet.experimental import mixed_precision_global_state from ppdet.core.workspace import register __all__ = ['RetinaNet'] @@ -50,9 +53,20 @@ class RetinaNet(object): gt_box = feed_vars['gt_box'] gt_label = feed_vars['gt_label'] is_crowd = feed_vars['is_crowd'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + # backbone body_feats = self.backbone(im) + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + # FPN body_feats, spatial_scale = self.fpn.get_output(body_feats) diff --git a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/ssd.py b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/ssd.py index 18132b20..e899075f 100644 --- a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/ssd.py +++ b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/ssd.py @@ -16,8 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from paddle import fluid +from collections import OrderedDict +import paddle.fluid as fluid + +from ppdet.experimental import mixed_precision_global_state from ppdet.core.workspace import register from ppdet.modeling.ops import SSDOutputDecoder @@ -59,7 +62,22 @@ class SSD(object): gt_box = feed_vars['gt_box'] gt_label = feed_vars['gt_label'] + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + # backbone body_feats = self.backbone(im) + + if isinstance(body_feats, OrderedDict): + body_feat_names = list(body_feats.keys()) + body_feats = [body_feats[name] for name in body_feat_names] + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats] + locs, confs, box, box_var = self.multi_box_head( inputs=body_feats, image=im, num_classes=self.num_classes) diff --git a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/yolov3.py b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/yolov3.py index 29efe1a6..2912ffda 100644 --- a/PaddleCV/PaddleDetection/ppdet/modeling/architectures/yolov3.py +++ b/PaddleCV/PaddleDetection/ppdet/modeling/architectures/yolov3.py @@ -18,6 +18,9 @@ from __future__ import print_function from collections import OrderedDict +from paddle import fluid + +from ppdet.experimental import mixed_precision_global_state from ppdet.core.workspace import register __all__ = ['YOLOv3'] @@ -43,12 +46,23 @@ class YOLOv3(object): def build(self, feed_vars, mode='train'): im = feed_vars['image'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + body_feats = self.backbone(im) if isinstance(body_feats, OrderedDict): body_feat_names = list(body_feats.keys()) body_feats = [body_feats[name] for name in body_feat_names] + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats] + if mode == 'train': gt_box = feed_vars['gt_box'] gt_label = feed_vars['gt_label'] diff --git a/PaddleCV/PaddleDetection/tools/train.py b/PaddleCV/PaddleDetection/tools/train.py index f1980307..2745c7f7 100644 --- a/PaddleCV/PaddleDetection/tools/train.py +++ b/PaddleCV/PaddleDetection/tools/train.py @@ -36,7 +36,8 @@ set_paddle_flags( ) from paddle import fluid -from paddle.fluid.contrib import mixed_precision + +from ppdet.experimental import mixed_precision_context from ppdet.core.workspace import load_config, merge_config, create from ppdet.data.data_feed import create_reader @@ -115,16 +116,18 @@ def main(): with fluid.unique_name.guard(): model = create(main_arch) train_pyreader, feed_vars = create_feed(train_feed) - train_fetches = model.train(feed_vars) - loss = train_fetches['loss'] - lr = lr_builder() - optimizer = optim_builder(lr) - if FLAGS.fp16: - optimizer = mixed_precision.decorate( - optimizer=optimizer, - init_loss_scaling=FLAGS.loss_scale, - use_dynamic_loss_scaling=False) - optimizer.minimize(loss) + + with mixed_precision_context(FLAGS.loss_scale, FLAGS.fp16) as ctx: + train_fetches = model.train(feed_vars) + + loss = train_fetches['loss'] + if FLAGS.fp16: + loss *= ctx.get_loss_scale_var() + lr = lr_builder() + optimizer = optim_builder(lr) + optimizer.minimize(loss) + if FLAGS.fp16: + loss /= ctx.get_loss_scale_var() # parse train fetches train_keys, train_values, _ = parse_fetches(train_fetches) @@ -154,8 +157,6 @@ def main(): # compile program for multi-devices build_strategy = fluid.BuildStrategy() build_strategy.fuse_all_optimizer_ops = False - if FLAGS.fp16: - build_strategy.fuse_all_reduce_ops = False # only enable sync_bn in multi GPU devices sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn' build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \ @@ -280,6 +281,12 @@ def main(): if __name__ == '__main__': parser = ArgsParser() + parser.add_argument( + "-r", + "--resume_checkpoint", + default=None, + type=str, + help="Checkpoint path for resuming training.") parser.add_argument( "--fp16", action='store_true', @@ -290,12 +297,6 @@ if __name__ == '__main__': default=8., type=float, help="Mixed precision training loss scale.") - parser.add_argument( - "-r", - "--resume_checkpoint", - default=None, - type=str, - help="Checkpoint path for resuming training.") parser.add_argument( "--eval", action='store_true', -- GitLab