提交 74523c41 编写于 作者: C chengduoZH

enhance regularizer.py

上级 0d49b921
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import framework import framework
from . import core
__all__ = [ __all__ = [
'append_regularization_ops', 'append_regularization_ops',
...@@ -46,9 +47,9 @@ def append_regularization_ops(parameters_and_grads, regularization=None): ...@@ -46,9 +47,9 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
regularization_term = None regularization_term = None
if param.regularizer is not None: if param.regularizer is not None:
# Add variable for regularization term in grad block # Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad.block) regularization_term = param.regularizer(param, grad, grad.block)
elif regularization is not None: elif regularization is not None:
regularization_term = regularization(param, grad.block) regularization_term = regularization(param, grad, grad.block)
# If no gradient or no regularization specified, # If no gradient or no regularization specified,
# then we don't need to do anything # then we don't need to do anything
...@@ -82,7 +83,7 @@ class WeightDecayRegularizer(object): ...@@ -82,7 +83,7 @@ class WeightDecayRegularizer(object):
def __init__(self): def __init__(self):
pass pass
def __call__(self, param, block): def __call__(self, param, grad, block):
"""Add corresponding weight decay operations to the network """Add corresponding weight decay operations to the network
""" """
raise NotImplementedError() raise NotImplementedError()
...@@ -102,7 +103,7 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -102,7 +103,7 @@ class L2DecayRegularizer(WeightDecayRegularizer):
super(L2DecayRegularizer, self).__init__() super(L2DecayRegularizer, self).__init__()
self._regularization_coeff = regularization_coeff self._regularization_coeff = regularization_coeff
def __call__(self, param, block): def __call__(self, param, grad, block):
"""Add L2 weight decay ops to network """Add L2 weight decay ops to network
Adds L2 weight decay ops. Adds L2 weight decay ops.
...@@ -117,8 +118,23 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -117,8 +118,23 @@ class L2DecayRegularizer(WeightDecayRegularizer):
""" """
assert isinstance(param, framework.Parameter) assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
decay = block.create_var( decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS)
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': grad},
outputs={'Out': decay},
attrs={'is_sparse': True})
param = decay
# Append Op to calculate decay # Append Op to calculate decay
block.append_op( block.append_op(
type='scale', type='scale',
...@@ -141,7 +157,7 @@ class L1DecayRegularizer(WeightDecayRegularizer): ...@@ -141,7 +157,7 @@ class L1DecayRegularizer(WeightDecayRegularizer):
super(L1DecayRegularizer, self).__init__() super(L1DecayRegularizer, self).__init__()
self._regularization_coeff = regularization_coeff self._regularization_coeff = regularization_coeff
def __call__(self, param, block): def __call__(self, param, grad, block):
"""Add L1 weight decay ops to network """Add L1 weight decay ops to network
Adds L1 weight decay ops. Adds L1 weight decay ops.
...@@ -158,6 +174,20 @@ class L1DecayRegularizer(WeightDecayRegularizer): ...@@ -158,6 +174,20 @@ class L1DecayRegularizer(WeightDecayRegularizer):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
decay = block.create_var( decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
# add concat_rows
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS)
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': grad},
outputs={'Out': decay},
attrs={'is_sparse': True})
# Append sign op # Append sign op
block.append_op( block.append_op(
type='sign', inputs={"X": param}, outputs={"Out": decay}) type='sign', inputs={"X": param}, outputs={"Out": decay})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册