未验证 提交 00e415de 编写于 作者: L LoneRanger 提交者: GitHub

relocate python/paddle/fluid/regularizer.py (#53106)

* relocate regularizer.py

* fix bug

* fix bug

* fix bug

* relocate the import

* replace _regularization_coeff with coeff

* remove the L1DecayRegularizer and L2DecayRegularizer
上级 81056073
...@@ -24,6 +24,7 @@ from paddle.fluid.dygraph import base as imperative_base ...@@ -24,6 +24,7 @@ from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid.optimizer import Momentum, Optimizer from paddle.fluid.optimizer import Momentum, Optimizer
from paddle.framework import core, in_dygraph_mode from paddle.framework import core, in_dygraph_mode
from paddle.nn.clip import ClipGradByNorm, append_gradient_clip_ops from paddle.nn.clip import ClipGradByNorm, append_gradient_clip_ops
from paddle.regularizer import L1Decay, L2Decay
from paddle.static import create_global_var from paddle.static import create_global_var
...@@ -99,8 +100,7 @@ class DGCMomentumOptimizer(Optimizer): ...@@ -99,8 +100,7 @@ class DGCMomentumOptimizer(Optimizer):
regular_coeff = 0.0 regular_coeff = 0.0
if regularization is not None: if regularization is not None:
regular_coeff = regularization._regularization_coeff regular_coeff = regularization._coeff
from paddle.fluid.regularizer import L1Decay, L2Decay
if isinstance(regularization, L1Decay): if isinstance(regularization, L1Decay):
regular_type = 1 regular_type = 1
......
...@@ -58,7 +58,6 @@ from . import nets ...@@ -58,7 +58,6 @@ from . import nets
from . import optimizer from . import optimizer
from . import backward from . import backward
from .backward import gradients from .backward import gradients
from . import regularizer
from . import incubate from . import incubate
from .param_attr import ParamAttr, WeightNormParamAttr from .param_attr import ParamAttr, WeightNormParamAttr
from .data_feeder import DataFeeder from .data_feeder import DataFeeder
...@@ -116,7 +115,6 @@ __all__ = ( ...@@ -116,7 +115,6 @@ __all__ = (
'nets', 'nets',
'optimizer', 'optimizer',
'backward', 'backward',
'regularizer',
'LoDTensor', 'LoDTensor',
'LoDTensorArray', 'LoDTensorArray',
'CPUPlace', 'CPUPlace',
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.fluid.optimizer import Optimizer from paddle.fluid.optimizer import Optimizer
from paddle.fluid.regularizer import L1DecayRegularizer from paddle.regularizer import L1Decay
from paddle.fluid.regularizer import L2DecayRegularizer from paddle.regularizer import L2Decay
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.framework import program_guard from paddle.fluid.framework import program_guard
...@@ -117,7 +117,7 @@ class Momentum(Optimizer): ...@@ -117,7 +117,7 @@ class Momentum(Optimizer):
): ):
assert learning_rate is not None assert learning_rate is not None
assert momentum is not None assert momentum is not None
predicate = lambda regular: isinstance(regular, L2DecayRegularizer) predicate = lambda regular: isinstance(regular, L2Decay)
py_regular = None if predicate(regularization) else regularization py_regular = None if predicate(regularization) else regularization
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
...@@ -131,9 +131,9 @@ class Momentum(Optimizer): ...@@ -131,9 +131,9 @@ class Momentum(Optimizer):
self._use_nesterov = bool(use_nesterov) self._use_nesterov = bool(use_nesterov)
self._regularization_method = "" self._regularization_method = ""
self._regularization_coeff = 0 self._regularization_coeff = 0
if isinstance(regularization, L2DecayRegularizer): if isinstance(regularization, L2Decay):
self._regularization_method = "l2_decay" self._regularization_method = "l2_decay"
self._regularization_coeff = regularization._regularization_coeff self._regularization_coeff = regularization._coeff
self._multi_precision = multi_precision self._multi_precision = multi_precision
self._rescale_grad = rescale_grad self._rescale_grad = rescale_grad
self._master_weights = {} self._master_weights = {}
......
...@@ -413,7 +413,7 @@ def piecewise_decay(boundaries, values): ...@@ -413,7 +413,7 @@ def piecewise_decay(boundaries, values):
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
momentum=0.9, momentum=0.9,
learning_rate=fluid.layers.piecewise_decay(boundaries=boundaries, values=values), learning_rate=fluid.layers.piecewise_decay(boundaries=boundaries, values=values),
regularization=fluid.regularizer.L2Decay(1e-4)) regularization=paddle.regularizer.L2Decay(1e-4))
""" """
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from .regularizer import WeightDecayRegularizer from paddle.regularizer import WeightDecayRegularizer
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
__all__ = [ __all__ = [
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from . import framework
from .framework import _non_static_mode, in_dygraph_mode
from . import core
from paddle import _C_ops, _legacy_C_ops
__all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer']
class WeightDecayRegularizer:
"""Base class for weight decay regularizers
Defines the common interface of weight-decay regularizers.
Weight-decay regularizers are added only during the backward
pass for faster regularization. They add operations to the network
that correspond to gradient of the regularization function.
Users should not use this class directly, but need to use one
of its implementations
"""
def __init__(self):
pass
def __call__(self, param, grad, block):
"""Add corresponding weight decay operations to the network"""
raise NotImplementedError()
def __str__(self):
"""Debug string"""
raise NotImplementedError()
class L2DecayRegularizer(WeightDecayRegularizer):
r"""
Implement the L2 Weight Decay Regularization, which helps to prevent the model over-fitting.
It can be set in :ref:`api_fluid_ParamAttr` or ``optimizer`` (such as :ref:`api_fluid_optimizer_SGDOptimizer` ).
When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in
``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has
higher priority than ``optimizer`` .
In the implementation, the formula of L2 Weight Decay Regularization is as follows:
.. math::
L2WeightDecay = reg\_coeff * parameter
Args:
regularization_coeff(float, optional): regularization coeff. Default:0.0
Examples:
.. code-block:: python
# Example1: set Regularizer in optimizer
import paddle.fluid as fluid
import paddle
paddle.enable_static()
main_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
data = paddle.static.data(name='image', shape=[-1, 3, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
hidden = paddle.static.nn.fc(x=data, size=128, activation='relu')
prediction = paddle.static.nn.fc(x=hidden, size=10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(
input=prediction, label=label,
reduction='none', use_softmax=False
)
avg_loss = paddle.mean(loss)
optimizer = fluid.optimizer.Adagrad(
learning_rate=1e-4,
regularization=fluid.regularizer.L2Decay(
regularization_coeff=0.1))
optimizer.minimize(avg_loss)
# Example2: set Regularizer both in ParamAttr and optimizer
import paddle.fluid as fluid
import paddle
paddle.enable_static()
l1 = fluid.regularizer.L1Decay(regularization_coeff=0.1)
l2 = fluid.regularizer.L2Decay(regularization_coeff=0.1)
x = paddle.uniform([3,4])
# set L1 regularization in fluid.ParamAttr
w_param = fluid.ParamAttr(regularizer=l1)
hidden1 = paddle.static.nn.fc(x, 8, weight_attr=w_param) # fc_0.w_0(L1), fc_0.b_0
hidden2 = paddle.static.nn.fc(hidden1, 16, weight_attr=w_param) # fc_1.w_0(L1), fc_1.b_0
predict = paddle.static.nn.fc(hidden2, 32) # fc_3.w_0, fc_3.b_0
avg_loss = paddle.mean(predict)
# set L2 regularization in optimizer
optimizer = fluid.optimizer.SGD(learning_rate=1e-4, regularization=l2)
optimizer.minimize(avg_loss)
# it will Print Message:
# Regularization of [fc_0.w_0, fc_1.w_0] have been set by ParamAttr or WeightNormParamAttr already.
# So, the Regularization of Optimizer will not take effect for these parameters!
"""
def __init__(self, regularization_coeff=0.0):
assert regularization_coeff is not None
super().__init__()
self._regularization_coeff = regularization_coeff
def __call__(self, param, grad, block):
"""Add L2 weight decay ops to network
Adds L2 weight decay ops.
L2WeightDecay = reg_coeff * parameter
Args:
param: parameter variable for which regularization is applied
block: block in which variable is to be created
Returns:
new variable for weight decay
"""
assert isinstance(param, framework.Variable)
assert isinstance(block, framework.Block)
if framework._non_static_mode():
if framework.in_dygraph_mode():
return _C_ops.scale(
param, self._regularization_coeff, 0.0, True
)
else:
return _legacy_C_ops.scale(
param, "scale", self._regularization_coeff
)
else:
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level
)
# Append Op to calculate decay
block.append_op(
type='scale',
inputs={"X": param},
outputs={"Out": decay},
attrs={"scale": self._regularization_coeff},
)
return decay
def __str__(self):
return "L2Decay, regularization_coeff=%f" % self._regularization_coeff
class L1DecayRegularizer(WeightDecayRegularizer):
r"""
Implement the L1 Weight Decay Regularization, which encourages the weights to be sparse.
It can be set in :ref:`api_fluid_ParamAttr` or ``optimizer`` (such as :ref:`api_fluid_optimizer_SGDOptimizer` ).
When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in
``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has
higher priority than ``optimizer`` .
In the implementation, the formula of L1 Weight Decay Regularization is as follows:
.. math::
L1WeightDecay = reg\_coeff * sign(parameter)
Args:
regularization_coeff(float, optional): regularization coeff. Default:0.0.
Examples:
.. code-block:: python
# Example1: set Regularizer in optimizer
import paddle.fluid as fluid
import paddle
paddle.enable_static()
main_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
data = paddle.static.data(name='image', shape=[-1, 3, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
hidden = paddle.static.nn.fc(x=data, size=128, activation='relu')
prediction = paddle.static.nn.fc(x=hidden, size=10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(
input=prediction, label=label,
reduction='none', use_softmax=False
)
avg_loss = paddle.mean(loss)
optimizer = fluid.optimizer.Adagrad(
learning_rate=1e-4,
regularization=fluid.regularizer.L1DecayRegularizer(
regularization_coeff=0.1))
optimizer.minimize(avg_loss)
# Example2: set Regularizer both in ParamAttr and optimizer
import paddle.fluid as fluid
import paddle
paddle.enable_static()
l1 = fluid.regularizer.L1Decay(regularization_coeff=0.1)
l2 = fluid.regularizer.L2Decay(regularization_coeff=0.1)
x = paddle.uniform([3,4])
# set L1 regularization in fluid.ParamAttr
w_param = fluid.ParamAttr(regularizer=l1)
hidden1 = paddle.static.nn.fc(x, 8, weight_attr=w_param) # fc_0.w_0(L1), fc_0.b_0
hidden2 = paddle.static.nn.fc(hidden1, 16, weight_attr=w_param) # fc_1.w_0(L1), fc_1.b_0
predict = paddle.static.nn.fc(hidden2, 32) # fc_3.w_0, fc_3.b_0
avg_loss = paddle.mean(predict)
# set L2 regularization in optimizer
optimizer = fluid.optimizer.SGD(learning_rate=1e-4, regularization=l2)
optimizer.minimize(avg_loss)
# it will Print Message:
# Regularization of [fc_0.w_0, fc_1.w_0] have been set by ParamAttr or WeightNormParamAttr already.
# So, the Regularization of Optimizer will not take effect for these parameters!
"""
def __init__(self, regularization_coeff=0.0):
assert regularization_coeff is not None
super().__init__()
self._regularization_coeff = regularization_coeff
def __call__(self, param, grad, block):
"""Add L1 weight decay ops to network
Adds L1 weight decay ops.
L1WeightDecay = reg_coeff * sign(parameter)
Args:
param: parameter variable for which regularization is applied
block: block in which variable is to be created
Returns:
new variable for weight decay
"""
assert isinstance(param, framework.Variable)
assert isinstance(block, framework.Block)
if framework._non_static_mode():
sign = block.create_var(dtype=param.dtype, shape=param.shape)
decay = block.create_var(dtype=param.dtype, shape=param.shape)
else:
sign = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level
)
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level
)
if in_dygraph_mode():
sign = _C_ops.sign(param)
return _C_ops.scale(sign, self._regularization_coeff, 0.0, True)
# Append sign op
block.append_op(type='sign', inputs={"X": param}, outputs={"Out": sign})
# Append scale op to the output of sign op
block.append_op(
type='scale',
inputs={"X": sign},
outputs={"Out": decay},
attrs={"scale": self._regularization_coeff},
)
return decay
def __str__(self):
return "L1Decay, regularization_coeff=%f" % self._regularization_coeff
# We short the class name, since users will use the regulaizer with the package
# name. The sample code:
#
# import paddle
# import paddle.fluid as fluid
#
# hidden = paddle.static.nn.fc(...,
# weight_attr=fluid.regularizer.Xavier())
#
# It is no need to add a `Regularizer` as the class suffix
L1Decay = L1DecayRegularizer
L2Decay = L2DecayRegularizer
...@@ -61,7 +61,7 @@ def optimizer_setting(params, parameter_list=None): ...@@ -61,7 +61,7 @@ def optimizer_setting(params, parameter_list=None):
learning_rate=lr, step_each_epoch=step, epochs=num_epochs learning_rate=lr, step_each_epoch=step, epochs=num_epochs
), ),
momentum=momentum_rate, momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay), regularization=paddle.regularizer.L2Decay(l2_decay),
parameter_list=parameter_list, parameter_list=parameter_list,
) )
else: else:
...@@ -70,7 +70,7 @@ def optimizer_setting(params, parameter_list=None): ...@@ -70,7 +70,7 @@ def optimizer_setting(params, parameter_list=None):
learning_rate=lr, step_each_epoch=step, epochs=num_epochs learning_rate=lr, step_each_epoch=step, epochs=num_epochs
), ),
momentum=momentum_rate, momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay), regularization=paddle.regularizer.L2Decay(l2_decay),
) )
return optimizer return optimizer
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
import unittest import unittest
import paddle import paddle
from paddle.fluid import framework, optimizer, regularizer from paddle import regularizer
from paddle.fluid import framework, optimizer
from paddle.nn import clip from paddle.nn import clip
paddle.enable_static() paddle.enable_static()
...@@ -49,7 +50,7 @@ class TestDGCMomentumOptimizer(unittest.TestCase): ...@@ -49,7 +50,7 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
optimize_attr={'learning_rate': 1.1}, optimize_attr={'learning_rate': 1.1},
regularizer=None regularizer=None
if regularization is not None if regularization is not None
else regularizer.L2DecayRegularizer(2e-4), else regularizer.L2Decay(2e-4),
) )
mul_y = block.create_var( mul_y = block.create_var(
dtype="float32", shape=[dims[1], dims[2]], lod_level=0, name="mul.y" dtype="float32", shape=[dims[1], dims[2]], lod_level=0, name="mul.y"
......
...@@ -66,7 +66,7 @@ class TestFleetAMPInit(unittest.TestCase): ...@@ -66,7 +66,7 @@ class TestFleetAMPInit(unittest.TestCase):
optimizer = paddle.optimizer.Momentum( optimizer = paddle.optimizer.Momentum(
learning_rate=0.001, learning_rate=0.001,
momentum=0.9, momentum=0.9,
weight_decay=fluid.regularizer.L2Decay(1e-4), weight_decay=paddle.regularizer.L2Decay(1e-4),
multi_precision=True, multi_precision=True,
) )
...@@ -110,7 +110,7 @@ class TestFleetAMPInit(unittest.TestCase): ...@@ -110,7 +110,7 @@ class TestFleetAMPInit(unittest.TestCase):
optimizer = paddle.optimizer.Momentum( optimizer = paddle.optimizer.Momentum(
learning_rate=0.001, learning_rate=0.001,
momentum=0.9, momentum=0.9,
weight_decay=fluid.regularizer.L2Decay(1e-4), weight_decay=paddle.regularizer.L2Decay(1e-4),
multi_precision=True, multi_precision=True,
) )
......
...@@ -539,7 +539,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -539,7 +539,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
) )
avg_cost, strategy = self.net(train_prog, startup_prog) avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'sharding') self.set_strategy(strategy, 'sharding')
regularization = paddle.fluid.regularizer.L2Decay(0.0001) regularization = paddle.regularizer.L2Decay(0.0001)
self.optimizer( self.optimizer(
avg_cost, avg_cost,
strategy, strategy,
......
...@@ -114,9 +114,7 @@ class TestDistCTR2x2(TestDistRunnerBase): ...@@ -114,9 +114,7 @@ class TestDistCTR2x2(TestDistRunnerBase):
regularization = None regularization = None
use_l2_decay = bool(os.getenv('USE_L2_DECAY', 0)) use_l2_decay = bool(os.getenv('USE_L2_DECAY', 0))
if use_l2_decay: if use_l2_decay:
regularization = fluid.regularizer.L2DecayRegularizer( regularization = paddle.regularizer.L2Decay(coeff=1e-1)
regularization_coeff=1e-1
)
use_lr_decay = bool(os.getenv('LR_DECAY', 0)) use_lr_decay = bool(os.getenv('LR_DECAY', 0))
lr = 0.0001 lr = 0.0001
if use_lr_decay: if use_lr_decay:
......
...@@ -243,7 +243,7 @@ class DistSeResneXt2x2(TestDistRunnerBase): ...@@ -243,7 +243,7 @@ class DistSeResneXt2x2(TestDistRunnerBase):
boundaries=bd, values=lr boundaries=bd, values=lr
), ),
momentum=0.9, momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4), regularization=paddle.regularizer.L2Decay(1e-4),
) )
else: else:
optimizer = ( optimizer = (
...@@ -253,7 +253,7 @@ class DistSeResneXt2x2(TestDistRunnerBase): ...@@ -253,7 +253,7 @@ class DistSeResneXt2x2(TestDistRunnerBase):
), ),
momentum=0.9, momentum=0.9,
rampup_begin_step=0, rampup_begin_step=0,
regularization=fluid.regularizer.L2Decay(1e-4), regularization=paddle.regularizer.L2Decay(1e-4),
) )
) )
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
......
...@@ -185,7 +185,7 @@ def optimizer(learning_rate=0.01): ...@@ -185,7 +185,7 @@ def optimizer(learning_rate=0.01):
learning_rate=learning_rate, step_each_epoch=2, epochs=1 learning_rate=learning_rate, step_each_epoch=2, epochs=1
), ),
momentum=0.9, momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4), regularization=paddle.regularizer.L2Decay(1e-4),
) )
return optimizer return optimizer
......
...@@ -711,7 +711,7 @@ class TestAdamOpV2(unittest.TestCase): ...@@ -711,7 +711,7 @@ class TestAdamOpV2(unittest.TestCase):
) )
adam = paddle.optimizer.Adam( adam = paddle.optimizer.Adam(
learning_rate=learning_rate, learning_rate=learning_rate,
weight_decay=fluid.regularizer.L2Decay(0.001), weight_decay=paddle.regularizer.L2Decay(0.001),
parameters=emb.parameters(), parameters=emb.parameters(),
) )
lr = adam.get_lr() lr = adam.get_lr()
...@@ -976,9 +976,7 @@ class TestAdamOptimizer(unittest.TestCase): ...@@ -976,9 +976,7 @@ class TestAdamOptimizer(unittest.TestCase):
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(
name="weight1", name="weight1",
initializer=paddle.nn.initializer.Constant(value=1.0), initializer=paddle.nn.initializer.Constant(value=1.0),
regularizer=fluid.regularizer.L1DecayRegularizer( regularizer=paddle.regularizer.L1Decay(coeff=0.1),
regularization_coeff=0.1
),
trainable=True, trainable=True,
) )
with fluid.program_guard(main): with fluid.program_guard(main):
......
...@@ -574,7 +574,7 @@ class TestL2Decay(TranspilerTest): ...@@ -574,7 +574,7 @@ class TestL2Decay(TranspilerTest):
x, x,
size=1000, size=1000,
weight_attr=fluid.ParamAttr( weight_attr=fluid.ParamAttr(
name='fc_w', regularizer=fluid.regularizer.L2Decay() name='fc_w', regularizer=paddle.regularizer.L2Decay()
), ),
bias_attr=fluid.ParamAttr(name='fc_b'), bias_attr=fluid.ParamAttr(name='fc_b'),
) )
...@@ -625,7 +625,7 @@ class TestL2DecayWithPiecewise(TranspilerTest): ...@@ -625,7 +625,7 @@ class TestL2DecayWithPiecewise(TranspilerTest):
boundaries=bd, values=lr boundaries=bd, values=lr
), ),
momentum=0.9, momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4), regularization=paddle.regularizer.L2Decay(1e-4),
) )
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
......
...@@ -88,7 +88,7 @@ class TestFuseAllReduceOpsBase(TestParallelExecutorBase): ...@@ -88,7 +88,7 @@ class TestFuseAllReduceOpsBase(TestParallelExecutorBase):
def optimizer(self, learning_rate=1e-3): def optimizer(self, learning_rate=1e-3):
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=learning_rate, learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-3), regularization=paddle.regularizer.L2Decay(1e-3),
) )
return optimizer return optimizer
......
...@@ -38,7 +38,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -38,7 +38,7 @@ class TestMNIST(TestParallelExecutorBase):
def _optimizer(learning_rate=1e-6): def _optimizer(learning_rate=1e-6):
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=learning_rate, learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-6), regularization=paddle.regularizer.L2Decay(1e-6),
) )
return optimizer return optimizer
......
...@@ -87,7 +87,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -87,7 +87,7 @@ class TestMNIST(TestParallelExecutorBase):
def _optimizer(learning_rate=1e-6): def _optimizer(learning_rate=1e-6):
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=learning_rate, learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-6), regularization=paddle.regularizer.L2Decay(1e-6),
) )
return optimizer return optimizer
......
...@@ -70,7 +70,7 @@ def optimizer_setting(params, parameter_list=None): ...@@ -70,7 +70,7 @@ def optimizer_setting(params, parameter_list=None):
# learning_rate=fluid.layers.piecewise_decay( # learning_rate=fluid.layers.piecewise_decay(
# boundaries=bd, values=lr), # boundaries=bd, values=lr),
# momentum=0.9, # momentum=0.9,
# regularization=fluid.regularizer.L2Decay(1e-4)) # regularization=paddle.regularizer.L2Decay(1e-4))
return optimizer return optimizer
......
...@@ -66,7 +66,7 @@ def optimizer_setting(params, parameter_list=None): ...@@ -66,7 +66,7 @@ def optimizer_setting(params, parameter_list=None):
# learning_rate=fluid.layers.piecewise_decay( # learning_rate=fluid.layers.piecewise_decay(
# boundaries=bd, values=lr), # boundaries=bd, values=lr),
# momentum=0.9, # momentum=0.9,
# regularization=fluid.regularizer.L2Decay(1e-4)) # regularization=paddle.regularizer.L2Decay(1e-4))
return optimizer return optimizer
......
...@@ -687,9 +687,7 @@ class TestMomentumOpWithDecayAPI(unittest.TestCase): ...@@ -687,9 +687,7 @@ class TestMomentumOpWithDecayAPI(unittest.TestCase):
def test_momentum_dygraph_1(self): def test_momentum_dygraph_1(self):
self._test_momentum_dygraph_common( self._test_momentum_dygraph_common(
regularization=paddle.fluid.regularizer.L2Decay( regularization=paddle.regularizer.L2Decay(coeff=0.1)
regularization_coeff=0.1
)
) )
def test_momentum_static(self): def test_momentum_static(self):
...@@ -825,9 +823,7 @@ class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase): ...@@ -825,9 +823,7 @@ class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase):
learning_rate=0.01, learning_rate=0.01,
momentum=0.9, momentum=0.9,
parameter_list=linear_old.parameters(), parameter_list=linear_old.parameters(),
regularization=paddle.fluid.regularizer.L2Decay( regularization=paddle.regularizer.L2Decay(coeff=0.1),
regularization_coeff=0.1
),
) )
self.__update_params(momentum=momentum_old, linear=linear_old) self.__update_params(momentum=momentum_old, linear=linear_old)
...@@ -841,9 +837,7 @@ class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase): ...@@ -841,9 +837,7 @@ class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase):
learning_rate=0.01, learning_rate=0.01,
momentum=0.9, momentum=0.9,
parameter_list=linear_new.parameters(), parameter_list=linear_new.parameters(),
regularization=paddle.fluid.regularizer.L2Decay( regularization=paddle.regularizer.L2Decay(coeff=0.1),
regularization_coeff=0.1
),
) )
self.__update_params(momentum=momentum_new, linear=linear_new) self.__update_params(momentum=momentum_new, linear=linear_new)
......
...@@ -167,7 +167,7 @@ class TestProgramPruneBackward(unittest.TestCase): ...@@ -167,7 +167,7 @@ class TestProgramPruneBackward(unittest.TestCase):
def optimizer(): def optimizer():
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=0.001, learning_rate=0.001,
regularization=fluid.regularizer.L2Decay(1e-4), regularization=paddle.regularizer.L2Decay(1e-4),
) )
return optimizer return optimizer
...@@ -183,7 +183,7 @@ class TestProgramPruneBackward(unittest.TestCase): ...@@ -183,7 +183,7 @@ class TestProgramPruneBackward(unittest.TestCase):
def optimizer(): def optimizer():
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=0.001, learning_rate=0.001,
regularization=fluid.regularizer.L2Decay(1e-4), regularization=paddle.regularizer.L2Decay(1e-4),
) )
return optimizer return optimizer
...@@ -199,7 +199,7 @@ class TestProgramPruneBackward(unittest.TestCase): ...@@ -199,7 +199,7 @@ class TestProgramPruneBackward(unittest.TestCase):
def optimizer(): def optimizer():
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=0.001, learning_rate=0.001,
regularization=fluid.regularizer.L2Decay(1e-4), regularization=paddle.regularizer.L2Decay(1e-4),
) )
return optimizer return optimizer
...@@ -223,7 +223,7 @@ class TestProgramPruneBackward(unittest.TestCase): ...@@ -223,7 +223,7 @@ class TestProgramPruneBackward(unittest.TestCase):
def optimizer(): def optimizer():
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
learning_rate=0.001, learning_rate=0.001,
regularization=fluid.regularizer.L2Decay(1e-4), regularization=paddle.regularizer.L2Decay(1e-4),
) )
return optimizer return optimizer
......
...@@ -20,12 +20,12 @@ from functools import partial ...@@ -20,12 +20,12 @@ from functools import partial
import numpy as np import numpy as np
import paddle import paddle
from paddle import fluid from paddle import fluid, regularizer
from paddle.fluid import core, framework, regularizer from paddle.fluid import core, framework
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
class TestL2DecayRegularizer(unittest.TestCase): class TestL2Decay(unittest.TestCase):
def test_l2decay_regularizer(self): def test_l2decay_regularizer(self):
paddle.enable_static() paddle.enable_static()
program = framework.Program() program = framework.Program()
...@@ -35,12 +35,10 @@ class TestL2DecayRegularizer(unittest.TestCase): ...@@ -35,12 +35,10 @@ class TestL2DecayRegularizer(unittest.TestCase):
shape=[5, 10], shape=[5, 10],
lod_level=0, lod_level=0,
name="mul.x", name="mul.x",
regularizer=regularizer.L2DecayRegularizer(0.5), regularizer=regularizer.L2Decay(0.5),
) )
self.assertIsNotNone(mul_x.regularizer) self.assertIsNotNone(mul_x.regularizer)
self.assertTrue( self.assertTrue(isinstance(mul_x.regularizer, regularizer.L2Decay))
isinstance(mul_x.regularizer, regularizer.L2DecayRegularizer)
)
mul_y = block.create_var( mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y" dtype="float32", shape=[10, 8], lod_level=0, name="mul.y"
) )
...@@ -70,7 +68,7 @@ class TestL2DecayRegularizer(unittest.TestCase): ...@@ -70,7 +68,7 @@ class TestL2DecayRegularizer(unittest.TestCase):
self.assertEqual(block.ops[-2].type, 'scale') self.assertEqual(block.ops[-2].type, 'scale')
class TestL1DecayRegularizer(unittest.TestCase): class TestL1Decay(unittest.TestCase):
def test_l2decay_regularizer(self): def test_l2decay_regularizer(self):
paddle.enable_static() paddle.enable_static()
program = framework.Program() program = framework.Program()
...@@ -80,12 +78,10 @@ class TestL1DecayRegularizer(unittest.TestCase): ...@@ -80,12 +78,10 @@ class TestL1DecayRegularizer(unittest.TestCase):
shape=[5, 10], shape=[5, 10],
lod_level=0, lod_level=0,
name="mul.x", name="mul.x",
regularizer=regularizer.L1DecayRegularizer(0.5), regularizer=regularizer.L1Decay(0.5),
) )
self.assertIsNotNone(mul_x.regularizer) self.assertIsNotNone(mul_x.regularizer)
self.assertTrue( self.assertTrue(isinstance(mul_x.regularizer, regularizer.L1Decay))
isinstance(mul_x.regularizer, regularizer.L1DecayRegularizer)
)
mul_y = block.create_var( mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y" dtype="float32", shape=[10, 8], lod_level=0, name="mul.y"
) )
...@@ -208,7 +204,8 @@ class TestRegularizer(unittest.TestCase): ...@@ -208,7 +204,8 @@ class TestRegularizer(unittest.TestCase):
avg_cost = model(data, label, self.word_len) avg_cost = model(data, label, self.word_len)
optimizer = fluid.optimizer.Adagrad( optimizer = fluid.optimizer.Adagrad(
learning_rate=0.1, regularization=fluid.regularizer.L2Decay(1.0) learning_rate=0.1,
regularization=paddle.regularizer.L2Decay(1.0),
) )
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
param_sum = self.run_program(place, [data, label]) param_sum = self.run_program(place, [data, label])
...@@ -265,8 +262,8 @@ class TestRegularizer(unittest.TestCase): ...@@ -265,8 +262,8 @@ class TestRegularizer(unittest.TestCase):
) )
def test_repeated_regularization(self): def test_repeated_regularization(self):
l1 = fluid.regularizer.L1Decay(regularization_coeff=0.1) l1 = paddle.regularizer.L1Decay(coeff=0.1)
l2 = fluid.regularizer.L2Decay(regularization_coeff=0.01) l2 = paddle.regularizer.L2Decay(coeff=0.01)
fc_param_attr = paddle.ParamAttr( fc_param_attr = paddle.ParamAttr(
regularizer=paddle.regularizer.L1Decay() regularizer=paddle.regularizer.L1Decay()
) )
......
...@@ -1127,7 +1127,7 @@ class TestVarBase(unittest.TestCase): ...@@ -1127,7 +1127,7 @@ class TestVarBase(unittest.TestCase):
) )
self.assertTrue( self.assertTrue(
isinstance( isinstance(
static_var.regularizer, fluid.regularizer.L1Decay static_var.regularizer, paddle.regularizer.L1Decay
) )
) )
else: else:
......
...@@ -17,7 +17,7 @@ import warnings ...@@ -17,7 +17,7 @@ import warnings
import paddle import paddle
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.regularizer import L2DecayRegularizer from paddle.regularizer import L2Decay
from ..fluid import core, framework from ..fluid import core, framework
from .optimizer import Optimizer from .optimizer import Optimizer
...@@ -136,9 +136,7 @@ class Momentum(Optimizer): ...@@ -136,9 +136,7 @@ class Momentum(Optimizer):
if momentum is None: if momentum is None:
raise ValueError("momentum is not set") raise ValueError("momentum is not set")
predicate = lambda regular: isinstance( predicate = lambda regular: isinstance(regular, (L2Decay, float))
regular, (L2DecayRegularizer, float)
)
if isinstance(parameters, list): if isinstance(parameters, list):
if isinstance(parameters[0], dict): if isinstance(parameters[0], dict):
for param_group in parameters: for param_group in parameters:
...@@ -192,9 +190,9 @@ class Momentum(Optimizer): ...@@ -192,9 +190,9 @@ class Momentum(Optimizer):
reg_method = "" reg_method = ""
reg_coeff = 0.0 reg_coeff = 0.0
if isinstance(weight_decay, L2DecayRegularizer): if isinstance(weight_decay, L2Decay):
reg_method = "l2_decay" reg_method = "l2_decay"
reg_coeff = weight_decay._regularization_coeff reg_coeff = weight_decay._coeff
if isinstance(weight_decay, float): if isinstance(weight_decay, float):
reg_method = "l2_decay" reg_method = "l2_decay"
reg_coeff = weight_decay reg_coeff = weight_decay
...@@ -237,7 +235,7 @@ class Momentum(Optimizer): ...@@ -237,7 +235,7 @@ class Momentum(Optimizer):
# If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused # If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
# L2Decay with momentum which can refer to _append_optimize_op below. # L2Decay with momentum which can refer to _append_optimize_op below.
if hasattr(param, 'regularizer') and isinstance( if hasattr(param, 'regularizer') and isinstance(
param.regularizer, L2DecayRegularizer param.regularizer, L2Decay
): ):
return grad return grad
return super()._create_regularization_of_grad( return super()._create_regularization_of_grad(
...@@ -260,9 +258,9 @@ class Momentum(Optimizer): ...@@ -260,9 +258,9 @@ class Momentum(Optimizer):
regularization_coeff = self._regularization_coeff regularization_coeff = self._regularization_coeff
if hasattr(param, 'regularizer'): if hasattr(param, 'regularizer'):
# we skip param's l2decay before, so fuse it with momentum here. # we skip param's l2decay before, so fuse it with momentum here.
if isinstance(param.regularizer, L2DecayRegularizer): if isinstance(param.regularizer, L2Decay):
regularization_method = "l2_decay" regularization_method = "l2_decay"
regularization_coeff = param.regularizer._regularization_coeff regularization_coeff = param.regularizer._coeff
# the param's regularization has been done before, we avoid do l2decay in momentum. # the param's regularization has been done before, we avoid do l2decay in momentum.
elif param.regularizer is not None: elif param.regularizer is not None:
regularization_method = "" regularization_method = ""
...@@ -348,11 +346,9 @@ class Momentum(Optimizer): ...@@ -348,11 +346,9 @@ class Momentum(Optimizer):
regularization_coeff = self._regularization_coeff regularization_coeff = self._regularization_coeff
if hasattr(param, 'regularizer'): if hasattr(param, 'regularizer'):
# we skip param's l2decay before, so fuse it with momentum here. # we skip param's l2decay before, so fuse it with momentum here.
if isinstance(param.regularizer, L2DecayRegularizer): if isinstance(param.regularizer, L2Decay):
regularization_method = "l2_decay" regularization_method = "l2_decay"
regularization_coeff = ( regularization_coeff = param.regularizer._coeff
param.regularizer._regularization_coeff
)
elif param.regularizer is not None: elif param.regularizer is not None:
regularization_method = "" regularization_method = ""
regularization_coeff = 0.0 regularization_coeff = 0.0
......
...@@ -30,6 +30,7 @@ from paddle.fluid.framework import ( ...@@ -30,6 +30,7 @@ from paddle.fluid.framework import (
in_dygraph_mode, in_dygraph_mode,
name_scope, name_scope,
) )
from paddle.regularizer import L2Decay
from ..fluid import framework, unique_name from ..fluid import framework, unique_name
from ..fluid.backward import _get_no_grad_set_name, append_backward from ..fluid.backward import _get_no_grad_set_name, append_backward
...@@ -224,8 +225,6 @@ class Optimizer: ...@@ -224,8 +225,6 @@ class Optimizer:
"'grad_clip' should be an instance of GradientClipBase's derived class" "'grad_clip' should be an instance of GradientClipBase's derived class"
) )
if isinstance(weight_decay, float): if isinstance(weight_decay, float):
from ..fluid.regularizer import L2Decay
self.regularization = L2Decay(weight_decay) self.regularization = L2Decay(weight_decay)
else: else:
self.regularization = weight_decay self.regularization = weight_decay
...@@ -1571,8 +1570,6 @@ class Optimizer: ...@@ -1571,8 +1570,6 @@ class Optimizer:
for param in param_group['params']: for param in param_group['params']:
weight_decay = param_group['weight_decay'] weight_decay = param_group['weight_decay']
if isinstance(weight_decay, float): if isinstance(weight_decay, float):
from ..fluid.regularizer import L2Decay
regularization = L2Decay(weight_decay) regularization = L2Decay(weight_decay)
else: else:
regularization = weight_decay regularization = weight_decay
......
...@@ -12,12 +12,38 @@ ...@@ -12,12 +12,38 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import framework
from paddle.fluid.framework import in_dygraph_mode
__all__ = ['L1Decay', 'L2Decay'] __all__ = ['L1Decay', 'L2Decay']
from paddle import fluid
class WeightDecayRegularizer:
"""Base class for weight decay regularizers
Defines the common interface of weight-decay regularizers.
Weight-decay regularizers are added only during the backward
pass for faster regularization. They add operations to the network
that correspond to gradient of the regularization function.
Users should not use this class directly, but need to use one
of its implementations
"""
def __init__(self):
pass
def __call__(self, param, grad, block):
"""Add corresponding weight decay operations to the network"""
raise NotImplementedError()
def __str__(self):
"""Debug string"""
raise NotImplementedError()
class L1Decay(fluid.regularizer.L1Decay): class L1Decay(WeightDecayRegularizer):
r""" r"""
Implement the L1 Weight Decay Regularization, which encourages the weights to be sparse. Implement the L1 Weight Decay Regularization, which encourages the weights to be sparse.
...@@ -76,10 +102,58 @@ class L1Decay(fluid.regularizer.L1Decay): ...@@ -76,10 +102,58 @@ class L1Decay(fluid.regularizer.L1Decay):
""" """
def __init__(self, coeff=0.0): def __init__(self, coeff=0.0):
super().__init__(coeff) assert coeff is not None
super().__init__()
self._coeff = coeff
class L2Decay(fluid.regularizer.L2Decay):
def __call__(self, param, grad, block):
"""Add L1 weight decay ops to network
Adds L1 weight decay ops.
L1WeightDecay = reg_coeff * sign(parameter)
Args:
param: parameter variable for which regularization is applied
block: block in which variable is to be created
Returns:
new variable for weight decay
"""
assert isinstance(param, framework.Variable)
assert isinstance(block, framework.Block)
if framework._non_static_mode():
sign = block.create_var(dtype=param.dtype, shape=param.shape)
decay = block.create_var(dtype=param.dtype, shape=param.shape)
else:
sign = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level
)
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level
)
if in_dygraph_mode():
sign = _C_ops.sign(param)
return _C_ops.scale(sign, self._coeff, 0.0, True)
# Append sign op
block.append_op(type='sign', inputs={"X": param}, outputs={"Out": sign})
# Append scale op to the output of sign op
block.append_op(
type='scale',
inputs={"X": sign},
outputs={"Out": decay},
attrs={"scale": self._coeff},
)
return decay
def __str__(self):
return "L1Decay, coeff=%f" % self._coeff
class L2Decay(WeightDecayRegularizer):
r""" r"""
Implement the L2 Weight Decay Regularization, which helps to prevent the model over-fitting. Implement the L2 Weight Decay Regularization, which helps to prevent the model over-fitting.
...@@ -97,7 +171,7 @@ class L2Decay(fluid.regularizer.L2Decay): ...@@ -97,7 +171,7 @@ class L2Decay(fluid.regularizer.L2Decay):
loss = 0.5 * coeff * reduce\_sum(square(x)) loss = 0.5 * coeff * reduce\_sum(square(x))
Args: Args:
regularization_coeff(float, optional): regularization coeff. Default:0.0 coeff(float, optional): regularization coeff. Default:0.0
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -137,4 +211,45 @@ class L2Decay(fluid.regularizer.L2Decay): ...@@ -137,4 +211,45 @@ class L2Decay(fluid.regularizer.L2Decay):
""" """
def __init__(self, coeff=0.0): def __init__(self, coeff=0.0):
super().__init__(coeff) assert coeff is not None
super().__init__()
self._coeff = coeff
def __call__(self, param, grad, block):
"""Add L2 weight decay ops to network
Adds L2 weight decay ops.
L2WeightDecay = reg_coeff * parameter
Args:
param: parameter variable for which regularization is applied
block: block in which variable is to be created
Returns:
new variable for weight decay
"""
assert isinstance(param, framework.Variable)
assert isinstance(block, framework.Block)
if framework._non_static_mode():
if framework.in_dygraph_mode():
return _C_ops.scale(param, self._coeff, 0.0, True)
else:
return _legacy_C_ops.scale(param, "scale", self._coeff)
else:
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level
)
# Append Op to calculate decay
block.append_op(
type='scale',
inputs={"X": param},
outputs={"Out": decay},
attrs={"scale": self._coeff},
)
return decay
def __str__(self):
return "L2Decay, coeff=%f" % self._coeff
...@@ -134,7 +134,7 @@ def train(use_pure_fp16=True, use_nesterov=False, optimizer=""): ...@@ -134,7 +134,7 @@ def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
learning_rate=0.001, learning_rate=0.001,
momentum=0.9, momentum=0.9,
use_nesterov=use_nesterov, use_nesterov=use_nesterov,
weight_decay=fluid.regularizer.L2Decay(1e-4), weight_decay=paddle.regularizer.L2Decay(1e-4),
multi_precision=use_pure_fp16, multi_precision=use_pure_fp16,
) )
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
import paddle import paddle
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.nn import BatchNorm from paddle.nn import BatchNorm
from paddle.regularizer import L2Decay
class ConvBNLayer(paddle.nn.Layer): class ConvBNLayer(paddle.nn.Layer):
......
...@@ -451,9 +451,7 @@ def optimizer(cfg, parameter_list): ...@@ -451,9 +451,7 @@ def optimizer(cfg, parameter_list):
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
fluid.layers.piecewise_decay(boundaries=bd, values=lr), fluid.layers.piecewise_decay(boundaries=bd, values=lr),
parameter_list=parameter_list, parameter_list=parameter_list,
regularization=fluid.regularizer.L2DecayRegularizer( regularization=paddle.regularizer.L2Decay(coeff=l2_weight_decay),
regularization_coeff=l2_weight_decay
),
) )
return optimizer return optimizer
......
...@@ -100,9 +100,7 @@ class BiGRU(paddle.nn.Layer): ...@@ -100,9 +100,7 @@ class BiGRU(paddle.nn.Layer):
initializer=paddle.nn.initializer.Uniform( initializer=paddle.nn.initializer.Uniform(
low=-init_bound, high=init_bound low=-init_bound, high=init_bound
), ),
regularizer=fluid.regularizer.L2DecayRegularizer( regularizer=paddle.regularizer.L2Decay(coeff=1e-4),
regularization_coeff=1e-4
),
), ),
) )
...@@ -113,9 +111,7 @@ class BiGRU(paddle.nn.Layer): ...@@ -113,9 +111,7 @@ class BiGRU(paddle.nn.Layer):
initializer=paddle.nn.initializer.Uniform( initializer=paddle.nn.initializer.Uniform(
low=-init_bound, high=init_bound low=-init_bound, high=init_bound
), ),
regularizer=fluid.regularizer.L2DecayRegularizer( regularizer=paddle.regularizer.L2Decay(coeff=1e-4),
regularization_coeff=1e-4
),
), ),
) )
...@@ -126,9 +122,7 @@ class BiGRU(paddle.nn.Layer): ...@@ -126,9 +122,7 @@ class BiGRU(paddle.nn.Layer):
initializer=paddle.nn.initializer.Uniform( initializer=paddle.nn.initializer.Uniform(
low=-init_bound, high=init_bound low=-init_bound, high=init_bound
), ),
regularizer=fluid.regularizer.L2DecayRegularizer( regularizer=paddle.regularizer.L2Decay(coeff=1e-4),
regularization_coeff=1e-4
),
), ),
) )
...@@ -140,9 +134,7 @@ class BiGRU(paddle.nn.Layer): ...@@ -140,9 +134,7 @@ class BiGRU(paddle.nn.Layer):
initializer=paddle.nn.initializer.Uniform( initializer=paddle.nn.initializer.Uniform(
low=-init_bound, high=init_bound low=-init_bound, high=init_bound
), ),
regularizer=fluid.regularizer.L2DecayRegularizer( regularizer=paddle.regularizer.L2Decay(coeff=1e-4),
regularization_coeff=1e-4
),
), ),
) )
...@@ -417,9 +409,7 @@ class LexNet(paddle.nn.Layer): ...@@ -417,9 +409,7 @@ class LexNet(paddle.nn.Layer):
initializer=paddle.nn.initializer.Uniform( initializer=paddle.nn.initializer.Uniform(
low=-self.init_bound, high=self.init_bound low=-self.init_bound, high=self.init_bound
), ),
regularizer=fluid.regularizer.L2DecayRegularizer( regularizer=paddle.regularizer.L2Decay(coeff=1e-4),
regularization_coeff=1e-4
),
), ),
) )
......
...@@ -448,7 +448,7 @@ def create_optimizer(args, parameter_list): ...@@ -448,7 +448,7 @@ def create_optimizer(args, parameter_list):
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=args.lr, learning_rate=args.lr,
momentum=args.momentum_rate, momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay), regularization=paddle.regularizer.L2Decay(args.l2_decay),
parameter_list=parameter_list, parameter_list=parameter_list,
) )
......
...@@ -48,7 +48,7 @@ def optimizer_setting(parameter_list=None): ...@@ -48,7 +48,7 @@ def optimizer_setting(parameter_list=None):
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=base_lr, learning_rate=base_lr,
momentum=momentum_rate, momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay), regularization=paddle.regularizer.L2Decay(l2_decay),
parameter_list=parameter_list, parameter_list=parameter_list,
) )
......
...@@ -82,7 +82,7 @@ def optimizer_setting(params, parameter_list): ...@@ -82,7 +82,7 @@ def optimizer_setting(params, parameter_list):
learning_rate=lr, step_each_epoch=step, epochs=num_epochs learning_rate=lr, step_each_epoch=step, epochs=num_epochs
), ),
momentum=momentum_rate, momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay), regularization=paddle.regularizer.L2Decay(l2_decay),
parameter_list=parameter_list, parameter_list=parameter_list,
) )
......
...@@ -281,7 +281,7 @@ def create_optimizer(cfg, params): ...@@ -281,7 +281,7 @@ def create_optimizer(cfg, params):
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(boundaries=bd, values=lr), learning_rate=fluid.layers.piecewise_decay(boundaries=bd, values=lr),
momentum=momentum, momentum=momentum,
regularization=fluid.regularizer.L2Decay(l2_weight_decay), regularization=paddle.regularizer.L2Decay(l2_weight_decay),
parameter_list=params, parameter_list=params,
) )
......
...@@ -107,7 +107,7 @@ def train(to_static): ...@@ -107,7 +107,7 @@ def train(to_static):
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=lr, learning_rate=lr,
regularization=fluid.regularizer.L2Decay(cfg.weight_decay), regularization=paddle.regularizer.L2Decay(cfg.weight_decay),
momentum=cfg.momentum, momentum=cfg.momentum,
parameter_list=model.parameters(), parameter_list=model.parameters(),
) )
......
...@@ -20,8 +20,8 @@ from darknet import ConvBNLayer, DarkNet53_conv_body ...@@ -20,8 +20,8 @@ from darknet import ConvBNLayer, DarkNet53_conv_body
import paddle import paddle
from paddle import _legacy_C_ops, fluid from paddle import _legacy_C_ops, fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.jit.api import to_static from paddle.jit.api import to_static
from paddle.regularizer import L2Decay
class AttrDict(dict): class AttrDict(dict):
......
...@@ -75,7 +75,7 @@ def optimizer_setting(parameter_list=None): ...@@ -75,7 +75,7 @@ def optimizer_setting(parameter_list=None):
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=base_lr, learning_rate=base_lr,
momentum=momentum_rate, momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay), regularization=paddle.regularizer.L2Decay(l2_decay),
parameter_list=parameter_list, parameter_list=parameter_list,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册