未验证 提交 b1869f16 编写于 作者: Y Yiqun Liu 提交者: GitHub

Simplify the inference unittests' cmake and codes. (#8216)

上级 47c13508
set(PYTHON_TESTS_DIR ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/tests)
cc_test(test_inference_recognize_digits_mlp
SRCS test_inference_recognize_digits.cc
function(inference_test TARGET_NAME)
set(options "")
set(oneValueArgs "")
set(multiValueArgs ARGS)
cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(PYTHON_TESTS_DIR ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/tests)
if(inference_test_ARGS)
foreach(arg ${inference_test_ARGS})
cc_test(test_inference_${TARGET_NAME}_${arg}
SRCS test_inference_${TARGET_NAME}.cc
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model)
cc_test(test_inference_image_classification_vgg
SRCS test_inference_image_classification.cc
ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}_${arg}.inference.model)
set_tests_properties(test_inference_${TARGET_NAME}_${arg}
PROPERTIES DEPENDS test_${TARGET_NAME})
endforeach()
else()
cc_test(test_inference_${TARGET_NAME}
SRCS test_inference_${TARGET_NAME}.cc
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=${PYTHON_TESTS_DIR}/book/image_classification_vgg.inference.model)
cc_test(test_inference_image_classification_resnet
SRCS test_inference_image_classification.cc
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=${PYTHON_TESTS_DIR}/book/image_classification_resnet.inference.model)
cc_test(test_inference_label_semantic_roles
SRCS test_inference_label_semantic_roles.cc
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=${PYTHON_TESTS_DIR}/book/label_semantic_roles.inference.model)
set_tests_properties(test_inference_recognize_digits_mlp
PROPERTIES DEPENDS test_recognize_digits)
set_tests_properties(test_inference_image_classification_vgg
PROPERTIES DEPENDS test_image_classification_train)
set_tests_properties(test_inference_image_classification_resnet
PROPERTIES DEPENDS test_image_classification_train)
set_tests_properties(test_inference_label_semantic_roles
PROPERTIES DEPENDS test_label_semantic_roles)
ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}.inference.model)
set_tests_properties(test_inference_${TARGET_NAME}
PROPERTIES DEPENDS test_${TARGET_NAME})
endif()
endfunction(inference_test)
inference_test(recognize_digits ARGS mlp)
inference_test(image_classification ARGS vgg resnet)
inference_test(label_semantic_roles)
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <time.h>
#include "paddle/framework/lod_tensor.h"
#include "paddle/inference/io.h"
......
......@@ -13,51 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <time.h>
#include <sstream>
#include "gflags/gflags.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/inference/io.h"
#include "test_helper.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
template <typename Place, typename T>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
// 1. Define place, executor and scope
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
// 2. Initialize the inference_program and load all parameters from file
auto inference_program = paddle::inference::Load(executor, *scope, dirname);
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
inference_program->GetFeedTargetNames();
const std::vector<std::string>& fetch_target_names =
inference_program->GetFetchTargetNames();
// 4. Prepare inputs: set up maps for feed targets
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
for (size_t i = 0; i < feed_target_names.size(); ++i) {
// Please make sure that cpu_feeds[i] is right for feed_target_names[i]
feed_targets[feed_target_names[i]] = cpu_feeds[i];
}
// 5. Define Tensor to get the outputs: set up maps for fetch targets
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
for (size_t i = 0; i < fetch_target_names.size(); ++i) {
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
}
// 6. Run the inference program
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
delete scope;
}
TEST(inference, image_classification) {
if (FLAGS_dirname.empty()) {
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
......@@ -70,12 +30,10 @@ TEST(inference, image_classification) {
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
paddle::framework::LoDTensor input;
srand(time(0));
float* input_ptr =
input.mutable_data<float>({1, 3, 32, 32}, paddle::platform::CPUPlace());
for (int i = 0; i < 3072; ++i) {
input_ptr[i] = rand() / (static_cast<float>(RAND_MAX));
}
// Use normilized image pixels as input data,
// which should be in the range [0.0, 1.0].
SetupTensor<float>(
input, {1, 3, 32, 32}, static_cast<float>(0), static_cast<float>(1));
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&input);
......@@ -98,16 +56,6 @@ TEST(inference, image_classification) {
dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.dims();
EXPECT_EQ(output1.dims(), output2.dims());
EXPECT_EQ(output1.numel(), output2.numel());
float err = 1E-3;
int count = 0;
for (int64_t i = 0; i < output1.numel(); ++i) {
if (fabs(output1.data<float>()[i] - output2.data<float>()[i]) > err) {
count++;
}
}
EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
CheckError<float>(output1, output2);
#endif
}
......@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <time.h>
#include <sstream>
#include "gflags/gflags.h"
#include "test_helper.h"
......
......@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <time.h>
#include <sstream>
#include "gflags/gflags.h"
#include "test_helper.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册