未验证 提交 18fc9275 编写于 作者: L littletomatodonkey 提交者: GitHub

add regularizer api (#27292)

上级 8fe1c2d1
......@@ -49,6 +49,7 @@ import paddle.optimizer
import paddle.metric
import paddle.device
import paddle.incubate.complex as complex
import paddle.regularizer
# TODO: define alias in tensor and framework directory
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from functools import partial
import contextlib
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.optimizer as optimizer
import paddle.regularizer as regularizer
from paddle.fluid.backward import append_backward
def bow_net(data,
label,
dict_dim,
is_sparse=False,
emb_dim=8,
hid_dim=8,
hid_dim2=6,
class_dim=2):
"""
BOW net
This model is from https://github.com/PaddlePaddle/models:
fluid/PaddleNLP/text_classification/nets.py
"""
emb = fluid.layers.embedding(
input=data, is_sparse=is_sparse, size=[dict_dim, emb_dim])
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
class TestRegularizer(unittest.TestCase):
def setUp(self):
self.word_dict = paddle.dataset.imdb.word_dict()
reader = paddle.batch(
paddle.dataset.imdb.train(self.word_dict), batch_size=1)()
self.train_data = [next(reader) for _ in range(1)]
def get_places(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
@contextlib.contextmanager
def scope_prog_guard(self, main_prog, startup_prog):
scope = fluid.core.Scope()
with fluid.unique_name.guard():
with fluid.scope_guard(scope):
with fluid.program_guard(main_prog, startup_prog):
yield
def run_program(self, place, feed_list):
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
exe.run(fluid.default_startup_program())
main_prog = fluid.default_main_program()
param_list = [var.name for var in main_prog.block(0).all_parameters()]
param_sum = []
for data in self.train_data:
out = exe.run(main_prog,
feed=feeder.feed(data),
fetch_list=param_list)
p_sum = 0
for v in out:
p_sum += np.sum(np.abs(v))
param_sum.append(p_sum)
return param_sum
def check_l2decay_regularizer(self, place, model):
paddle.manual_seed(1)
paddle.framework.random._manual_program_seed(1)
main_prog = fluid.framework.Program()
startup_prog = fluid.framework.Program()
with self.scope_prog_guard(
main_prog=main_prog, startup_prog=startup_prog):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
avg_cost = model(data, label, len(self.word_dict))
optimizer = fluid.optimizer.Adagrad(
learning_rate=0.1,
regularization=paddle.regularizer.L2Decay(1.0))
optimizer.minimize(avg_cost)
param_sum = self.run_program(place, [data, label])
return param_sum
def check_l2decay(self, place, model):
paddle.manual_seed(1)
paddle.framework.random._manual_program_seed(1)
main_prog = fluid.framework.Program()
startup_prog = fluid.framework.Program()
with self.scope_prog_guard(
main_prog=main_prog, startup_prog=startup_prog):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
avg_cost_l2 = model(data, label, len(self.word_dict))
param_list = fluid.default_main_program().block(0).all_parameters()
para_sum = []
for para in param_list:
para_mul = fluid.layers.square(x=para)
para_sum.append(fluid.layers.reduce_sum(input=para_mul))
avg_cost_l2 += fluid.layers.sums(para_sum) * .5
optimizer = fluid.optimizer.Adagrad(learning_rate=0.1)
optimizer.minimize(avg_cost_l2)
param_sum = self.run_program(place, [data, label])
return param_sum
def test_l2(self):
for place in self.get_places():
dense_sparse_p_sum = []
for sparse in [True, False]:
model = partial(bow_net, is_sparse=sparse)
framework_l2 = self.check_l2decay_regularizer(place, model)
l2 = self.check_l2decay(place, model)
assert len(l2) == len(framework_l2)
for i in range(len(l2)):
assert np.isclose(a=framework_l2[i], b=l2[i], rtol=5e-5)
dense_sparse_p_sum.append(framework_l2)
assert len(dense_sparse_p_sum[0]) == len(dense_sparse_p_sum[1])
for i in range(len(dense_sparse_p_sum[0])):
assert np.isclose(
a=dense_sparse_p_sum[0][i],
b=dense_sparse_p_sum[1][i],
rtol=5e-5)
def test_repeated_regularization(self):
l1 = paddle.regularizer.L1Decay(0.1)
l2 = paddle.regularizer.L2Decay(0.01)
fc_param_attr = fluid.ParamAttr(regularizer=l1)
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = fluid.layers.uniform_random([2, 2, 3])
out = fluid.layers.fc(x, 5, param_attr=fc_param_attr)
loss = fluid.layers.reduce_sum(out)
sgd = fluid.optimizer.SGD(learning_rate=0.1, regularization=l2)
sgd.minimize(loss)
with fluid.dygraph.guard():
input = fluid.dygraph.to_variable(
np.random.randn(3, 2).astype('float32'))
paddle.manual_seed(1)
paddle.framework.random._manual_program_seed(1)
linear1 = fluid.dygraph.Linear(
2, 2, param_attr=fc_param_attr, bias_attr=fc_param_attr)
linear2 = fluid.dygraph.Linear(
2, 2, param_attr=fc_param_attr, bias_attr=fc_param_attr)
loss1 = linear1(input)
loss1.backward()
# set l2 regularizer in optimizer, but l1 in fluid.ParamAttr
fluid.optimizer.SGD(parameter_list=linear1.parameters(),
learning_rate=1e-2,
regularization=l2).minimize(loss1)
# only set l1 in fluid.ParamAttr
loss2 = linear2(input)
loss2.backward()
fluid.optimizer.SGD(parameter_list=linear2.parameters(),
learning_rate=1e-2).minimize(loss2)
# they should both be applied by l1, and keep the same
self.assertTrue(
np.allclose(linear1.weight.numpy(), linear2.weight.numpy()),
"weight should use the regularization in fluid.ParamAttr!")
self.assertTrue(
np.allclose(linear1.bias.numpy(), linear2.bias.numpy()),
"bias should use the regularization in fluid.ParamAttr!")
if __name__ == '__main__':
unittest.main()
......@@ -12,8 +12,134 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define the regularizer functions
# __all__ = ['L1Decay',
# 'L1DecayRegularizer',
# 'L2Decay',
# 'L2DecayRegularizer']
__all__ = ['L1Decay', 'L2Decay']
import paddle.fluid as fluid
class L1Decay(fluid.regularizer.L1Decay):
"""
Implement the L1 Weight Decay Regularization, which encourages the weights to be sparse.
It can be set in :ref:`api_fluid_ParamAttr` or ``optimizer`` (such as :ref:`api_paddle_optimizer_Momentum` ).
When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in
``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has
higher priority than ``optimizer`` , which means that for a trainable parameter, if regularizer is defined
in its ParamAttr, then the regularizer in Optimizer will be ignored. Otherwise the regularizer
in Optimizer will be used.
In the implementation, the formula of L1 Weight Decay Regularization is as follows:
.. math::
L1WeightDecay = reg\_coeff * sign(parameter)
Args:
coeff(float, optional): regularization coeff. Default:0.0.
Examples:
.. code-block:: python
# Example1: set Regularizer in optimizer
import paddle
from paddle.regularizer import L1Decay
import numpy as np
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp)
out = linear(inp)
loss = paddle.mean(out)
beta1 = paddle.to_tensor([0.9], dtype="float32")
beta2 = paddle.to_tensor([0.99], dtype="float32")
momentum = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=linear.parameters(),
weight_decay=L1Decay(0.0001))
back = out.backward()
momentum.step()
momentum.clear_grad()
# Example2: set Regularizer in parameters
# Set L1 regularization in parameters.
# Global regularizer does not take effect on my_conv2d for this case.
from paddle.nn import Conv2d
from paddle import ParamAttr
from paddle.regularizer import L2Decay
my_conv2d = Conv2d(
in_channels=10,
out_channels=10,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(regularizer=L2Decay(coeff=0.01)),
bias_attr=False)
"""
def __init__(self, coeff=0.0):
super(L1Decay, self).__init__(coeff)
class L2Decay(fluid.regularizer.L2Decay):
"""
Implement the L2 Weight Decay Regularization, which helps to prevent the model over-fitting.
It can be set in :ref:`api_fluid_ParamAttr` or ``optimizer`` (such as :ref:`api_paddle_optimizer_Momentum` ).
When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in
``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has
higher priority than ``optimizer`` , which means that for a trainable parameter, if regularizer is defined
in its ParamAttr, then the regularizer in Optimizer will be ignored. Otherwise the regularizer
in Optimizer will be used.
In the implementation, the formula of L2 Weight Decay Regularization is as follows:
.. math::
L2WeightDecay = reg\_coeff * parameter
Args:
regularization_coeff(float, optional): regularization coeff. Default:0.0
Examples:
.. code-block:: python
# Example1: set Regularizer in optimizer
import paddle
from paddle.regularizer import L2Decay
import numpy as np
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp)
out = linear(inp)
loss = paddle.mean(out)
beta1 = paddle.to_tensor([0.9], dtype="float32")
beta2 = paddle.to_tensor([0.99], dtype="float32")
momentum = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=linear.parameters(),
weight_decay=L2Decay(0.0001))
back = out.backward()
momentum.step()
momentum.clear_grad()
# Example2: set Regularizer in parameters
# Set L2 regularization in parameters.
# Global regularizer does not take effect on my_conv2d for this case.
from paddle.nn import Conv2d
from paddle import ParamAttr
from paddle.regularizer import L2Decay
my_conv2d = Conv2d(
in_channels=10,
out_channels=10,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(regularizer=L2Decay(coeff=0.01)),
bias_attr=False)
"""
def __init__(self, coeff=0.0):
super(L2Decay, self).__init__(coeff)
......@@ -16,12 +16,13 @@ from .profiler import ProfilerOptions
from .profiler import Profiler
from .profiler import get_profiler
from .deprecated import deprecated
from ..fluid.framework import unique_name
from ..fluid.framework import load_op_library
from ..fluid.framework import require_version
from . import download
__all__ = ['dump_config', 'deprecated', 'download']
#TODO: define new api under this directory
# __all__ = ['unique_name',
# 'load_op_library',
# 'require_version']
__all__ += ['unique_name', 'load_op_library', 'require_version']
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册