From e77f54734b04484aac99fa866cf9d40db53da876 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 28 Dec 2018 12:28:52 +0800 Subject: [PATCH] add unit test for dist sparse l2 decay --- .../paddle/fluid/tests/unittests/dist_ctr.py | 13 ++++++++- .../tests/unittests/dist_ctr_with_l2_decay.py | 27 +++++++++++++++++++ .../fluid/tests/unittests/test_dist_ctr.py | 10 +++++++ 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/dist_ctr_with_l2_decay.py diff --git a/python/paddle/fluid/tests/unittests/dist_ctr.py b/python/paddle/fluid/tests/unittests/dist_ctr.py index 65969824338..dd97853a4c4 100644 --- a/python/paddle/fluid/tests/unittests/dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_ctr.py @@ -30,7 +30,12 @@ fluid.default_main_program().random_seed = 1 class TestDistCTR2x2(TestDistRunnerBase): + def config(self): + self.use_l2_decay = False + def get_model(self, batch_size=2): + self.config() + dnn_input_dim, lr_input_dim = dist_ctr_reader.load_data_meta() """ network definition """ dnn_data = fluid.layers.data( @@ -97,7 +102,13 @@ class TestDistCTR2x2(TestDistRunnerBase): inference_program = paddle.fluid.default_main_program().clone() - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001) + regularization = None + if self.use_l2_decay: + regularization = fluid.regularizer.L2DecayRegularizer( + regularization_coeff=1e-3) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001, + regularization=regularization) sgd_optimizer.minimize(avg_cost) dataset = dist_ctr_reader.Dataset() diff --git a/python/paddle/fluid/tests/unittests/dist_ctr_with_l2_decay.py b/python/paddle/fluid/tests/unittests/dist_ctr_with_l2_decay.py new file mode 100644 index 00000000000..a7fbfd644d2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_ctr_with_l2_decay.py @@ -0,0 +1,27 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import dist_ctr +from test_dist_base import runtime_main + + +class TestDistCTRWithL2Decay(dist_ctr.TestDistCTR2x2): + def config(self): + self.use_l2_decay = True + + +if __name__ == "__main__": + runtime_main(TestDistCTRWithL2Decay) diff --git a/python/paddle/fluid/tests/unittests/test_dist_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_ctr.py index b2d979729bc..f6b0971c5c9 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr.py @@ -28,5 +28,15 @@ class TestDistCTR2x2(TestDistBase): self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) +class TestDistCTR2x2WithL2Decay(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._enforce_place = "CPU" + + def test_dist_ctr(self): + self.check_with_place( + "dist_ctr_with_l2_decay.py", delta=1e-7, check_error_log=False) + + if __name__ == "__main__": unittest.main() -- GitLab