diff --git a/paddle/inference/tests/book/CMakeLists.txt b/paddle/inference/tests/book/CMakeLists.txt index 5d065e53b2d3afcdc4ea78515b4c10d66c768b95..ca3c056b0978c1fb2d6057d28fed761833d34fe4 100644 --- a/paddle/inference/tests/book/CMakeLists.txt +++ b/paddle/inference/tests/book/CMakeLists.txt @@ -30,3 +30,4 @@ inference_test(image_classification ARGS vgg resnet) inference_test(label_semantic_roles) inference_test(rnn_encoder_decoder) inference_test(recommender_system) +inference_test(understand_sentiment) diff --git a/paddle/inference/tests/book/test_inference_understand_sentiment.cc b/paddle/inference/tests/book/test_inference_understand_sentiment.cc new file mode 100644 index 0000000000000000000000000000000000000000..1afb644446569fc4fea8d7e9bf806daec016a5b2 --- /dev/null +++ b/paddle/inference/tests/book/test_inference_understand_sentiment.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "gflags/gflags.h" +#include "test_helper.h" + +DEFINE_string(dirname, "", "Directory of the inference model."); + +TEST(inference, understand_sentiment) { + if (FLAGS_dirname.empty()) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; + } + + LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; + std::string dirname = FLAGS_dirname; + + // 0. Call `paddle::framework::InitDevices()` initialize all the devices + // In unittests, this is done in paddle/testing/paddle_gtest_main.cc + + paddle::framework::LoDTensor words; + paddle::framework::LoD lod{{0, 4, 10}}; + SetupLoDTensor(words, lod, static_cast(0), static_cast(10)); + + std::vector cpu_feeds; + cpu_feeds.push_back(&words); + + paddle::framework::LoDTensor output1; + std::vector cpu_fetchs1; + cpu_fetchs1.push_back(&output1); + + // Run inference on CPU + TestInference(dirname, cpu_feeds, cpu_fetchs1); + LOG(INFO) << output1.lod(); + LOG(INFO) << output1.dims(); + +#ifdef PADDLE_WITH_CUDA + paddle::framework::LoDTensor output2; + std::vector cpu_fetchs2; + cpu_fetchs2.push_back(&output2); + + // Run inference on CUDA GPU + TestInference(dirname, cpu_feeds, cpu_fetchs2); + LOG(INFO) << output2.lod(); + LOG(INFO) << output2.dims(); + + CheckError(output1, output2); +#endif +} diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment.py index 9c5cb667aed7456b54d32dcd650852cfdbd6cce1..6e0206d41db6265e926991fe35cd53513bd3417d 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment.py @@ -17,6 +17,7 @@ import paddle.v2.fluid as fluid import paddle.v2 as paddle import contextlib import math +import numpy as np import sys @@ -43,7 +44,7 @@ def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32, adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002) adam_optimizer.minimize(avg_cost) accuracy = fluid.layers.accuracy(input=prediction, label=label) - return avg_cost, accuracy + return avg_cost, accuracy, prediction def stacked_lstm_net(data, @@ -81,13 +82,18 @@ def stacked_lstm_net(data, adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002) adam_optimizer.minimize(avg_cost) accuracy = fluid.layers.accuracy(input=prediction, label=label) - return avg_cost, accuracy + return avg_cost, accuracy, prediction -def main(word_dict, net_method, use_cuda): - if use_cuda and not fluid.core.is_compiled_with_cuda(): - return +def create_random_lodtensor(lod, place, low, high): + data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64") + res = fluid.LoDTensor() + res.set(data, place) + res.set_lod([lod]) + return res + +def train(word_dict, net_method, use_cuda, save_dirname=None): BATCH_SIZE = 128 PASS_NUM = 5 dict_dim = len(word_dict) @@ -96,7 +102,7 @@ def main(word_dict, net_method, use_cuda): data = fluid.layers.data( name="words", shape=[1], dtype="int64", lod_level=1) label = fluid.layers.data(name="label", shape=[1], dtype="int64") - cost, acc_out = net_method( + cost, acc_out, prediction = net_method( data, label, input_dim=dict_dim, class_dim=class_dim) train_data = paddle.batch( @@ -116,6 +122,9 @@ def main(word_dict, net_method, use_cuda): fetch_list=[cost, acc_out]) print("cost=" + str(cost_val) + " acc=" + str(acc_val)) if cost_val < 0.4 and acc_val > 0.8: + if save_dirname is not None: + fluid.io.save_inference_model(save_dirname, ["words"], + prediction, exe) return if math.isnan(float(cost_val)): sys.exit("got NaN loss, training failed.") @@ -123,6 +132,49 @@ def main(word_dict, net_method, use_cuda): net_method.__name__)) +def infer(use_cuda, save_dirname=None): + if save_dirname is None: + return + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + lod = [0, 4, 10] + word_dict = paddle.dataset.imdb.word_dict() + tensor_words = create_random_lodtensor( + lod, place, low=0, high=len(word_dict) - 1) + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + assert feed_target_names[0] == "words" + results = exe.run(inference_program, + feed={feed_target_names[0]: tensor_words}, + fetch_list=fetch_targets, + return_numpy=False) + print(results[0].lod()) + np_data = np.array(results[0]) + print("Inference Shape: ", np_data.shape) + print("Inference results: ", np_data) + + +def main(word_dict, net_method, use_cuda): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + + # Directory for saving the trained model + save_dirname = "understand_sentiment.inference.model" + + train(word_dict, net_method, use_cuda, save_dirname) + infer(use_cuda, save_dirname) + + class TestUnderstandSentiment(unittest.TestCase): @classmethod def setUpClass(cls):