提交 e5832019 编写于 作者: K kexinzhao 提交者: Yiqun Liu

Inference example and unit test for label_semantic_roles (#8058)

* set up python code

* fix bug

* add cc file

* fix cmake

* add inference test for label semantic role

* fix

* address comments

* address comments

* address comments

* address comments

* add use_cuda
上级 f2154002
......@@ -11,9 +11,15 @@ cc_test(test_inference_image_classification_resnet
SRCS test_inference_image_classification.cc
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=${PYTHON_TESTS_DIR}/book/image_classification_resnet.inference.model)
cc_test(test_inference_label_semantic_roles
SRCS test_inference_label_semantic_roles.cc
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=${PYTHON_TESTS_DIR}/book/label_semantic_roles.inference.model)
set_tests_properties(test_inference_recognize_digits_mlp
PROPERTIES DEPENDS test_recognize_digits)
set_tests_properties(test_inference_image_classification_vgg
PROPERTIES DEPENDS test_image_classification_train)
set_tests_properties(test_inference_image_classification_resnet
PROPERTIES DEPENDS test_image_classification_train)
set_tests_properties(test_inference_label_semantic_roles
PROPERTIES DEPENDS test_label_semantic_roles)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/inference/io.h"
template <typename T>
void SetupTensor(paddle::framework::LoDTensor& input,
paddle::framework::DDim dims,
T lower,
T upper) {
srand(time(0));
T* input_ptr = input.mutable_data<T>(dims, paddle::platform::CPUPlace());
for (int i = 0; i < input.numel(); ++i) {
input_ptr[i] =
(static_cast<T>(rand()) / static_cast<T>(RAND_MAX)) * (upper - lower) +
lower;
}
}
template <typename T>
void SetupLoDTensor(paddle::framework::LoDTensor& input,
paddle::framework::LoD& lod,
T lower,
T upper) {
input.set_lod(lod);
int dim = lod[0][lod[0].size() - 1];
SetupTensor(input, {dim, 1}, lower, upper);
}
template <typename T>
void CheckError(paddle::framework::LoDTensor& output1,
paddle::framework::LoDTensor& output2) {
// Check lod information
EXPECT_EQ(output1.lod(), output2.lod());
EXPECT_EQ(output1.dims(), output2.dims());
EXPECT_EQ(output1.numel(), output2.numel());
T err = static_cast<T>(0);
if (typeid(T) == typeid(float)) {
err = 1E-3;
} else if (typeid(T) == typeid(double)) {
err = 1E-6;
} else {
err = 0;
}
size_t count = 0;
for (int64_t i = 0; i < output1.numel(); ++i) {
if (fabs(output1.data<T>()[i] - output2.data<T>()[i]) > err) {
count++;
}
}
EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
}
template <typename Place, typename T>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
// 1. Define place, executor and scope
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
// 2. Initialize the inference_program and load all parameters from file
auto inference_program = paddle::inference::Load(executor, *scope, dirname);
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
inference_program->GetFeedTargetNames();
const std::vector<std::string>& fetch_target_names =
inference_program->GetFetchTargetNames();
// 4. Prepare inputs: set up maps for feed targets
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
for (size_t i = 0; i < feed_target_names.size(); ++i) {
// Please make sure that cpu_feeds[i] is right for feed_target_names[i]
feed_targets[feed_target_names[i]] = cpu_feeds[i];
}
// 5. Define Tensor to get the outputs: set up maps for fetch targets
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
for (size_t i = 0; i < fetch_target_names.size(); ++i) {
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
}
// 6. Run the inference program
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
delete scope;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <time.h>
#include <sstream>
#include "gflags/gflags.h"
#include "test_helper.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
TEST(inference, label_semantic_roles) {
if (FLAGS_dirname.empty()) {
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
}
LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::string dirname = FLAGS_dirname;
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
paddle::framework::LoDTensor word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1,
ctx_p2, mark;
paddle::framework::LoD lod{{0, 4, 10}};
SetupLoDTensor(word, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(
predicate, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_n2, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_n1, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_0, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_p1, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_p2, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(mark, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&word);
cpu_feeds.push_back(&predicate);
cpu_feeds.push_back(&ctx_n2);
cpu_feeds.push_back(&ctx_n1);
cpu_feeds.push_back(&ctx_0);
cpu_feeds.push_back(&ctx_p1);
cpu_feeds.push_back(&ctx_p2);
cpu_feeds.push_back(&mark);
paddle::framework::LoDTensor output1;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
TestInference<paddle::platform::CPUPlace, float>(
dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.lod();
LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA
paddle::framework::LoDTensor output2;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs2;
cpu_fetchs2.push_back(&output2);
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace, float>(
dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.lod();
LOG(INFO) << output2.dims();
CheckError<float>(output1, output2);
#endif
}
......@@ -16,89 +16,10 @@ limitations under the License. */
#include <time.h>
#include <sstream>
#include "gflags/gflags.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/inference/io.h"
#include "test_helper.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
template <typename Place, typename T>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
// 1. Define place, executor and scope
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
// 2. Initialize the inference_program and load all parameters from file
auto inference_program = paddle::inference::Load(executor, *scope, dirname);
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
inference_program->GetFeedTargetNames();
const std::vector<std::string>& fetch_target_names =
inference_program->GetFetchTargetNames();
// 4. Prepare inputs: set up maps for feed targets
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
for (size_t i = 0; i < feed_target_names.size(); ++i) {
// Please make sure that cpu_feeds[i] is right for feed_target_names[i]
feed_targets[feed_target_names[i]] = cpu_feeds[i];
}
// 5. Define Tensor to get the outputs: set up maps for fetch targets
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
for (size_t i = 0; i < fetch_target_names.size(); ++i) {
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
}
// 6. Run the inference program
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
delete scope;
}
template <typename T>
void SetupTensor(paddle::framework::LoDTensor& input,
paddle::framework::DDim dims,
T lower,
T upper) {
srand(time(0));
float* input_ptr = input.mutable_data<T>(dims, paddle::platform::CPUPlace());
for (int i = 0; i < input.numel(); ++i) {
input_ptr[i] =
(static_cast<T>(rand()) / static_cast<T>(RAND_MAX)) * (upper - lower) +
lower;
}
}
template <typename T>
void CheckError(paddle::framework::LoDTensor& output1,
paddle::framework::LoDTensor& output2) {
// Check lod information
EXPECT_EQ(output1.lod(), output2.lod());
EXPECT_EQ(output1.dims(), output2.dims());
EXPECT_EQ(output1.numel(), output2.numel());
T err = static_cast<T>(0);
if (typeid(T) == typeid(float)) {
err = 1E-3;
} else if (typeid(T) == typeid(double)) {
err = 1E-6;
} else {
err = 0;
}
size_t count = 0;
for (int64_t i = 0; i < output1.numel(); ++i) {
if (fabs(output1.data<T>()[i] - output2.data<T>()[i]) > err) {
count++;
}
}
EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
}
TEST(inference, recognize_digits) {
if (FLAGS_dirname.empty()) {
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
......
......@@ -18,7 +18,9 @@ import numpy as np
import paddle.v2 as paddle
import paddle.v2.dataset.conll05 as conll05
import paddle.v2.fluid as fluid
import contextlib
import time
import unittest
word_dict, verb_dict, label_dict = conll05.get_dict()
word_dict_len = len(word_dict)
......@@ -127,7 +129,15 @@ def to_lodtensor(data, place):
return res
def main():
def create_random_lodtensor(lod, place, low, high):
data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64")
res = fluid.LoDTensor()
res.set(data, place)
res.set_lod([lod])
return res
def train(use_cuda, save_dirname=None):
# define network topology
word = fluid.layers.data(
name='word_data', shape=[1], dtype='int64', lod_level=1)
......@@ -175,8 +185,8 @@ def main():
paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192),
batch_size=BATCH_SIZE)
# place = fluid.CPUPlace()
place = fluid.CUDAPlace(0)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
feeder = fluid.DataFeeder(
feed_list=[
word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate, mark, target
......@@ -211,12 +221,102 @@ def main():
if batch_id != 0:
print("second per batch: " + str((time.time() - start_time)
/ batch_id))
# exit early for CI
exit(0)
# Set the threshold low to speed up the CI test
if float(pass_precision) > 0.05:
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname, [
'word_data', 'verb_data', 'ctx_n2_data',
'ctx_n1_data', 'ctx_0_data', 'ctx_p1_data',
'ctx_p2_data', 'mark_data'
], [feature_out], exe)
return
batch_id = batch_id + 1
def infer(use_cuda, save_dirname=None):
if save_dirname is None:
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
lod = [0, 4, 10]
ts_word = create_random_lodtensor(lod, place, low=0, high=1)
ts_pred = create_random_lodtensor(lod, place, low=0, high=1)
ts_ctx_n2 = create_random_lodtensor(lod, place, low=0, high=1)
ts_ctx_n1 = create_random_lodtensor(lod, place, low=0, high=1)
ts_ctx_0 = create_random_lodtensor(lod, place, low=0, high=1)
ts_ctx_p1 = create_random_lodtensor(lod, place, low=0, high=1)
ts_ctx_p2 = create_random_lodtensor(lod, place, low=0, high=1)
ts_mark = create_random_lodtensor(lod, place, low=0, high=1)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
assert feed_target_names[0] == 'word_data'
assert feed_target_names[1] == 'verb_data'
assert feed_target_names[2] == 'ctx_n2_data'
assert feed_target_names[3] == 'ctx_n1_data'
assert feed_target_names[4] == 'ctx_0_data'
assert feed_target_names[5] == 'ctx_p1_data'
assert feed_target_names[6] == 'ctx_p2_data'
assert feed_target_names[7] == 'mark_data'
results = exe.run(inference_program,
feed={
feed_target_names[0]: ts_word,
feed_target_names[1]: ts_pred,
feed_target_names[2]: ts_ctx_n2,
feed_target_names[3]: ts_ctx_n1,
feed_target_names[4]: ts_ctx_0,
feed_target_names[5]: ts_ctx_p1,
feed_target_names[6]: ts_ctx_p2,
feed_target_names[7]: ts_mark
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
print("Inference results: ", np_data)
def main(use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
# Directory for saving the trained model
save_dirname = "label_semantic_roles.inference.model"
train(use_cuda, save_dirname)
infer(use_cuda, save_dirname)
class TestLabelSemanticRoles(unittest.TestCase):
def test_cuda(self):
with self.scope_prog_guard():
main(use_cuda=True)
def test_cpu(self):
with self.scope_prog_guard():
main(use_cuda=False)
@contextlib.contextmanager
def scope_prog_guard(self):
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
if __name__ == '__main__':
main()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册