diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py index a29f9a208ebefc75b531030c9f0de9487f2b136c..dc641cdd1afbecfc9122c9f2e8ce6fac77b53f21 100644 --- a/python/paddle/fluid/regularizer.py +++ b/python/paddle/fluid/regularizer.py @@ -13,6 +13,7 @@ # limitations under the License. import framework +from . import core __all__ = [ 'append_regularization_ops', @@ -46,9 +47,9 @@ def append_regularization_ops(parameters_and_grads, regularization=None): regularization_term = None if param.regularizer is not None: # Add variable for regularization term in grad block - regularization_term = param.regularizer(param, grad.block) + regularization_term = param.regularizer(param, grad, grad.block) elif regularization is not None: - regularization_term = regularization(param, grad.block) + regularization_term = regularization(param, grad, grad.block) # If no gradient or no regularization specified, # then we don't need to do anything @@ -82,7 +83,7 @@ class WeightDecayRegularizer(object): def __init__(self): pass - def __call__(self, param, block): + def __call__(self, param, grad, block): """Add corresponding weight decay operations to the network """ raise NotImplementedError() @@ -102,7 +103,7 @@ class L2DecayRegularizer(WeightDecayRegularizer): super(L2DecayRegularizer, self).__init__() self._regularization_coeff = regularization_coeff - def __call__(self, param, block): + def __call__(self, param, grad, block): """Add L2 weight decay ops to network Adds L2 weight decay ops. @@ -117,8 +118,23 @@ class L2DecayRegularizer(WeightDecayRegularizer): """ assert isinstance(param, framework.Parameter) assert isinstance(block, framework.Block) + decay = block.create_var( dtype="float32", shape=param.shape, lod_level=param.lod_level) + + if grad.type == core.VarDesc.VarType.SELECTED_ROWS: + decay = block.create_var( + dtype="float32", + shape=param.shape, + type=core.VarDesc.VarType.SELECTED_ROWS) + block.append_op( + type='lookup_table', + inputs={'W': param, + 'Ids': grad}, + outputs={'Out': decay}, + attrs={'is_sparse': True}) + param = decay + # Append Op to calculate decay block.append_op( type='scale', @@ -141,7 +157,7 @@ class L1DecayRegularizer(WeightDecayRegularizer): super(L1DecayRegularizer, self).__init__() self._regularization_coeff = regularization_coeff - def __call__(self, param, block): + def __call__(self, param, grad, block): """Add L1 weight decay ops to network Adds L1 weight decay ops. @@ -158,6 +174,20 @@ class L1DecayRegularizer(WeightDecayRegularizer): assert isinstance(block, framework.Block) decay = block.create_var( dtype="float32", shape=param.shape, lod_level=param.lod_level) + + if grad.type == core.VarDesc.VarType.SELECTED_ROWS: + # add concat_rows + decay = block.create_var( + dtype="float32", + shape=param.shape, + type=core.VarDesc.VarType.SELECTED_ROWS) + block.append_op( + type='lookup_table', + inputs={'W': param, + 'Ids': grad}, + outputs={'Out': decay}, + attrs={'is_sparse': True}) + # Append sign op block.append_op( type='sign', inputs={"X": param}, outputs={"Out": decay})