提交 9ae60641 编写于 作者: T tangwei12

add dist ut for simnet bow

上级 f40110a8
......@@ -32,6 +32,8 @@ from functools import reduce
from test_dist_base import TestDistRunnerBase, runtime_main
DTYPE = "float32"
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
DATA_MD5 = '4cc060b0a0939a343fc9242aa1ee2e4e'
# For Net
base_lr = 0.005
......@@ -191,7 +193,9 @@ def get_batch_reader(file_list):
def get_train_reader():
# The training data set.
train_reader = get_batch_reader("sample")
train_file = os.path.join(paddle.dataset.common.DATA_HOME, "simnet",
"train")
train_reader = get_batch_reader(train_file)
train_feed = ["query_ids", "pos_title_ids", "neg_title_ids", "label"]
return train_reader, train_feed
......@@ -202,15 +206,16 @@ class TestDistSimnetBow2x2(TestDistRunnerBase):
avg_cost, acc, predict = train_network()
inference_program = fluid.default_main_program().clone()
# Optimization
opt = get_optimizer(learning_rate=0.001)
opt = get_optimizer()
opt.minimize(avg_cost)
# Reader
train_reader, _ = get_train_reader()
return inference_program, avg_cost, train_reader, _, acc, predict
if __name__ == "__main__":
paddle.dataset.common.download(DATA_URL, 'simnet', DATA_MD5, "train")
runtime_main(TestDistSimnetBow2x2)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
class TestDistSimnetBow2x2(TestDistBase):
def test_simnet_bow(self):
self.check_with_place("dist_simnet_bow.py", delta=1e-7)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册