提交 1185a1b5 编写于 作者: Y Yiqun Liu 提交者: kexinzhao

Add C++ inference unittest of recommender system (#8227)

* Save the inference model in Python example of recommender_system.

* Add infer() in Python unittest recommender_system.

* Add C++ inference unittest of recommender_system.
上级 b257ca9a
......@@ -28,3 +28,4 @@ inference_test(recognize_digits ARGS mlp)
inference_test(image_classification ARGS vgg resnet)
inference_test(label_semantic_roles)
inference_test(rnn_encoder_decoder)
inference_test(recommender_system)
......@@ -30,6 +30,15 @@ void SetupTensor(paddle::framework::LoDTensor& input,
}
}
template <typename T>
void SetupTensor(paddle::framework::LoDTensor& input,
paddle::framework::DDim dims,
std::vector<T>& data) {
CHECK_EQ(paddle::framework::product(dims), data.size());
T* input_ptr = input.mutable_data<T>(dims, paddle::platform::CPUPlace());
memcpy(input_ptr, data.data(), input.numel() * sizeof(T));
}
template <typename T>
void SetupLoDTensor(paddle::framework::LoDTensor& input,
paddle::framework::LoD& lod,
......@@ -37,7 +46,18 @@ void SetupLoDTensor(paddle::framework::LoDTensor& input,
T upper) {
input.set_lod(lod);
int dim = lod[0][lod[0].size() - 1];
SetupTensor(input, {dim, 1}, lower, upper);
SetupTensor<T>(input, {dim, 1}, lower, upper);
}
template <typename T>
void SetupLoDTensor(paddle::framework::LoDTensor& input,
paddle::framework::DDim dims,
paddle::framework::LoD lod,
std::vector<T>& data) {
const size_t level = lod.size() - 1;
CHECK_EQ(dims[0], (lod[level]).back());
input.set_lod(lod);
SetupTensor<T>(input, dims, data);
}
template <typename T>
......@@ -67,7 +87,7 @@ void CheckError(paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
}
template <typename Place, typename T, bool IsCombined = false>
template <typename Place, bool IsCombined = false>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
......@@ -75,11 +95,13 @@ void TestInference(const std::string& dirname,
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
// 2. Initialize the inference_program and load all parameters from file
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
if (IsCombined) {
// Hard-coding the names for combined params case
// All parameters are saved in a single file.
// Hard-coding the file names of program and parameters in unittest.
// Users are free to specify different filename.
std::string prog_filename = "__model_combined__";
std::string param_filename = "__params_combined__";
inference_program = paddle::inference::Load(executor,
......@@ -87,6 +109,7 @@ void TestInference(const std::string& dirname,
dirname + "/" + prog_filename,
dirname + "/" + param_filename);
} else {
// Parameters are saved in separate files sited in the specified `dirname`.
inference_program = paddle::inference::Load(executor, *scope, dirname);
}
......
......@@ -29,11 +29,15 @@ TEST(inference, image_classification) {
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
int64_t batch_size = 1;
paddle::framework::LoDTensor input;
// Use normilized image pixels as input data,
// which should be in the range [0.0, 1.0].
SetupTensor<float>(
input, {1, 3, 32, 32}, static_cast<float>(0), static_cast<float>(1));
SetupTensor<float>(input,
{batch_size, 3, 32, 32},
static_cast<float>(0),
static_cast<float>(1));
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&input);
......@@ -42,8 +46,7 @@ TEST(inference, image_classification) {
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
TestInference<paddle::platform::CPUPlace, float>(
dirname, cpu_feeds, cpu_fetchs1);
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA
......@@ -52,8 +55,7 @@ TEST(inference, image_classification) {
cpu_fetchs2.push_back(&output2);
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace, float>(
dirname, cpu_feeds, cpu_fetchs2);
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.dims();
CheckError<float>(output1, output2);
......
......@@ -58,8 +58,7 @@ TEST(inference, label_semantic_roles) {
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
TestInference<paddle::platform::CPUPlace, float>(
dirname, cpu_feeds, cpu_fetchs1);
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.lod();
LOG(INFO) << output1.dims();
......@@ -69,8 +68,7 @@ TEST(inference, label_semantic_roles) {
cpu_fetchs2.push_back(&output2);
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace, float>(
dirname, cpu_feeds, cpu_fetchs2);
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.lod();
LOG(INFO) << output2.dims();
......
......@@ -29,11 +29,15 @@ TEST(inference, recognize_digits) {
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
int64_t batch_size = 1;
paddle::framework::LoDTensor input;
// Use normilized image pixels as input data,
// which should be in the range [-1.0, 1.0].
SetupTensor<float>(
input, {1, 28, 28}, static_cast<float>(-1), static_cast<float>(1));
SetupTensor<float>(input,
{batch_size, 1, 28, 28},
static_cast<float>(-1),
static_cast<float>(1));
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&input);
......@@ -42,8 +46,7 @@ TEST(inference, recognize_digits) {
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
TestInference<paddle::platform::CPUPlace, float>(
dirname, cpu_feeds, cpu_fetchs1);
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA
......@@ -52,8 +55,7 @@ TEST(inference, recognize_digits) {
cpu_fetchs2.push_back(&output2);
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace, float>(
dirname, cpu_feeds, cpu_fetchs2);
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.dims();
CheckError<float>(output1, output2);
......@@ -84,7 +86,7 @@ TEST(inference, recognize_digits_combine) {
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
TestInference<paddle::platform::CPUPlace, float, true>(
TestInference<paddle::platform::CPUPlace, true>(
dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.dims();
......@@ -94,7 +96,7 @@ TEST(inference, recognize_digits_combine) {
cpu_fetchs2.push_back(&output2);
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace, float, true>(
TestInference<paddle::platform::CUDAPlace, true>(
dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.dims();
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "gflags/gflags.h"
#include "test_helper.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
TEST(inference, recommender_system) {
if (FLAGS_dirname.empty()) {
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
}
LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::string dirname = FLAGS_dirname;
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
int64_t batch_size = 1;
paddle::framework::LoDTensor user_id, gender_id, age_id, job_id, movie_id,
category_id, movie_title;
// Use the first data from paddle.dataset.movielens.test() as input
std::vector<int64_t> user_id_data = {1};
SetupTensor<int64_t>(user_id, {batch_size, 1}, user_id_data);
std::vector<int64_t> gender_id_data = {1};
SetupTensor<int64_t>(gender_id, {batch_size, 1}, gender_id_data);
std::vector<int64_t> age_id_data = {0};
SetupTensor<int64_t>(age_id, {batch_size, 1}, age_id_data);
std::vector<int64_t> job_id_data = {10};
SetupTensor<int64_t>(job_id, {batch_size, 1}, job_id_data);
std::vector<int64_t> movie_id_data = {783};
SetupTensor<int64_t>(movie_id, {batch_size, 1}, movie_id_data);
std::vector<int64_t> category_id_data = {10, 8, 9};
SetupLoDTensor<int64_t>(category_id, {3, 1}, {{0, 3}}, category_id_data);
std::vector<int64_t> movie_title_data = {1069, 4140, 2923, 710, 988};
SetupLoDTensor<int64_t>(movie_title, {5, 1}, {{0, 5}}, movie_title_data);
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&user_id);
cpu_feeds.push_back(&gender_id);
cpu_feeds.push_back(&age_id);
cpu_feeds.push_back(&job_id);
cpu_feeds.push_back(&movie_id);
cpu_feeds.push_back(&category_id);
cpu_feeds.push_back(&movie_title);
paddle::framework::LoDTensor output1;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA
paddle::framework::LoDTensor output2;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs2;
cpu_fetchs2.push_back(&output2);
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.dims();
CheckError<float>(output1, output2);
#endif
}
......@@ -46,8 +46,7 @@ TEST(inference, rnn_encoder_decoder) {
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
TestInference<paddle::platform::CPUPlace, float>(
dirname, cpu_feeds, cpu_fetchs1);
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.lod();
LOG(INFO) << output1.dims();
......@@ -57,8 +56,7 @@ TEST(inference, rnn_encoder_decoder) {
cpu_fetchs2.push_back(&output2);
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace, float>(
dirname, cpu_feeds, cpu_fetchs2);
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.lod();
LOG(INFO) << output2.dims();
......
recognize_digits_*.inference.model
*.inference.model
......@@ -174,8 +174,9 @@ def infer(use_cuda, save_dirname=None, param_filename=None):
# The input's dimension of conv should be 4-D or 5-D.
# Use normilized image pixels as input data, which should be in the range [-1.0, 1.0].
batch_size = 1
tensor_img = numpy.random.uniform(-1.0, 1.0,
[1, 1, 28, 28]).astype("float32")
[batch_size, 1, 28, 28]).astype("float32")
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
......
......@@ -16,7 +16,7 @@ import math
import sys
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid.core as core
import paddle.v2.fluid as fluid
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
......@@ -104,7 +104,8 @@ def get_mov_combined_features():
CATEGORY_DICT_SIZE = len(paddle.dataset.movielens.movie_categories())
category_id = layers.data(name='category_id', shape=[1], dtype='int64')
category_id = layers.data(
name='category_id', shape=[1], dtype='int64', lod_level=1)
mov_categories_emb = layers.embedding(
input=category_id, size=[CATEGORY_DICT_SIZE, 32], is_sparse=IS_SPARSE)
......@@ -114,7 +115,8 @@ def get_mov_combined_features():
MOV_TITLE_DICT_SIZE = len(paddle.dataset.movielens.get_movie_title_dict())
mov_title_id = layers.data(name='movie_title', shape=[1], dtype='int64')
mov_title_id = layers.data(
name='movie_title', shape=[1], dtype='int64', lod_level=1)
mov_title_emb = layers.embedding(
input=mov_title_id, size=[MOV_TITLE_DICT_SIZE, 32], is_sparse=IS_SPARSE)
......@@ -144,23 +146,22 @@ def model():
scale_infer = layers.scale(x=inference, scale=5.0)
label = layers.data(name='score', shape=[1], dtype='float32')
square_cost = layers.square_error_cost(input=scale_infer, label=label)
avg_cost = layers.mean(x=square_cost)
return avg_cost
return scale_infer, avg_cost
def train(use_cuda, save_dirname):
scale_infer, avg_cost = model()
# test program
test_program = fluid.default_main_program().clone()
def main():
cost = model()
sgd_optimizer = SGDOptimizer(learning_rate=0.2)
opts = sgd_optimizer.minimize(cost)
opts = sgd_optimizer.minimize(avg_cost)
if USE_GPU:
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
......@@ -169,6 +170,8 @@ def main():
paddle.reader.shuffle(
paddle.dataset.movielens.train(), buf_size=8192),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.movielens.test(), batch_size=BATCH_SIZE)
feeding = {
'user_id': 0,
......@@ -184,7 +187,7 @@ def main():
def func_feed(feeding, data):
feed_tensors = {}
for (key, idx) in feeding.iteritems():
tensor = core.LoDTensor()
tensor = fluid.LoDTensor()
if key != "category_id" and key != "movie_title":
if key == "score":
numpy_data = np.array(map(lambda x: x[idx], data)).astype(
......@@ -211,16 +214,117 @@ def main():
PASS_NUM = 100
for pass_id in range(PASS_NUM):
for data in train_reader():
outs = exe.run(framework.default_main_program(),
for batch_id, data in enumerate(train_reader()):
# train a mini-batch
outs = exe.run(program=fluid.default_main_program(),
feed=func_feed(feeding, data),
fetch_list=[cost])
fetch_list=[avg_cost])
out = np.array(outs[0])
if out[0] < 6.0:
# if avg cost less than 6.0, we think our code is good.
exit(0)
if (batch_id + 1) % 10 == 0:
avg_cost_set = []
for test_data in test_reader():
avg_cost_np = exe.run(program=test_program,
feed=func_feed(feeding, test_data),
fetch_list=[avg_cost])
avg_cost_set.append(avg_cost_np[0])
break # test only 1 segment for speeding up CI
# get test avg_cost
test_avg_cost = np.array(avg_cost_set).mean()
if test_avg_cost < 6.0:
# if avg_cost less than 6.0, we think our code is good.
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname, [
"user_id", "gender_id", "age_id", "job_id",
"movie_id", "category_id", "movie_title"
], [scale_infer], exe)
return
if math.isnan(float(out[0])):
sys.exit("got NaN loss, training failed.")
main()
def infer(use_cuda, save_dirname=None):
if save_dirname is None:
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
def create_lod_tensor(data, lod=None):
tensor = fluid.LoDTensor()
if lod is None:
# Tensor, the shape is [batch_size, 1]
index = 0
lod_0 = [index]
for l in range(len(data)):
index += 1
lod_0.append(index)
lod = [lod_0]
tensor.set_lod(lod)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
tensor.set(flattened_data, place)
return tensor
# Use the first data from paddle.dataset.movielens.test() as input
assert feed_target_names[0] == "user_id"
user_id = create_lod_tensor([[1]])
assert feed_target_names[1] == "gender_id"
gender_id = create_lod_tensor([[1]])
assert feed_target_names[2] == "age_id"
age_id = create_lod_tensor([[0]])
assert feed_target_names[3] == "job_id"
job_id = create_lod_tensor([[10]])
assert feed_target_names[4] == "movie_id"
movie_id = create_lod_tensor([[783]])
assert feed_target_names[5] == "category_id"
category_id = create_lod_tensor([[10], [8], [9]], [[0, 3]])
assert feed_target_names[6] == "movie_title"
movie_title = create_lod_tensor([[1069], [4140], [2923], [710], [988]],
[[0, 5]])
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={
feed_target_names[0]: user_id,
feed_target_names[1]: gender_id,
feed_target_names[2]: age_id,
feed_target_names[3]: job_id,
feed_target_names[4]: movie_id,
feed_target_names[5]: category_id,
feed_target_names[6]: movie_title
},
fetch_list=fetch_targets,
return_numpy=False)
print("inferred score: ", np.array(results[0]))
def main(use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
# Directory for saving the inference model
save_dirname = "recommender_system.inference.model"
train(use_cuda, save_dirname)
infer(use_cuda, save_dirname)
if __name__ == '__main__':
main(USE_GPU)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册