From 18fc92756274e537eb1edf548f16ae1af72893be Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Wed, 16 Sep 2020 13:02:31 +0800 Subject: [PATCH] add regularizer api (#27292) --- python/paddle/__init__.py | 1 + .../tests/unittests/test_regularizer_api.py | 204 ++++++++++++++++++ python/paddle/regularizer.py | 136 +++++++++++- python/paddle/utils/__init__.py | 7 +- 4 files changed, 340 insertions(+), 8 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_regularizer_api.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ed0b415d0bf..016726633ea 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -49,6 +49,7 @@ import paddle.optimizer import paddle.metric import paddle.device import paddle.incubate.complex as complex +import paddle.regularizer # TODO: define alias in tensor and framework directory diff --git a/python/paddle/fluid/tests/unittests/test_regularizer_api.py b/python/paddle/fluid/tests/unittests/test_regularizer_api.py new file mode 100644 index 00000000000..76186d2e39f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_regularizer_api.py @@ -0,0 +1,204 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from functools import partial +import contextlib +import numpy as np +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +import paddle.fluid.framework as framework +import paddle.fluid.optimizer as optimizer +import paddle.regularizer as regularizer +from paddle.fluid.backward import append_backward + + +def bow_net(data, + label, + dict_dim, + is_sparse=False, + emb_dim=8, + hid_dim=8, + hid_dim2=6, + class_dim=2): + """ + BOW net + This model is from https://github.com/PaddlePaddle/models: + fluid/PaddleNLP/text_classification/nets.py + """ + emb = fluid.layers.embedding( + input=data, is_sparse=is_sparse, size=[dict_dim, emb_dim]) + bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') + bow_tanh = fluid.layers.tanh(bow) + fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh") + fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh") + prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + + return avg_cost + + +class TestRegularizer(unittest.TestCase): + def setUp(self): + self.word_dict = paddle.dataset.imdb.word_dict() + reader = paddle.batch( + paddle.dataset.imdb.train(self.word_dict), batch_size=1)() + self.train_data = [next(reader) for _ in range(1)] + + def get_places(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + return places + + @contextlib.contextmanager + def scope_prog_guard(self, main_prog, startup_prog): + scope = fluid.core.Scope() + with fluid.unique_name.guard(): + with fluid.scope_guard(scope): + with fluid.program_guard(main_prog, startup_prog): + yield + + def run_program(self, place, feed_list): + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=feed_list, place=place) + exe.run(fluid.default_startup_program()) + + main_prog = fluid.default_main_program() + param_list = [var.name for var in main_prog.block(0).all_parameters()] + + param_sum = [] + for data in self.train_data: + out = exe.run(main_prog, + feed=feeder.feed(data), + fetch_list=param_list) + p_sum = 0 + for v in out: + p_sum += np.sum(np.abs(v)) + param_sum.append(p_sum) + return param_sum + + def check_l2decay_regularizer(self, place, model): + paddle.manual_seed(1) + paddle.framework.random._manual_program_seed(1) + main_prog = fluid.framework.Program() + startup_prog = fluid.framework.Program() + with self.scope_prog_guard( + main_prog=main_prog, startup_prog=startup_prog): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + avg_cost = model(data, label, len(self.word_dict)) + + optimizer = fluid.optimizer.Adagrad( + learning_rate=0.1, + regularization=paddle.regularizer.L2Decay(1.0)) + optimizer.minimize(avg_cost) + param_sum = self.run_program(place, [data, label]) + return param_sum + + def check_l2decay(self, place, model): + paddle.manual_seed(1) + paddle.framework.random._manual_program_seed(1) + main_prog = fluid.framework.Program() + startup_prog = fluid.framework.Program() + + with self.scope_prog_guard( + main_prog=main_prog, startup_prog=startup_prog): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + avg_cost_l2 = model(data, label, len(self.word_dict)) + + param_list = fluid.default_main_program().block(0).all_parameters() + para_sum = [] + for para in param_list: + para_mul = fluid.layers.square(x=para) + para_sum.append(fluid.layers.reduce_sum(input=para_mul)) + avg_cost_l2 += fluid.layers.sums(para_sum) * .5 + + optimizer = fluid.optimizer.Adagrad(learning_rate=0.1) + optimizer.minimize(avg_cost_l2) + param_sum = self.run_program(place, [data, label]) + return param_sum + + def test_l2(self): + for place in self.get_places(): + dense_sparse_p_sum = [] + for sparse in [True, False]: + model = partial(bow_net, is_sparse=sparse) + framework_l2 = self.check_l2decay_regularizer(place, model) + l2 = self.check_l2decay(place, model) + assert len(l2) == len(framework_l2) + for i in range(len(l2)): + assert np.isclose(a=framework_l2[i], b=l2[i], rtol=5e-5) + dense_sparse_p_sum.append(framework_l2) + + assert len(dense_sparse_p_sum[0]) == len(dense_sparse_p_sum[1]) + for i in range(len(dense_sparse_p_sum[0])): + assert np.isclose( + a=dense_sparse_p_sum[0][i], + b=dense_sparse_p_sum[1][i], + rtol=5e-5) + + def test_repeated_regularization(self): + l1 = paddle.regularizer.L1Decay(0.1) + l2 = paddle.regularizer.L2Decay(0.01) + fc_param_attr = fluid.ParamAttr(regularizer=l1) + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.layers.uniform_random([2, 2, 3]) + out = fluid.layers.fc(x, 5, param_attr=fc_param_attr) + loss = fluid.layers.reduce_sum(out) + sgd = fluid.optimizer.SGD(learning_rate=0.1, regularization=l2) + sgd.minimize(loss) + with fluid.dygraph.guard(): + input = fluid.dygraph.to_variable( + np.random.randn(3, 2).astype('float32')) + paddle.manual_seed(1) + paddle.framework.random._manual_program_seed(1) + + linear1 = fluid.dygraph.Linear( + 2, 2, param_attr=fc_param_attr, bias_attr=fc_param_attr) + linear2 = fluid.dygraph.Linear( + 2, 2, param_attr=fc_param_attr, bias_attr=fc_param_attr) + + loss1 = linear1(input) + loss1.backward() + # set l2 regularizer in optimizer, but l1 in fluid.ParamAttr + + fluid.optimizer.SGD(parameter_list=linear1.parameters(), + learning_rate=1e-2, + regularization=l2).minimize(loss1) + # only set l1 in fluid.ParamAttr + loss2 = linear2(input) + loss2.backward() + fluid.optimizer.SGD(parameter_list=linear2.parameters(), + learning_rate=1e-2).minimize(loss2) + # they should both be applied by l1, and keep the same + self.assertTrue( + np.allclose(linear1.weight.numpy(), linear2.weight.numpy()), + "weight should use the regularization in fluid.ParamAttr!") + self.assertTrue( + np.allclose(linear1.bias.numpy(), linear2.bias.numpy()), + "bias should use the regularization in fluid.ParamAttr!") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/regularizer.py b/python/paddle/regularizer.py index 2b20bb41970..b3f483fd891 100644 --- a/python/paddle/regularizer.py +++ b/python/paddle/regularizer.py @@ -12,8 +12,134 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: define the regularizer functions -# __all__ = ['L1Decay', -# 'L1DecayRegularizer', -# 'L2Decay', -# 'L2DecayRegularizer'] +__all__ = ['L1Decay', 'L2Decay'] + +import paddle.fluid as fluid + + +class L1Decay(fluid.regularizer.L1Decay): + """ + Implement the L1 Weight Decay Regularization, which encourages the weights to be sparse. + + It can be set in :ref:`api_fluid_ParamAttr` or ``optimizer`` (such as :ref:`api_paddle_optimizer_Momentum` ). + When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in + ``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has + higher priority than ``optimizer`` , which means that for a trainable parameter, if regularizer is defined + in its ParamAttr, then the regularizer in Optimizer will be ignored. Otherwise the regularizer + in Optimizer will be used. + + In the implementation, the formula of L1 Weight Decay Regularization is as follows: + + .. math:: + + L1WeightDecay = reg\_coeff * sign(parameter) + + Args: + coeff(float, optional): regularization coeff. Default:0.0. + + Examples: + .. code-block:: python + + # Example1: set Regularizer in optimizer + import paddle + from paddle.regularizer import L1Decay + import numpy as np + paddle.disable_static() + inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + inp = paddle.to_tensor(inp) + out = linear(inp) + loss = paddle.mean(out) + beta1 = paddle.to_tensor([0.9], dtype="float32") + beta2 = paddle.to_tensor([0.99], dtype="float32") + momentum = paddle.optimizer.Momentum( + learning_rate=0.1, + parameters=linear.parameters(), + weight_decay=L1Decay(0.0001)) + back = out.backward() + momentum.step() + momentum.clear_grad() + + # Example2: set Regularizer in parameters + # Set L1 regularization in parameters. + # Global regularizer does not take effect on my_conv2d for this case. + from paddle.nn import Conv2d + from paddle import ParamAttr + from paddle.regularizer import L2Decay + + my_conv2d = Conv2d( + in_channels=10, + out_channels=10, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(regularizer=L2Decay(coeff=0.01)), + bias_attr=False) + """ + + def __init__(self, coeff=0.0): + super(L1Decay, self).__init__(coeff) + + +class L2Decay(fluid.regularizer.L2Decay): + """ + Implement the L2 Weight Decay Regularization, which helps to prevent the model over-fitting. + + It can be set in :ref:`api_fluid_ParamAttr` or ``optimizer`` (such as :ref:`api_paddle_optimizer_Momentum` ). + When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in + ``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has + higher priority than ``optimizer`` , which means that for a trainable parameter, if regularizer is defined + in its ParamAttr, then the regularizer in Optimizer will be ignored. Otherwise the regularizer + in Optimizer will be used. + + In the implementation, the formula of L2 Weight Decay Regularization is as follows: + + .. math:: + + L2WeightDecay = reg\_coeff * parameter + + Args: + regularization_coeff(float, optional): regularization coeff. Default:0.0 + + Examples: + .. code-block:: python + + # Example1: set Regularizer in optimizer + import paddle + from paddle.regularizer import L2Decay + import numpy as np + paddle.disable_static() + inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + inp = paddle.to_tensor(inp) + out = linear(inp) + loss = paddle.mean(out) + beta1 = paddle.to_tensor([0.9], dtype="float32") + beta2 = paddle.to_tensor([0.99], dtype="float32") + momentum = paddle.optimizer.Momentum( + learning_rate=0.1, + parameters=linear.parameters(), + weight_decay=L2Decay(0.0001)) + back = out.backward() + momentum.step() + momentum.clear_grad() + + # Example2: set Regularizer in parameters + # Set L2 regularization in parameters. + # Global regularizer does not take effect on my_conv2d for this case. + from paddle.nn import Conv2d + from paddle import ParamAttr + from paddle.regularizer import L2Decay + + my_conv2d = Conv2d( + in_channels=10, + out_channels=10, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(regularizer=L2Decay(coeff=0.01)), + bias_attr=False) + """ + + def __init__(self, coeff=0.0): + super(L2Decay, self).__init__(coeff) diff --git a/python/paddle/utils/__init__.py b/python/paddle/utils/__init__.py index 2a649c776b4..4a786679727 100644 --- a/python/paddle/utils/__init__.py +++ b/python/paddle/utils/__init__.py @@ -16,12 +16,13 @@ from .profiler import ProfilerOptions from .profiler import Profiler from .profiler import get_profiler from .deprecated import deprecated +from ..fluid.framework import unique_name +from ..fluid.framework import load_op_library +from ..fluid.framework import require_version from . import download __all__ = ['dump_config', 'deprecated', 'download'] #TODO: define new api under this directory -# __all__ = ['unique_name', -# 'load_op_library', -# 'require_version'] +__all__ += ['unique_name', 'load_op_library', 'require_version'] -- GitLab