From 9ae6064125947d5c8701d86d98f2e19951fbc24a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 21 Aug 2018 14:05:07 +0800 Subject: [PATCH] add dist ut for simnet bow --- .../fluid/tests/unittests/dist_simnet_bow.py | 11 +++++--- .../tests/unittests/test_dist_simnet_bow.py | 26 +++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py diff --git a/python/paddle/fluid/tests/unittests/dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/dist_simnet_bow.py index cbf5ab7a773..0dcf227107a 100644 --- a/python/paddle/fluid/tests/unittests/dist_simnet_bow.py +++ b/python/paddle/fluid/tests/unittests/dist_simnet_bow.py @@ -32,6 +32,8 @@ from functools import reduce from test_dist_base import TestDistRunnerBase, runtime_main DTYPE = "float32" +DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' +DATA_MD5 = '4cc060b0a0939a343fc9242aa1ee2e4e' # For Net base_lr = 0.005 @@ -191,7 +193,9 @@ def get_batch_reader(file_list): def get_train_reader(): # The training data set. - train_reader = get_batch_reader("sample") + train_file = os.path.join(paddle.dataset.common.DATA_HOME, "simnet", + "train") + train_reader = get_batch_reader(train_file) train_feed = ["query_ids", "pos_title_ids", "neg_title_ids", "label"] return train_reader, train_feed @@ -202,15 +206,16 @@ class TestDistSimnetBow2x2(TestDistRunnerBase): avg_cost, acc, predict = train_network() inference_program = fluid.default_main_program().clone() + # Optimization - opt = get_optimizer(learning_rate=0.001) + opt = get_optimizer() opt.minimize(avg_cost) # Reader train_reader, _ = get_train_reader() - return inference_program, avg_cost, train_reader, _, acc, predict if __name__ == "__main__": + paddle.dataset.common.download(DATA_URL, 'simnet', DATA_MD5, "train") runtime_main(TestDistSimnetBow2x2) diff --git a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py new file mode 100644 index 00000000000..2a01f01a87e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py @@ -0,0 +1,26 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase + + +class TestDistSimnetBow2x2(TestDistBase): + def test_simnet_bow(self): + self.check_with_place("dist_simnet_bow.py", delta=1e-7) + + +if __name__ == "__main__": + unittest.main() -- GitLab