未验证 提交 1961470f 编写于 作者: S Siddharth Goyal 提交者: GitHub

Add inference example and unit-test for word2vec chapter (#8206)

* Add unit-test and example

* Fix type error

* Fix unit test cases

* Fix init error for cudaplace

* Change unit-test options
上级 4b62fcd0
......@@ -25,9 +25,10 @@ function(inference_test TARGET_NAME)
endfunction(inference_test)
inference_test(fit_a_line)
inference_test(recognize_digits ARGS mlp)
inference_test(image_classification ARGS vgg resnet)
inference_test(label_semantic_roles)
inference_test(rnn_encoder_decoder)
inference_test(recognize_digits ARGS mlp)
inference_test(recommender_system)
inference_test(rnn_encoder_decoder)
inference_test(understand_sentiment)
inference_test(word2vec)
......@@ -91,7 +91,7 @@ template <typename Place, bool IsCombined = false>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
// 1. Define place, executor, scope and inference_program
// 1. Define place, executor, scope
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
......@@ -101,7 +101,8 @@ void TestInference(const std::string& dirname,
if (IsCombined) {
// All parameters are saved in a single file.
// Hard-coding the file names of program and parameters in unittest.
// Users are free to specify different filename.
// Users are free to specify different filename
// (provided: the filenames are changed in the python api as well: io.py)
std::string prog_filename = "__model_combined__";
std::string param_filename = "__params_combined__";
inference_program = paddle::inference::Load(executor,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "gflags/gflags.h"
#include "test_helper.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
TEST(inference, word2vec) {
if (FLAGS_dirname.empty()) {
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
}
LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::string dirname = FLAGS_dirname;
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
paddle::framework::LoDTensor first_word, second_word, third_word, fourth_word;
paddle::framework::LoD lod{{0, 1}};
int64_t dict_size = 2072; // Hard-coding the size of dictionary
SetupLoDTensor(first_word, lod, static_cast<int64_t>(0), dict_size);
SetupLoDTensor(second_word, lod, static_cast<int64_t>(0), dict_size);
SetupLoDTensor(third_word, lod, static_cast<int64_t>(0), dict_size);
SetupLoDTensor(fourth_word, lod, static_cast<int64_t>(0), dict_size);
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&first_word);
cpu_feeds.push_back(&second_word);
cpu_feeds.push_back(&third_word);
cpu_feeds.push_back(&fourth_word);
paddle::framework::LoDTensor output1;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.lod();
LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA
paddle::framework::LoDTensor output2;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs2;
cpu_fetchs2.push_back(&output2);
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.lod();
LOG(INFO) << output2.dims();
CheckError<float>(output1, output2);
#endif
}
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
......@@ -16,14 +15,67 @@ import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import unittest
import os
import numpy as np
import math
import sys
def main(use_cuda, is_sparse, parallel):
if use_cuda and not fluid.core.is_compiled_with_cuda():
def create_random_lodtensor(lod, place, low, high):
data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64")
res = fluid.LoDTensor()
res.set(data, place)
res.set_lod([lod])
return res
def infer(use_cuda, save_dirname=None):
if save_dirname is None:
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict) - 1
# Setup input, by creating 4 words, and setting up lod required for
# lookup_table_op
lod = [0, 1]
first_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
second_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
third_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
fourth_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
assert feed_target_names[0] == 'firstw'
assert feed_target_names[1] == 'secondw'
assert feed_target_names[2] == 'thirdw'
assert feed_target_names[3] == 'forthw'
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={
feed_target_names[0]: first_word,
feed_target_names[1]: second_word,
feed_target_names[2]: third_word,
feed_target_names[3]: fourth_word
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
print("Inference results: ", np_data)
def train(use_cuda, is_sparse, parallel, save_dirname):
PASS_NUM = 100
EMBED_SIZE = 32
HIDDEN_SIZE = 256
......@@ -67,7 +119,7 @@ def main(use_cuda, is_sparse, parallel):
act='softmax')
cost = fluid.layers.cross_entropy(input=predict_word, label=words[4])
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
return avg_cost, predict_word
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
......@@ -79,13 +131,13 @@ def main(use_cuda, is_sparse, parallel):
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
if not parallel:
avg_cost = __network__(
avg_cost, predict_word = __network__(
[first_word, second_word, third_word, forth_word, next_word])
else:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
avg_cost = __network__(
avg_cost, predict_word = __network__(
map(pd.read_input, [
first_word, second_word, third_word, forth_word, next_word
]))
......@@ -113,6 +165,10 @@ def main(use_cuda, is_sparse, parallel):
feed=feeder.feed(data),
fetch_list=[avg_cost])
if avg_cost_np[0] < 5.0:
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname, [
'firstw', 'secondw', 'thirdw', 'forthw'
], [predict_word], exe)
return
if math.isnan(float(avg_cost_np[0])):
sys.exit("got NaN loss, training failed.")
......@@ -120,6 +176,14 @@ def main(use_cuda, is_sparse, parallel):
raise AssertionError("Cost is too large {0:2.2}".format(avg_cost_np[0]))
def main(use_cuda, is_sparse, parallel):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
save_dirname = "word2vec.inference.model"
train(use_cuda, is_sparse, parallel, save_dirname)
infer(use_cuda, save_dirname)
FULL_TEST = os.getenv('FULL_TEST',
'0').lower() in ['true', '1', 't', 'y', 'yes', 'on']
SKIP_REASON = "Only run minimum number of tests in CI server, to make CI faster"
......@@ -142,7 +206,8 @@ def inject_test_method(use_cuda, is_sparse, parallel):
with fluid.program_guard(prog, startup_prog):
main(use_cuda=use_cuda, is_sparse=is_sparse, parallel=parallel)
if use_cuda and is_sparse and parallel:
# run only 2 cases: use_cuda is either True or False
if is_sparse == False and parallel == False:
fn = __impl__
else:
# skip the other test when on CI server
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册