test_module.py 2.2 KB
Newer Older
Z
Zeyu Chen 已提交
1
# coding=utf-8
Z
Zeyu Chen 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
W
wuzewu 已提交
17
import paddlehub as hub
18
import paddle.fluid as fluid
Z
Zeyu Chen 已提交
19 20 21


class TestModule(unittest.TestCase):
Z
Zeyu Chen 已提交
22
    #TODO(ZeyuChen): add setup for test envrinoment prepration
Z
Zeyu Chen 已提交
23
    def test_word2vec_module_usage(self):
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
        url = "https://paddlehub.cdn.bcebos.com/word2vec/word2vec_test_module.tar.gz"
        w2v_module = hub.Module(module_url=url)
        feed_dict, fetch_dict, program = w2v_module(
            sign_name="default", trainable=False)
        with fluid.program_guard(main_program=program):
            pred_prob = fetch_dict["pred_prob"]
            pred_word = fluid.layers.argmax(x=pred_prob, axis=1)
            # set place, executor, datafeeder
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            feed_vars = [
                feed_dict["firstw"], feed_dict["secondw"], feed_dict["thirdw"],
                feed_dict["fourthw"]
            ]
            feeder = fluid.DataFeeder(place=place, feed_list=feed_vars)

            word_ids = [[1, 2, 3, 4]]
            result = exe.run(
                fluid.default_main_program(),
                feed=feeder.feed(word_ids),
                fetch_list=[pred_word],
                return_numpy=True)

            self.assertEqual(result[0], 5)
Z
Zeyu Chen 已提交
48 49 50 51 52 53 54 55 56

    def test_senta_module_usage(self):
        pass
        # m = Module(module_dir="./models/bow_net")
        # inputs = [["外人", "爸妈", "翻车"], ["金钱", "电量"]]
        # tensor = m._preprocess_input(inputs)
        # print(tensor)
        # result = m({"words": tensor})
        # print(result)
Z
Zeyu Chen 已提交
57 58 59 60


if __name__ == "__main__":
    unittest.main()