diff --git a/ppdet/experimental/mixed_precision.py b/ppdet/experimental/mixed_precision.py index 8ff9db21ceddab4c8c608489c15a809052c6cd82..e13a7214252db2e0436f8a9902fe9066177b4f23 100644 --- a/ppdet/experimental/mixed_precision.py +++ b/ppdet/experimental/mixed_precision.py @@ -129,30 +129,27 @@ class DynamicLossScale(LossScale): def increment(self): enough_steps = layers.less_than(self.increment_every, self.good_steps + 1) - with layers.Switch() as switch: - with switch.case(enough_steps): - new_scale = self.scale * self.factor - scale_valid = layers.isfinite(new_scale) - with layers.Switch() as switch2: - with switch2.case(scale_valid): - layers.assign(new_scale, self.scale) - layers.assign( - layers.zeros_like(self.good_steps), self.good_steps) - with switch2.default(): - layers.increment(self.good_steps) - with switch.default(): - layers.increment(self.good_steps) + + def increment_step(): + layers.increment(self.good_steps) + + def maybe_update(): + new_scale = self.scale * self.factor + scale_valid = layers.isfinite(new_scale) + + def update_scale_and_step(): + layers.assign(new_scale, self.scale) + layers.assign( + layers.zeros_like(self.good_steps), self.good_steps) + + layers.cond(scale_valid, update_scale_and_step) + + layers.cond(enough_steps, maybe_update, increment_step) def decrement(self): new_scale = self.scale / self.factor one = layers.fill_constant(shape=[1], dtype='float32', value=1.0) - less_than_one = layers.less_than(new_scale, one) - with layers.Switch() as switch: - with switch.case(less_than_one): - layers.assign(one, self.scale) - with switch.default(): - layers.assign(new_scale, self.scale) - + layers.assign(layers.elementwise_max(new_scale, one), self.scale) layers.assign(layers.zeros_like(self.good_steps), self.good_steps) @@ -275,12 +272,13 @@ def scale_gradient(block, context): fwd_var = block._var_recursive(context[name]) if not isinstance(fwd_var, Parameter): continue # TODO verify all use cases - clip_op_desc = block.desc.append_op() - clip_op_desc.set_type("elementwise_div") - clip_op_desc.set_input("X", [name]) - clip_op_desc.set_input("Y", [scale.name]) - clip_op_desc.set_output("Out", [name]) - clip_op_desc._set_attr(op_role_attr_name, bwd_role) + scale_op_desc = block.desc.append_op() + scale_op_desc.set_type("elementwise_div") + scale_op_desc.set_input("X", [name]) + scale_op_desc.set_input("Y", [scale.name]) + scale_op_desc.set_output("Out", [name]) + scale_op_desc._set_attr("axis", -1) + scale_op_desc._set_attr(op_role_attr_name, bwd_role) def update_loss_scale(grads): @@ -289,12 +287,8 @@ def update_loss_scale(grads): return per_grad_check = layers.stack([layers.reduce_sum(g) for g in grads]) grad_valid = layers.isfinite(per_grad_check) - - with layers.Switch() as switch: - with switch.case(grad_valid): - state.increment() - with switch.default(): - state.decrement() + layers.cond(grad_valid, lambda: state.increment(), + lambda: state.decrement()) return grad_valid @@ -309,15 +303,15 @@ def backward(self, loss, **kwargs): else: kwargs['callbacks'] = callbacks param_grads = self._backward(loss, **kwargs) + + def zero_grad(): + for _, g in param_grads: + layers.assign(layers.zeros_like(g), g) + if state is not None: grad_valid = update_loss_scale(v for k, v in param_grads) if state.dynamic_scaling: - with layers.Switch() as switch: - with switch.case(grad_valid): - pass - with switch.default(): - for _, g in param_grads: - layers.assign(layers.zeros_like(g), g) + layers.cond(grad_valid, None, zero_grad) return param_grads