diff --git a/python/paddle/fluid/tests/unittests/dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/dist_simnet_bow.py index cbf5ab7a7739b6fbf8bd2dc56f3599e73e3126ed..0dcf227107a849bfdafbbe7863d45e7b09ec567f 100644 --- a/python/paddle/fluid/tests/unittests/dist_simnet_bow.py +++ b/python/paddle/fluid/tests/unittests/dist_simnet_bow.py @@ -32,6 +32,8 @@ from functools import reduce from test_dist_base import TestDistRunnerBase, runtime_main DTYPE = "float32" +DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' +DATA_MD5 = '4cc060b0a0939a343fc9242aa1ee2e4e' # For Net base_lr = 0.005 @@ -191,7 +193,9 @@ def get_batch_reader(file_list): def get_train_reader(): # The training data set. - train_reader = get_batch_reader("sample") + train_file = os.path.join(paddle.dataset.common.DATA_HOME, "simnet", + "train") + train_reader = get_batch_reader(train_file) train_feed = ["query_ids", "pos_title_ids", "neg_title_ids", "label"] return train_reader, train_feed @@ -202,15 +206,16 @@ class TestDistSimnetBow2x2(TestDistRunnerBase): avg_cost, acc, predict = train_network() inference_program = fluid.default_main_program().clone() + # Optimization - opt = get_optimizer(learning_rate=0.001) + opt = get_optimizer() opt.minimize(avg_cost) # Reader train_reader, _ = get_train_reader() - return inference_program, avg_cost, train_reader, _, acc, predict if __name__ == "__main__": + paddle.dataset.common.download(DATA_URL, 'simnet', DATA_MD5, "train") runtime_main(TestDistSimnetBow2x2) diff --git a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py new file mode 100644 index 0000000000000000000000000000000000000000..2a01f01a87e4a60be577557c717d5148e054dbc1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py @@ -0,0 +1,26 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase + + +class TestDistSimnetBow2x2(TestDistBase): + def test_simnet_bow(self): + self.check_with_place("dist_simnet_bow.py", delta=1e-7) + + +if __name__ == "__main__": + unittest.main()