未验证 提交 5a159f34 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #8934 from chengduoZH/feature/Enhance_regularizer_py

Enhance regularizer.py
......@@ -13,6 +13,7 @@
# limitations under the License.
import framework
from . import core
__all__ = [
'append_regularization_ops',
......@@ -46,9 +47,9 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
regularization_term = None
if param.regularizer is not None:
# Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad.block)
regularization_term = param.regularizer(param, grad, grad.block)
elif regularization is not None:
regularization_term = regularization(param, grad.block)
regularization_term = regularization(param, grad, grad.block)
# If no gradient or no regularization specified,
# then we don't need to do anything
......@@ -82,7 +83,7 @@ class WeightDecayRegularizer(object):
def __init__(self):
pass
def __call__(self, param, block):
def __call__(self, param, grad, block):
"""Add corresponding weight decay operations to the network
"""
raise NotImplementedError()
......@@ -102,7 +103,7 @@ class L2DecayRegularizer(WeightDecayRegularizer):
super(L2DecayRegularizer, self).__init__()
self._regularization_coeff = regularization_coeff
def __call__(self, param, block):
def __call__(self, param, grad, block):
"""Add L2 weight decay ops to network
Adds L2 weight decay ops.
......@@ -117,8 +118,23 @@ class L2DecayRegularizer(WeightDecayRegularizer):
"""
assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block)
decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS)
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': grad},
outputs={'Out': decay},
attrs={'is_sparse': True})
param = decay
# Append Op to calculate decay
block.append_op(
type='scale',
......@@ -141,7 +157,7 @@ class L1DecayRegularizer(WeightDecayRegularizer):
super(L1DecayRegularizer, self).__init__()
self._regularization_coeff = regularization_coeff
def __call__(self, param, block):
def __call__(self, param, grad, block):
"""Add L1 weight decay ops to network
Adds L1 weight decay ops.
......@@ -158,6 +174,19 @@ class L1DecayRegularizer(WeightDecayRegularizer):
assert isinstance(block, framework.Block)
decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS)
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': grad},
outputs={'Out': decay},
attrs={'is_sparse': True})
# Append sign op
block.append_op(
type='sign', inputs={"X": param}, outputs={"Out": decay})
......
......@@ -181,7 +181,10 @@ def train_main(use_cuda, is_sparse, is_local=True):
cost = pd.cross_entropy(input=rnn_out, label=label)
avg_cost = pd.mean(cost)
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
optimizer = fluid.optimizer.Adagrad(
learning_rate=1e-4,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.1))
optimize_ops, params_grads = optimizer.minimize(avg_cost)
train_data = paddle.batch(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册