未验证 提交 2835d5ad 编写于 作者: Y Yang Zhang 提交者: GitHub

Upgrade paddle API used in mixed precision training (#227)

上级 59b70495
...@@ -129,30 +129,27 @@ class DynamicLossScale(LossScale): ...@@ -129,30 +129,27 @@ class DynamicLossScale(LossScale):
def increment(self): def increment(self):
enough_steps = layers.less_than(self.increment_every, enough_steps = layers.less_than(self.increment_every,
self.good_steps + 1) self.good_steps + 1)
with layers.Switch() as switch:
with switch.case(enough_steps): def increment_step():
new_scale = self.scale * self.factor layers.increment(self.good_steps)
scale_valid = layers.isfinite(new_scale)
with layers.Switch() as switch2: def maybe_update():
with switch2.case(scale_valid): new_scale = self.scale * self.factor
layers.assign(new_scale, self.scale) scale_valid = layers.isfinite(new_scale)
layers.assign(
layers.zeros_like(self.good_steps), self.good_steps) def update_scale_and_step():
with switch2.default(): layers.assign(new_scale, self.scale)
layers.increment(self.good_steps) layers.assign(
with switch.default(): layers.zeros_like(self.good_steps), self.good_steps)
layers.increment(self.good_steps)
layers.cond(scale_valid, update_scale_and_step)
layers.cond(enough_steps, maybe_update, increment_step)
def decrement(self): def decrement(self):
new_scale = self.scale / self.factor new_scale = self.scale / self.factor
one = layers.fill_constant(shape=[1], dtype='float32', value=1.0) one = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
less_than_one = layers.less_than(new_scale, one) layers.assign(layers.elementwise_max(new_scale, one), self.scale)
with layers.Switch() as switch:
with switch.case(less_than_one):
layers.assign(one, self.scale)
with switch.default():
layers.assign(new_scale, self.scale)
layers.assign(layers.zeros_like(self.good_steps), self.good_steps) layers.assign(layers.zeros_like(self.good_steps), self.good_steps)
...@@ -275,12 +272,13 @@ def scale_gradient(block, context): ...@@ -275,12 +272,13 @@ def scale_gradient(block, context):
fwd_var = block._var_recursive(context[name]) fwd_var = block._var_recursive(context[name])
if not isinstance(fwd_var, Parameter): if not isinstance(fwd_var, Parameter):
continue # TODO verify all use cases continue # TODO verify all use cases
clip_op_desc = block.desc.append_op() scale_op_desc = block.desc.append_op()
clip_op_desc.set_type("elementwise_div") scale_op_desc.set_type("elementwise_div")
clip_op_desc.set_input("X", [name]) scale_op_desc.set_input("X", [name])
clip_op_desc.set_input("Y", [scale.name]) scale_op_desc.set_input("Y", [scale.name])
clip_op_desc.set_output("Out", [name]) scale_op_desc.set_output("Out", [name])
clip_op_desc._set_attr(op_role_attr_name, bwd_role) scale_op_desc._set_attr("axis", -1)
scale_op_desc._set_attr(op_role_attr_name, bwd_role)
def update_loss_scale(grads): def update_loss_scale(grads):
...@@ -289,12 +287,8 @@ def update_loss_scale(grads): ...@@ -289,12 +287,8 @@ def update_loss_scale(grads):
return return
per_grad_check = layers.stack([layers.reduce_sum(g) for g in grads]) per_grad_check = layers.stack([layers.reduce_sum(g) for g in grads])
grad_valid = layers.isfinite(per_grad_check) grad_valid = layers.isfinite(per_grad_check)
layers.cond(grad_valid, lambda: state.increment(),
with layers.Switch() as switch: lambda: state.decrement())
with switch.case(grad_valid):
state.increment()
with switch.default():
state.decrement()
return grad_valid return grad_valid
...@@ -309,15 +303,15 @@ def backward(self, loss, **kwargs): ...@@ -309,15 +303,15 @@ def backward(self, loss, **kwargs):
else: else:
kwargs['callbacks'] = callbacks kwargs['callbacks'] = callbacks
param_grads = self._backward(loss, **kwargs) param_grads = self._backward(loss, **kwargs)
def zero_grad():
for _, g in param_grads:
layers.assign(layers.zeros_like(g), g)
if state is not None: if state is not None:
grad_valid = update_loss_scale(v for k, v in param_grads) grad_valid = update_loss_scale(v for k, v in param_grads)
if state.dynamic_scaling: if state.dynamic_scaling:
with layers.Switch() as switch: layers.cond(grad_valid, None, zero_grad)
with switch.case(grad_valid):
pass
with switch.default():
for _, g in param_grads:
layers.assign(layers.zeros_like(g), g)
return param_grads return param_grads
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册