diff --git a/paddle/inference/tests/book/CMakeLists.txt b/paddle/inference/tests/book/CMakeLists.txt index 0e987eb0240301c58cfb74c9e995d3b525130125..4c71517dc987456c6ccfa85d4c84ab40193d15c3 100644 --- a/paddle/inference/tests/book/CMakeLists.txt +++ b/paddle/inference/tests/book/CMakeLists.txt @@ -3,5 +3,17 @@ cc_test(test_inference_recognize_digits_mlp SRCS test_inference_recognize_digits.cc DEPS ARCHIVE_START paddle_fluid ARCHIVE_END ARGS --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model) +cc_test(test_inference_image_classification_vgg + SRCS test_inference_image_classification.cc + DEPS ARCHIVE_START paddle_fluid ARCHIVE_END + ARGS --dirname=${PYTHON_TESTS_DIR}/book/image_classification_vgg.inference.model) +cc_test(test_inference_image_classification_resnet + SRCS test_inference_image_classification.cc + DEPS ARCHIVE_START paddle_fluid ARCHIVE_END + ARGS --dirname=${PYTHON_TESTS_DIR}/book/image_classification_resnet.inference.model) set_tests_properties(test_inference_recognize_digits_mlp PROPERTIES DEPENDS test_recognize_digits) +set_tests_properties(test_inference_image_classification_vgg + PROPERTIES DEPENDS test_image_classification_train) +set_tests_properties(test_inference_image_classification_resnet + PROPERTIES DEPENDS test_image_classification_train) diff --git a/paddle/inference/tests/book/test_inference_image_classification.cc b/paddle/inference/tests/book/test_inference_image_classification.cc new file mode 100644 index 0000000000000000000000000000000000000000..e01f5b312a097ce3d7b20ce74e3803c79d942e51 --- /dev/null +++ b/paddle/inference/tests/book/test_inference_image_classification.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "gflags/gflags.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/inference/io.h" + +DEFINE_string(dirname, "", "Directory of the inference model."); + +template +void TestInference(const std::string& dirname, + const std::vector& cpu_feeds, + std::vector& cpu_fetchs) { + // 1. Define place, executor and scope + auto place = Place(); + auto executor = paddle::framework::Executor(place); + auto* scope = new paddle::framework::Scope(); + + // 2. Initialize the inference_program and load all parameters from file + auto inference_program = paddle::inference::Load(executor, *scope, dirname); + + // 3. Get the feed_target_names and fetch_target_names + const std::vector& feed_target_names = + inference_program->GetFeedTargetNames(); + const std::vector& fetch_target_names = + inference_program->GetFetchTargetNames(); + + // 4. Prepare inputs: set up maps for feed targets + std::map feed_targets; + for (size_t i = 0; i < feed_target_names.size(); ++i) { + // Please make sure that cpu_feeds[i] is right for feed_target_names[i] + feed_targets[feed_target_names[i]] = cpu_feeds[i]; + } + + // 5. Define Tensor to get the outputs: set up maps for fetch targets + std::map fetch_targets; + for (size_t i = 0; i < fetch_target_names.size(); ++i) { + fetch_targets[fetch_target_names[i]] = cpu_fetchs[i]; + } + + // 6. Run the inference program + executor.Run(*inference_program, scope, feed_targets, fetch_targets); + + delete scope; +} + +TEST(inference, image_classification) { + if (FLAGS_dirname.empty()) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; + } + + LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; + std::string dirname = FLAGS_dirname; + + // 0. Call `paddle::framework::InitDevices()` initialize all the devices + // In unittests, this is done in paddle/testing/paddle_gtest_main.cc + + paddle::framework::LoDTensor input; + srand(time(0)); + float* input_ptr = + input.mutable_data({1, 3, 32, 32}, paddle::platform::CPUPlace()); + for (int i = 0; i < 3072; ++i) { + input_ptr[i] = rand() / (static_cast(RAND_MAX)); + } + std::vector cpu_feeds; + cpu_feeds.push_back(&input); + + paddle::framework::LoDTensor output1; + std::vector cpu_fetchs1; + cpu_fetchs1.push_back(&output1); + + // Run inference on CPU + TestInference( + dirname, cpu_feeds, cpu_fetchs1); + LOG(INFO) << output1.dims(); + +#ifdef PADDLE_WITH_CUDA + paddle::framework::LoDTensor output2; + std::vector cpu_fetchs2; + cpu_fetchs2.push_back(&output2); + + // Run inference on CUDA GPU + TestInference( + dirname, cpu_feeds, cpu_fetchs2); + LOG(INFO) << output2.dims(); + + EXPECT_EQ(output1.dims(), output2.dims()); + EXPECT_EQ(output1.numel(), output2.numel()); + + float err = 1E-3; + int count = 0; + for (int64_t i = 0; i < output1.numel(); ++i) { + if (fabs(output1.data()[i] - output2.data()[i]) > err) { + count++; + } + } + EXPECT_EQ(count, 0) << "There are " << count << " different elements."; +#endif +} diff --git a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py index a4168d16db06f904faed811fdda3f0fe52f0b27b..03b009ebb0714a91329a1c56ff3939beecb03435 100644 --- a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py +++ b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py @@ -16,8 +16,9 @@ from __future__ import print_function import paddle.v2 as paddle import paddle.v2.fluid as fluid -import unittest import contextlib +import numpy +import unittest def resnet_cifar10(input, depth=32): @@ -89,10 +90,7 @@ def vgg16_bn_drop(input): return fc2 -def main(net_type, use_cuda): - if use_cuda and not fluid.core.is_compiled_with_cuda(): - return - +def train(net_type, use_cuda, save_dirname): classdim = 10 data_shape = [3, 32, 32] @@ -111,12 +109,14 @@ def main(net_type, use_cuda): predict = fluid.layers.fc(input=net, size=classdim, act='softmax') cost = fluid.layers.cross_entropy(input=predict, label=label) avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=predict, label=label) + + # Test program + test_program = fluid.default_main_program().clone() optimizer = fluid.optimizer.Adam(learning_rate=0.001) optimizer.minimize(avg_cost) - accuracy = fluid.evaluator.Accuracy(input=predict, label=label) - BATCH_SIZE = 128 PASS_NUM = 1 @@ -125,6 +125,9 @@ def main(net_type, use_cuda): paddle.dataset.cifar.train10(), buf_size=128 * 10), batch_size=BATCH_SIZE) + test_reader = paddle.batch( + paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE) + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) @@ -132,18 +135,68 @@ def main(net_type, use_cuda): loss = 0.0 for pass_id in range(PASS_NUM): - accuracy.reset(exe) - for data in train_reader(): - loss, acc = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost] + accuracy.metrics) - pass_acc = accuracy.eval(exe) - print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( - pass_acc)) - return - - raise AssertionError( - "Image classification loss is too large, {0:2.2}".format(loss)) + for batch_id, data in enumerate(train_reader()): + exe.run(feed=feeder.feed(data)) + + if (batch_id % 10) == 0: + acc_list = [] + avg_loss_list = [] + for tid, test_data in enumerate(test_reader()): + loss_t, acc_t = exe.run(program=test_program, + feed=feeder.feed(test_data), + fetch_list=[avg_cost, acc]) + acc_list.append(float(acc_t)) + avg_loss_list.append(float(loss_t)) + break # Use 1 segment for speeding up CI + + acc_value = numpy.array(acc_list).mean() + avg_loss_value = numpy.array(avg_loss_list).mean() + + print( + 'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'. + format(pass_id, batch_id + 1, + float(avg_loss_value), float(acc_value))) + + if acc_value > 0.01: # Low threshold for speeding up CI + fluid.io.save_inference_model(save_dirname, ["pixel"], + [predict], exe) + return + + +def infer(use_cuda, save_dirname=None): + if save_dirname is None: + return + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + # The input's dimension of conv should be 4-D or 5-D. + tensor_img = numpy.random.rand(1, 3, 32, 32).astype("float32") + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + results = exe.run(inference_program, + feed={feed_target_names[0]: tensor_img}, + fetch_list=fetch_targets) + print("infer results: ", results[0]) + + +def main(net_type, use_cuda): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + + # Directory for saving the trained model + save_dirname = "image_classification_" + net_type + ".inference.model" + + train(net_type, use_cuda, save_dirname) + infer(use_cuda, save_dirname) class TestImageClassification(unittest.TestCase):