From 1185a1b5ab897133585ffdcef8f68d0ee6caef46 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Fri, 9 Feb 2018 13:27:42 +0800 Subject: [PATCH] Add C++ inference unittest of recommender system (#8227) * Save the inference model in Python example of recommender_system. * Add infer() in Python unittest recommender_system. * Add C++ inference unittest of recommender_system. --- paddle/inference/tests/book/CMakeLists.txt | 1 + paddle/inference/tests/book/test_helper.h | 33 +++- .../test_inference_image_classification.cc | 14 +- .../test_inference_label_semantic_roles.cc | 6 +- .../book/test_inference_recognize_digits.cc | 18 ++- .../book/test_inference_recommender_system.cc | 87 +++++++++++ .../test_inference_rnn_encoder_decoder.cc | 6 +- python/paddle/v2/fluid/tests/book/.gitignore | 2 +- .../fluid/tests/book/test_recognize_digits.py | 3 +- .../tests/book/test_recommender_system.py | 146 +++++++++++++++--- 10 files changed, 266 insertions(+), 50 deletions(-) create mode 100644 paddle/inference/tests/book/test_inference_recommender_system.cc diff --git a/paddle/inference/tests/book/CMakeLists.txt b/paddle/inference/tests/book/CMakeLists.txt index 479f51f1df..5c866eb1e2 100644 --- a/paddle/inference/tests/book/CMakeLists.txt +++ b/paddle/inference/tests/book/CMakeLists.txt @@ -28,3 +28,4 @@ inference_test(recognize_digits ARGS mlp) inference_test(image_classification ARGS vgg resnet) inference_test(label_semantic_roles) inference_test(rnn_encoder_decoder) +inference_test(recommender_system) diff --git a/paddle/inference/tests/book/test_helper.h b/paddle/inference/tests/book/test_helper.h index 3e66ced94f..22ce903c72 100644 --- a/paddle/inference/tests/book/test_helper.h +++ b/paddle/inference/tests/book/test_helper.h @@ -30,6 +30,15 @@ void SetupTensor(paddle::framework::LoDTensor& input, } } +template +void SetupTensor(paddle::framework::LoDTensor& input, + paddle::framework::DDim dims, + std::vector& data) { + CHECK_EQ(paddle::framework::product(dims), data.size()); + T* input_ptr = input.mutable_data(dims, paddle::platform::CPUPlace()); + memcpy(input_ptr, data.data(), input.numel() * sizeof(T)); +} + template void SetupLoDTensor(paddle::framework::LoDTensor& input, paddle::framework::LoD& lod, @@ -37,7 +46,18 @@ void SetupLoDTensor(paddle::framework::LoDTensor& input, T upper) { input.set_lod(lod); int dim = lod[0][lod[0].size() - 1]; - SetupTensor(input, {dim, 1}, lower, upper); + SetupTensor(input, {dim, 1}, lower, upper); +} + +template +void SetupLoDTensor(paddle::framework::LoDTensor& input, + paddle::framework::DDim dims, + paddle::framework::LoD lod, + std::vector& data) { + const size_t level = lod.size() - 1; + CHECK_EQ(dims[0], (lod[level]).back()); + input.set_lod(lod); + SetupTensor(input, dims, data); } template @@ -67,7 +87,7 @@ void CheckError(paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0) << "There are " << count << " different elements."; } -template +template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, std::vector& cpu_fetchs) { @@ -75,11 +95,13 @@ void TestInference(const std::string& dirname, auto place = Place(); auto executor = paddle::framework::Executor(place); auto* scope = new paddle::framework::Scope(); - std::unique_ptr inference_program; - // 2. Initialize the inference_program and load all parameters from file + // 2. Initialize the inference_program and load parameters + std::unique_ptr inference_program; if (IsCombined) { - // Hard-coding the names for combined params case + // All parameters are saved in a single file. + // Hard-coding the file names of program and parameters in unittest. + // Users are free to specify different filename. std::string prog_filename = "__model_combined__"; std::string param_filename = "__params_combined__"; inference_program = paddle::inference::Load(executor, @@ -87,6 +109,7 @@ void TestInference(const std::string& dirname, dirname + "/" + prog_filename, dirname + "/" + param_filename); } else { + // Parameters are saved in separate files sited in the specified `dirname`. inference_program = paddle::inference::Load(executor, *scope, dirname); } diff --git a/paddle/inference/tests/book/test_inference_image_classification.cc b/paddle/inference/tests/book/test_inference_image_classification.cc index 35ff9431e9..36ea7c77a7 100644 --- a/paddle/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/inference/tests/book/test_inference_image_classification.cc @@ -29,11 +29,15 @@ TEST(inference, image_classification) { // 0. Call `paddle::framework::InitDevices()` initialize all the devices // In unittests, this is done in paddle/testing/paddle_gtest_main.cc + int64_t batch_size = 1; + paddle::framework::LoDTensor input; // Use normilized image pixels as input data, // which should be in the range [0.0, 1.0]. - SetupTensor( - input, {1, 3, 32, 32}, static_cast(0), static_cast(1)); + SetupTensor(input, + {batch_size, 3, 32, 32}, + static_cast(0), + static_cast(1)); std::vector cpu_feeds; cpu_feeds.push_back(&input); @@ -42,8 +46,7 @@ TEST(inference, image_classification) { cpu_fetchs1.push_back(&output1); // Run inference on CPU - TestInference( - dirname, cpu_feeds, cpu_fetchs1); + TestInference(dirname, cpu_feeds, cpu_fetchs1); LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA @@ -52,8 +55,7 @@ TEST(inference, image_classification) { cpu_fetchs2.push_back(&output2); // Run inference on CUDA GPU - TestInference( - dirname, cpu_feeds, cpu_fetchs2); + TestInference(dirname, cpu_feeds, cpu_fetchs2); LOG(INFO) << output2.dims(); CheckError(output1, output2); diff --git a/paddle/inference/tests/book/test_inference_label_semantic_roles.cc b/paddle/inference/tests/book/test_inference_label_semantic_roles.cc index 1eaf4022a1..922dbfd333 100644 --- a/paddle/inference/tests/book/test_inference_label_semantic_roles.cc +++ b/paddle/inference/tests/book/test_inference_label_semantic_roles.cc @@ -58,8 +58,7 @@ TEST(inference, label_semantic_roles) { cpu_fetchs1.push_back(&output1); // Run inference on CPU - TestInference( - dirname, cpu_feeds, cpu_fetchs1); + TestInference(dirname, cpu_feeds, cpu_fetchs1); LOG(INFO) << output1.lod(); LOG(INFO) << output1.dims(); @@ -69,8 +68,7 @@ TEST(inference, label_semantic_roles) { cpu_fetchs2.push_back(&output2); // Run inference on CUDA GPU - TestInference( - dirname, cpu_feeds, cpu_fetchs2); + TestInference(dirname, cpu_feeds, cpu_fetchs2); LOG(INFO) << output2.lod(); LOG(INFO) << output2.dims(); diff --git a/paddle/inference/tests/book/test_inference_recognize_digits.cc b/paddle/inference/tests/book/test_inference_recognize_digits.cc index 3a48db7fe0..af8c2b14c3 100644 --- a/paddle/inference/tests/book/test_inference_recognize_digits.cc +++ b/paddle/inference/tests/book/test_inference_recognize_digits.cc @@ -29,11 +29,15 @@ TEST(inference, recognize_digits) { // 0. Call `paddle::framework::InitDevices()` initialize all the devices // In unittests, this is done in paddle/testing/paddle_gtest_main.cc + int64_t batch_size = 1; + paddle::framework::LoDTensor input; // Use normilized image pixels as input data, // which should be in the range [-1.0, 1.0]. - SetupTensor( - input, {1, 28, 28}, static_cast(-1), static_cast(1)); + SetupTensor(input, + {batch_size, 1, 28, 28}, + static_cast(-1), + static_cast(1)); std::vector cpu_feeds; cpu_feeds.push_back(&input); @@ -42,8 +46,7 @@ TEST(inference, recognize_digits) { cpu_fetchs1.push_back(&output1); // Run inference on CPU - TestInference( - dirname, cpu_feeds, cpu_fetchs1); + TestInference(dirname, cpu_feeds, cpu_fetchs1); LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA @@ -52,8 +55,7 @@ TEST(inference, recognize_digits) { cpu_fetchs2.push_back(&output2); // Run inference on CUDA GPU - TestInference( - dirname, cpu_feeds, cpu_fetchs2); + TestInference(dirname, cpu_feeds, cpu_fetchs2); LOG(INFO) << output2.dims(); CheckError(output1, output2); @@ -84,7 +86,7 @@ TEST(inference, recognize_digits_combine) { cpu_fetchs1.push_back(&output1); // Run inference on CPU - TestInference( + TestInference( dirname, cpu_feeds, cpu_fetchs1); LOG(INFO) << output1.dims(); @@ -94,7 +96,7 @@ TEST(inference, recognize_digits_combine) { cpu_fetchs2.push_back(&output2); // Run inference on CUDA GPU - TestInference( + TestInference( dirname, cpu_feeds, cpu_fetchs2); LOG(INFO) << output2.dims(); diff --git a/paddle/inference/tests/book/test_inference_recommender_system.cc b/paddle/inference/tests/book/test_inference_recommender_system.cc new file mode 100644 index 0000000000..ec24c7e6ab --- /dev/null +++ b/paddle/inference/tests/book/test_inference_recommender_system.cc @@ -0,0 +1,87 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "gflags/gflags.h" +#include "test_helper.h" + +DEFINE_string(dirname, "", "Directory of the inference model."); + +TEST(inference, recommender_system) { + if (FLAGS_dirname.empty()) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; + } + + LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; + std::string dirname = FLAGS_dirname; + + // 0. Call `paddle::framework::InitDevices()` initialize all the devices + // In unittests, this is done in paddle/testing/paddle_gtest_main.cc + + int64_t batch_size = 1; + + paddle::framework::LoDTensor user_id, gender_id, age_id, job_id, movie_id, + category_id, movie_title; + + // Use the first data from paddle.dataset.movielens.test() as input + std::vector user_id_data = {1}; + SetupTensor(user_id, {batch_size, 1}, user_id_data); + + std::vector gender_id_data = {1}; + SetupTensor(gender_id, {batch_size, 1}, gender_id_data); + + std::vector age_id_data = {0}; + SetupTensor(age_id, {batch_size, 1}, age_id_data); + + std::vector job_id_data = {10}; + SetupTensor(job_id, {batch_size, 1}, job_id_data); + + std::vector movie_id_data = {783}; + SetupTensor(movie_id, {batch_size, 1}, movie_id_data); + + std::vector category_id_data = {10, 8, 9}; + SetupLoDTensor(category_id, {3, 1}, {{0, 3}}, category_id_data); + + std::vector movie_title_data = {1069, 4140, 2923, 710, 988}; + SetupLoDTensor(movie_title, {5, 1}, {{0, 5}}, movie_title_data); + + std::vector cpu_feeds; + cpu_feeds.push_back(&user_id); + cpu_feeds.push_back(&gender_id); + cpu_feeds.push_back(&age_id); + cpu_feeds.push_back(&job_id); + cpu_feeds.push_back(&movie_id); + cpu_feeds.push_back(&category_id); + cpu_feeds.push_back(&movie_title); + + paddle::framework::LoDTensor output1; + std::vector cpu_fetchs1; + cpu_fetchs1.push_back(&output1); + + // Run inference on CPU + TestInference(dirname, cpu_feeds, cpu_fetchs1); + LOG(INFO) << output1.dims(); + +#ifdef PADDLE_WITH_CUDA + paddle::framework::LoDTensor output2; + std::vector cpu_fetchs2; + cpu_fetchs2.push_back(&output2); + + // Run inference on CUDA GPU + TestInference(dirname, cpu_feeds, cpu_fetchs2); + LOG(INFO) << output2.dims(); + + CheckError(output1, output2); +#endif +} diff --git a/paddle/inference/tests/book/test_inference_rnn_encoder_decoder.cc b/paddle/inference/tests/book/test_inference_rnn_encoder_decoder.cc index 9bfc0407b7..248b9dce21 100644 --- a/paddle/inference/tests/book/test_inference_rnn_encoder_decoder.cc +++ b/paddle/inference/tests/book/test_inference_rnn_encoder_decoder.cc @@ -46,8 +46,7 @@ TEST(inference, rnn_encoder_decoder) { cpu_fetchs1.push_back(&output1); // Run inference on CPU - TestInference( - dirname, cpu_feeds, cpu_fetchs1); + TestInference(dirname, cpu_feeds, cpu_fetchs1); LOG(INFO) << output1.lod(); LOG(INFO) << output1.dims(); @@ -57,8 +56,7 @@ TEST(inference, rnn_encoder_decoder) { cpu_fetchs2.push_back(&output2); // Run inference on CUDA GPU - TestInference( - dirname, cpu_feeds, cpu_fetchs2); + TestInference(dirname, cpu_feeds, cpu_fetchs2); LOG(INFO) << output2.lod(); LOG(INFO) << output2.dims(); diff --git a/python/paddle/v2/fluid/tests/book/.gitignore b/python/paddle/v2/fluid/tests/book/.gitignore index f0b574b939..dd28d354f4 100644 --- a/python/paddle/v2/fluid/tests/book/.gitignore +++ b/python/paddle/v2/fluid/tests/book/.gitignore @@ -1 +1 @@ -recognize_digits_*.inference.model +*.inference.model diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py index 6f9d85faff..244c1749cd 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py @@ -174,8 +174,9 @@ def infer(use_cuda, save_dirname=None, param_filename=None): # The input's dimension of conv should be 4-D or 5-D. # Use normilized image pixels as input data, which should be in the range [-1.0, 1.0]. + batch_size = 1 tensor_img = numpy.random.uniform(-1.0, 1.0, - [1, 1, 28, 28]).astype("float32") + [batch_size, 1, 28, 28]).astype("float32") # Construct feed as a dictionary of {feed_target_name: feed_target_data} # and results will contain a list of data corresponding to fetch_targets. diff --git a/python/paddle/v2/fluid/tests/book/test_recommender_system.py b/python/paddle/v2/fluid/tests/book/test_recommender_system.py index 9c7ab7d631..612d51e08e 100644 --- a/python/paddle/v2/fluid/tests/book/test_recommender_system.py +++ b/python/paddle/v2/fluid/tests/book/test_recommender_system.py @@ -16,7 +16,7 @@ import math import sys import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid.core as core +import paddle.v2.fluid as fluid import paddle.v2.fluid.framework as framework import paddle.v2.fluid.layers as layers import paddle.v2.fluid.nets as nets @@ -104,7 +104,8 @@ def get_mov_combined_features(): CATEGORY_DICT_SIZE = len(paddle.dataset.movielens.movie_categories()) - category_id = layers.data(name='category_id', shape=[1], dtype='int64') + category_id = layers.data( + name='category_id', shape=[1], dtype='int64', lod_level=1) mov_categories_emb = layers.embedding( input=category_id, size=[CATEGORY_DICT_SIZE, 32], is_sparse=IS_SPARSE) @@ -114,7 +115,8 @@ def get_mov_combined_features(): MOV_TITLE_DICT_SIZE = len(paddle.dataset.movielens.get_movie_title_dict()) - mov_title_id = layers.data(name='movie_title', shape=[1], dtype='int64') + mov_title_id = layers.data( + name='movie_title', shape=[1], dtype='int64', lod_level=1) mov_title_emb = layers.embedding( input=mov_title_id, size=[MOV_TITLE_DICT_SIZE, 32], is_sparse=IS_SPARSE) @@ -144,23 +146,22 @@ def model(): scale_infer = layers.scale(x=inference, scale=5.0) label = layers.data(name='score', shape=[1], dtype='float32') - square_cost = layers.square_error_cost(input=scale_infer, label=label) - avg_cost = layers.mean(x=square_cost) - return avg_cost + return scale_infer, avg_cost + +def train(use_cuda, save_dirname): + scale_infer, avg_cost = model() + + # test program + test_program = fluid.default_main_program().clone() -def main(): - cost = model() sgd_optimizer = SGDOptimizer(learning_rate=0.2) - opts = sgd_optimizer.minimize(cost) + opts = sgd_optimizer.minimize(avg_cost) - if USE_GPU: - place = core.CUDAPlace(0) - else: - place = core.CPUPlace() + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = Executor(place) exe.run(framework.default_startup_program()) @@ -169,6 +170,8 @@ def main(): paddle.reader.shuffle( paddle.dataset.movielens.train(), buf_size=8192), batch_size=BATCH_SIZE) + test_reader = paddle.batch( + paddle.dataset.movielens.test(), batch_size=BATCH_SIZE) feeding = { 'user_id': 0, @@ -184,7 +187,7 @@ def main(): def func_feed(feeding, data): feed_tensors = {} for (key, idx) in feeding.iteritems(): - tensor = core.LoDTensor() + tensor = fluid.LoDTensor() if key != "category_id" and key != "movie_title": if key == "score": numpy_data = np.array(map(lambda x: x[idx], data)).astype( @@ -211,16 +214,117 @@ def main(): PASS_NUM = 100 for pass_id in range(PASS_NUM): - for data in train_reader(): - outs = exe.run(framework.default_main_program(), + for batch_id, data in enumerate(train_reader()): + # train a mini-batch + outs = exe.run(program=fluid.default_main_program(), feed=func_feed(feeding, data), - fetch_list=[cost]) + fetch_list=[avg_cost]) out = np.array(outs[0]) - if out[0] < 6.0: - # if avg cost less than 6.0, we think our code is good. - exit(0) + if (batch_id + 1) % 10 == 0: + avg_cost_set = [] + for test_data in test_reader(): + avg_cost_np = exe.run(program=test_program, + feed=func_feed(feeding, test_data), + fetch_list=[avg_cost]) + avg_cost_set.append(avg_cost_np[0]) + break # test only 1 segment for speeding up CI + + # get test avg_cost + test_avg_cost = np.array(avg_cost_set).mean() + if test_avg_cost < 6.0: + # if avg_cost less than 6.0, we think our code is good. + if save_dirname is not None: + fluid.io.save_inference_model(save_dirname, [ + "user_id", "gender_id", "age_id", "job_id", + "movie_id", "category_id", "movie_title" + ], [scale_infer], exe) + return + if math.isnan(float(out[0])): sys.exit("got NaN loss, training failed.") -main() +def infer(use_cuda, save_dirname=None): + if save_dirname is None: + return + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + def create_lod_tensor(data, lod=None): + tensor = fluid.LoDTensor() + if lod is None: + # Tensor, the shape is [batch_size, 1] + index = 0 + lod_0 = [index] + for l in range(len(data)): + index += 1 + lod_0.append(index) + lod = [lod_0] + tensor.set_lod(lod) + + flattened_data = np.concatenate(data, axis=0).astype("int64") + flattened_data = flattened_data.reshape([len(flattened_data), 1]) + tensor.set(flattened_data, place) + return tensor + + # Use the first data from paddle.dataset.movielens.test() as input + assert feed_target_names[0] == "user_id" + user_id = create_lod_tensor([[1]]) + + assert feed_target_names[1] == "gender_id" + gender_id = create_lod_tensor([[1]]) + + assert feed_target_names[2] == "age_id" + age_id = create_lod_tensor([[0]]) + + assert feed_target_names[3] == "job_id" + job_id = create_lod_tensor([[10]]) + + assert feed_target_names[4] == "movie_id" + movie_id = create_lod_tensor([[783]]) + + assert feed_target_names[5] == "category_id" + category_id = create_lod_tensor([[10], [8], [9]], [[0, 3]]) + + assert feed_target_names[6] == "movie_title" + movie_title = create_lod_tensor([[1069], [4140], [2923], [710], [988]], + [[0, 5]]) + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + results = exe.run(inference_program, + feed={ + feed_target_names[0]: user_id, + feed_target_names[1]: gender_id, + feed_target_names[2]: age_id, + feed_target_names[3]: job_id, + feed_target_names[4]: movie_id, + feed_target_names[5]: category_id, + feed_target_names[6]: movie_title + }, + fetch_list=fetch_targets, + return_numpy=False) + print("inferred score: ", np.array(results[0])) + + +def main(use_cuda): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + + # Directory for saving the inference model + save_dirname = "recommender_system.inference.model" + + train(use_cuda, save_dirname) + infer(use_cuda, save_dirname) + + +if __name__ == '__main__': + main(USE_GPU) -- GitLab