提交 e00f32c5 编写于 作者: L lijianshe02

add googlenet runtest and fix some kernel bugs

上级 5e2a2555
...@@ -204,4 +204,7 @@ if (WITH_TESTING) ...@@ -204,4 +204,7 @@ if (WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4.tar.gz")
endif() endif()
if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz")
endif()
endif() endif()
...@@ -45,6 +45,11 @@ if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) ...@@ -45,6 +45,11 @@ if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING)
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
--optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz) add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz)
lite_cc_test(test_googlenet_lite SRCS test_googlenet_lite.cc
DEPS cxx_api_lite mir_passes lite_api_test_helper
${ops_lite} ${host_kernels} ${x86_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/googlenet)
add_dependencies(test_googlenet_lite extern_lite_download_GoogleNet_inference_tar_gz)
endif() endif()
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <ctime>
#include <iostream>
#include <vector>
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/api/lite_api_test_helper.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/use_passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/kernels/use_kernels.h"
#include "paddle/fluid/lite/operators/use_ops.h"
// for googlenet
DEFINE_string(model_dir, "", "");
namespace paddle {
namespace lite {
#ifdef LITE_WITH_X86
TEST(CXXApi, test_lite_googlenet) {
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
// LOG(INFO)<<"FLAGS_eval_googlenet_dir:"<<FLAGS_test_lite_googlenet_dir;
std::string model_dir = FLAGS_model_dir;
predictor.Build(model_dir, Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < input_tensor->dims().production(); i++) {
data[i] = 1;
}
predictor.Run();
auto* out = predictor.GetOutput(0);
std::vector<float> results(
{0.00034298553, 0.0008200012, 0.0005046297, 0.000839279,
0.00052616704, 0.0003447803, 0.0010877076, 0.00081762316,
0.0003941339, 0.0011430943, 0.0008892841, 0.00080191303,
0.0004442384, 0.000658702, 0.0026721435, 0.0013686896,
0.0005618166, 0.0006556497, 0.0006984528, 0.0014619455});
for (size_t i = 0; i < results.size(); ++i) {
EXPECT_NEAR(out->data<float>()[i * 51], results[i], 1e-5);
}
ASSERT_EQ(out->dims().size(), 2);
ASSERT_EQ(out->dims()[0], 1);
ASSERT_EQ(out->dims()[1], 1000);
}
#endif
} // namespace lite
} // namespace paddle
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
REGISTER_LITE_KERNEL(concat, kX86, kFloat, kNCHW, REGISTER_LITE_KERNEL(concat, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ConcatCompute<float>, def) paddle::lite::kernels::x86::ConcatCompute<float>, def)
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kX86))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
REGISTER_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW, REGISTER_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::PoolCompute<float>, def) paddle::lite::kernels::x86::PoolCompute<float>, def)
.BindInput("x", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -58,6 +58,7 @@ class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -58,6 +58,7 @@ class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
// auto& context = context_->As<X86Context>(); // auto& context = context_->As<X86Context>();
CHECK(param.output); CHECK(param.output);
CHECK(param.x); CHECK(param.x);
param.output->mutable_data<T>();
const int rank = param.x->dims().size(); const int rank = param.x->dims().size();
const int axis = CanonicalAxis(param.axis, rank); const int axis = CanonicalAxis(param.axis, rank);
int axis_dim = param.x->dims()[axis]; int axis_dim = param.x->dims()[axis];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册