未验证 提交 c0f26683 编写于 作者: L littletomatodonkey 提交者: GitHub

fix l1 decay for inplace (#32718)

上级 df00636b
...@@ -326,19 +326,21 @@ class L1DecayRegularizer(WeightDecayRegularizer): ...@@ -326,19 +326,21 @@ class L1DecayRegularizer(WeightDecayRegularizer):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
sign = block.create_var(dtype=param.dtype, shape=param.shape)
decay = block.create_var(dtype=param.dtype, shape=param.shape) decay = block.create_var(dtype=param.dtype, shape=param.shape)
else: else:
sign = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
decay = block.create_var( decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level) dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
# Append sign op # Append sign op
block.append_op( block.append_op(type='sign', inputs={"X": param}, outputs={"Out": sign})
type='sign', inputs={"X": param}, outputs={"Out": decay})
# Append scale op to the output of sign op # Append scale op to the output of sign op
block.append_op( block.append_op(
type='scale', type='scale',
inputs={"X": decay}, inputs={"X": sign},
outputs={"Out": decay}, outputs={"Out": decay},
attrs={"scale": self._regularization_coeff}) attrs={"scale": self._regularization_coeff})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册