提交 63192b42 编写于 作者: W wuzewu
...@@ -18,7 +18,7 @@ import paddle_hub as hub ...@@ -18,7 +18,7 @@ import paddle_hub as hub
class TestDownloader(unittest.TestCase): class TestDownloader(unittest.TestCase):
def test_download(self): def test_download(self):
link = "http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz" link = "https://paddlehub.cdn.bcebos.com/word2vec/word2vec_test_module.tar.gz"
module_path = hub.download_and_uncompress(link) module_path = hub.download_and_uncompress(link)
......
...@@ -154,36 +154,40 @@ def test_create_w2v_module(use_gpu=False): ...@@ -154,36 +154,40 @@ def test_create_w2v_module(use_gpu=False):
main_program.global_block().var("fourthw"), main_program.global_block().var("fourthw"),
] ]
signature = hub.create_signature( signature = hub.create_signature(
"default", inputs=module_inputs, outputs=[pred_prob]) "default",
inputs=module_inputs,
outputs=[pred_prob],
feed_names=["firstw", "secondw", "thirdw", "fourthw"],
fetch_names=["pred_prob"])
hub.create_module( hub.create_module(
sign_arr=signature, sign_arr=signature, module_dir=saved_module_dir, word_dict=dictionary)
program=fluid.default_main_program(),
module_dir=saved_module_dir,
word_dict=dictionary)
def test_load_w2v_module(use_gpu=False): def test_load_w2v_module(use_gpu=False):
saved_module_dir = "./tmp/word2vec_test_module" saved_module_dir = "./tmp/word2vec_test_module"
w2v_module = hub.Module(module_dir=saved_module_dir) w2v_module = hub.Module(module_dir=saved_module_dir)
feed_list, fetch_list, program, generator = w2v_module( feed_dict, fetch_dict, program = w2v_module(
sign_name="default", trainable=False) sign_name="default", trainable=False)
with fluid.program_guard(main_program=program): with fluid.program_guard(main_program=program):
with fluid.unique_name.guard(generator): pred_prob = fetch_dict["pred_prob"]
pred_prob = fetch_list[0] pred_word = fluid.layers.argmax(x=pred_prob, axis=1)
pred_word = fluid.layers.argmax(x=pred_prob, axis=1) # set place, executor, datafeeder
# set place, executor, datafeeder place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() exe = fluid.Executor(place)
exe = fluid.Executor(place) feed_vars = [
feeder = fluid.DataFeeder(place=place, feed_list=feed_list) feed_dict["firstw"], feed_dict["secondw"], feed_dict["thirdw"],
feed_dict["fourthw"]
word_ids = [[1, 2, 3, 4]] ]
result = exe.run( feeder = fluid.DataFeeder(place=place, feed_list=feed_vars)
fluid.default_main_program(),
feed=feeder.feed(word_ids), word_ids = [[1, 2, 3, 4]]
fetch_list=[pred_word], result = exe.run(
return_numpy=True) fluid.default_main_program(),
feed=feeder.feed(word_ids),
print(result) fetch_list=[pred_word],
return_numpy=True)
print(result)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -15,19 +15,36 @@ ...@@ -15,19 +15,36 @@
import unittest import unittest
import paddle_hub as hub import paddle_hub as hub
import paddle.fluid as fluid
class TestModule(unittest.TestCase): class TestModule(unittest.TestCase):
#TODO(ZeyuChen): add setup for test envrinoment prepration #TODO(ZeyuChen): add setup for test envrinoment prepration
def test_word2vec_module_usage(self): def test_word2vec_module_usage(self):
pass url = "https://paddlehub.cdn.bcebos.com/word2vec/word2vec_test_module.tar.gz"
# url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz" w2v_module = hub.Module(module_url=url)
# module = Module(module_url=url) feed_dict, fetch_dict, program = w2v_module(
# inputs = [["it", "is", "new"], ["hello", "world"]] sign_name="default", trainable=False)
# tensor = module._process_input(inputs) with fluid.program_guard(main_program=program):
# print(tensor) pred_prob = fetch_dict["pred_prob"]
# result = module(inputs) pred_word = fluid.layers.argmax(x=pred_prob, axis=1)
# print(result) # set place, executor, datafeeder
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feed_vars = [
feed_dict["firstw"], feed_dict["secondw"], feed_dict["thirdw"],
feed_dict["fourthw"]
]
feeder = fluid.DataFeeder(place=place, feed_list=feed_vars)
word_ids = [[1, 2, 3, 4]]
result = exe.run(
fluid.default_main_program(),
feed=feeder.feed(word_ids),
fetch_list=[pred_word],
return_numpy=True)
self.assertEqual(result[0], 5)
def test_senta_module_usage(self): def test_senta_module_usage(self):
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册