diff --git a/paddle/inference/tests/book/CMakeLists.txt b/paddle/inference/tests/book/CMakeLists.txt index ca3c056b0978c1fb2d6057d28fed761833d34fe4..9fe76afb582a13b741ab086f0c62d77e86d4e8bb 100644 --- a/paddle/inference/tests/book/CMakeLists.txt +++ b/paddle/inference/tests/book/CMakeLists.txt @@ -25,9 +25,10 @@ function(inference_test TARGET_NAME) endfunction(inference_test) inference_test(fit_a_line) -inference_test(recognize_digits ARGS mlp) inference_test(image_classification ARGS vgg resnet) inference_test(label_semantic_roles) -inference_test(rnn_encoder_decoder) +inference_test(recognize_digits ARGS mlp) inference_test(recommender_system) +inference_test(rnn_encoder_decoder) inference_test(understand_sentiment) +inference_test(word2vec) diff --git a/paddle/inference/tests/book/test_helper.h b/paddle/inference/tests/book/test_helper.h index 22ce903c7250b84fd0b08e82cfda03df411a3068..02104306e71b9ec06c0fef0f7417bbde8f6ff88a 100644 --- a/paddle/inference/tests/book/test_helper.h +++ b/paddle/inference/tests/book/test_helper.h @@ -91,7 +91,7 @@ template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, std::vector& cpu_fetchs) { - // 1. Define place, executor, scope and inference_program + // 1. Define place, executor, scope auto place = Place(); auto executor = paddle::framework::Executor(place); auto* scope = new paddle::framework::Scope(); @@ -101,7 +101,8 @@ void TestInference(const std::string& dirname, if (IsCombined) { // All parameters are saved in a single file. // Hard-coding the file names of program and parameters in unittest. - // Users are free to specify different filename. + // Users are free to specify different filename + // (provided: the filenames are changed in the python api as well: io.py) std::string prog_filename = "__model_combined__"; std::string param_filename = "__params_combined__"; inference_program = paddle::inference::Load(executor, diff --git a/paddle/inference/tests/book/test_inference_word2vec.cc b/paddle/inference/tests/book/test_inference_word2vec.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca0c040ff629c55b3392d67628ba93226c04a0ce --- /dev/null +++ b/paddle/inference/tests/book/test_inference_word2vec.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "gflags/gflags.h" +#include "test_helper.h" + +DEFINE_string(dirname, "", "Directory of the inference model."); + +TEST(inference, word2vec) { + if (FLAGS_dirname.empty()) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; + } + + LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; + std::string dirname = FLAGS_dirname; + + // 0. Call `paddle::framework::InitDevices()` initialize all the devices + // In unittests, this is done in paddle/testing/paddle_gtest_main.cc + + paddle::framework::LoDTensor first_word, second_word, third_word, fourth_word; + paddle::framework::LoD lod{{0, 1}}; + int64_t dict_size = 2072; // Hard-coding the size of dictionary + + SetupLoDTensor(first_word, lod, static_cast(0), dict_size); + SetupLoDTensor(second_word, lod, static_cast(0), dict_size); + SetupLoDTensor(third_word, lod, static_cast(0), dict_size); + SetupLoDTensor(fourth_word, lod, static_cast(0), dict_size); + + std::vector cpu_feeds; + cpu_feeds.push_back(&first_word); + cpu_feeds.push_back(&second_word); + cpu_feeds.push_back(&third_word); + cpu_feeds.push_back(&fourth_word); + + paddle::framework::LoDTensor output1; + std::vector cpu_fetchs1; + cpu_fetchs1.push_back(&output1); + + // Run inference on CPU + TestInference(dirname, cpu_feeds, cpu_fetchs1); + LOG(INFO) << output1.lod(); + LOG(INFO) << output1.dims(); + +#ifdef PADDLE_WITH_CUDA + paddle::framework::LoDTensor output2; + std::vector cpu_fetchs2; + cpu_fetchs2.push_back(&output2); + + // Run inference on CUDA GPU + TestInference(dirname, cpu_feeds, cpu_fetchs2); + LOG(INFO) << output2.lod(); + LOG(INFO) << output2.dims(); + + CheckError(output1, output2); +#endif +} diff --git a/python/paddle/v2/fluid/tests/book/test_word2vec.py b/python/paddle/v2/fluid/tests/book/test_word2vec.py index f013d7f1551bdbfb2f725809e2fb4d7d686560fe..69bfbcee69a08f57e4754f1a94f85534be4baac6 100644 --- a/python/paddle/v2/fluid/tests/book/test_word2vec.py +++ b/python/paddle/v2/fluid/tests/book/test_word2vec.py @@ -1,6 +1,5 @@ # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -16,14 +15,67 @@ import paddle.v2 as paddle import paddle.v2.fluid as fluid import unittest import os +import numpy as np import math import sys -def main(use_cuda, is_sparse, parallel): - if use_cuda and not fluid.core.is_compiled_with_cuda(): +def create_random_lodtensor(lod, place, low, high): + data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64") + res = fluid.LoDTensor() + res.set(data, place) + res.set_lod([lod]) + return res + + +def infer(use_cuda, save_dirname=None): + if save_dirname is None: return + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + word_dict = paddle.dataset.imikolov.build_dict() + dict_size = len(word_dict) - 1 + + # Setup input, by creating 4 words, and setting up lod required for + # lookup_table_op + lod = [0, 1] + first_word = create_random_lodtensor(lod, place, low=0, high=dict_size) + second_word = create_random_lodtensor(lod, place, low=0, high=dict_size) + third_word = create_random_lodtensor(lod, place, low=0, high=dict_size) + fourth_word = create_random_lodtensor(lod, place, low=0, high=dict_size) + + assert feed_target_names[0] == 'firstw' + assert feed_target_names[1] == 'secondw' + assert feed_target_names[2] == 'thirdw' + assert feed_target_names[3] == 'forthw' + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + results = exe.run(inference_program, + feed={ + feed_target_names[0]: first_word, + feed_target_names[1]: second_word, + feed_target_names[2]: third_word, + feed_target_names[3]: fourth_word + }, + fetch_list=fetch_targets, + return_numpy=False) + print(results[0].lod()) + np_data = np.array(results[0]) + print("Inference Shape: ", np_data.shape) + print("Inference results: ", np_data) + + +def train(use_cuda, is_sparse, parallel, save_dirname): PASS_NUM = 100 EMBED_SIZE = 32 HIDDEN_SIZE = 256 @@ -67,7 +119,7 @@ def main(use_cuda, is_sparse, parallel): act='softmax') cost = fluid.layers.cross_entropy(input=predict_word, label=words[4]) avg_cost = fluid.layers.mean(x=cost) - return avg_cost + return avg_cost, predict_word word_dict = paddle.dataset.imikolov.build_dict() dict_size = len(word_dict) @@ -79,13 +131,13 @@ def main(use_cuda, is_sparse, parallel): next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64') if not parallel: - avg_cost = __network__( + avg_cost, predict_word = __network__( [first_word, second_word, third_word, forth_word, next_word]) else: places = fluid.layers.get_places() pd = fluid.layers.ParallelDo(places) with pd.do(): - avg_cost = __network__( + avg_cost, predict_word = __network__( map(pd.read_input, [ first_word, second_word, third_word, forth_word, next_word ])) @@ -113,6 +165,10 @@ def main(use_cuda, is_sparse, parallel): feed=feeder.feed(data), fetch_list=[avg_cost]) if avg_cost_np[0] < 5.0: + if save_dirname is not None: + fluid.io.save_inference_model(save_dirname, [ + 'firstw', 'secondw', 'thirdw', 'forthw' + ], [predict_word], exe) return if math.isnan(float(avg_cost_np[0])): sys.exit("got NaN loss, training failed.") @@ -120,6 +176,14 @@ def main(use_cuda, is_sparse, parallel): raise AssertionError("Cost is too large {0:2.2}".format(avg_cost_np[0])) +def main(use_cuda, is_sparse, parallel): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + save_dirname = "word2vec.inference.model" + train(use_cuda, is_sparse, parallel, save_dirname) + infer(use_cuda, save_dirname) + + FULL_TEST = os.getenv('FULL_TEST', '0').lower() in ['true', '1', 't', 'y', 'yes', 'on'] SKIP_REASON = "Only run minimum number of tests in CI server, to make CI faster" @@ -142,7 +206,8 @@ def inject_test_method(use_cuda, is_sparse, parallel): with fluid.program_guard(prog, startup_prog): main(use_cuda=use_cuda, is_sparse=is_sparse, parallel=parallel) - if use_cuda and is_sparse and parallel: + # run only 2 cases: use_cuda is either True or False + if is_sparse == False and parallel == False: fn = __impl__ else: # skip the other test when on CI server