提交 d0f1670e 编写于 作者: S superjomn

make test_apis_lite works

上级 3fa9fff8
...@@ -46,16 +46,34 @@ bool CompareTensors(const std::string& name, const Predictor& cxx_api, ...@@ -46,16 +46,34 @@ bool CompareTensors(const std::string& name, const Predictor& cxx_api,
return TensorCompareWith(*a, *b); return TensorCompareWith(*a, *b);
} }
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK TEST(CXXApi_LightApi, optim_model) {
lite::Predictor cxx_api;
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}, // Both works on X86 and ARM
});
// On ARM devices, the preferred X86 target not works, but it can still
// select ARM kernels.
cxx_api.Build(FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
cxx_api.SaveModel(FLAGS_optimized_model);
}
TEST(CXXApi_LightApi, save_and_load_model) { TEST(CXXApi_LightApi, save_and_load_model) {
lite::Predictor cxx_api; lite::Predictor cxx_api;
lite::LightPredictor light_api(FLAGS_optimized_model); lite::LightPredictor light_api(FLAGS_optimized_model);
// CXXAPi // CXXAPi
{ {
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, std::vector<Place> valid_places({
Place{TARGET(kX86), PRECISION(kFloat)}}); Place{TARGET(kHost), PRECISION(kFloat)},
cxx_api.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)}, Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}, // Both works on X86 and ARM
});
// On ARM devices, the preferred X86 target not works, but it can still
// select ARM kernels.
cxx_api.Build(FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)},
valid_places); valid_places);
auto* x = cxx_api.GetInput(0); auto* x = cxx_api.GetInput(0);
...@@ -87,7 +105,6 @@ TEST(CXXApi_LightApi, save_and_load_model) { ...@@ -87,7 +105,6 @@ TEST(CXXApi_LightApi, save_and_load_model) {
ASSERT_TRUE(CompareTensors(tensor_name, cxx_api, light_api)); ASSERT_TRUE(CompareTensors(tensor_name, cxx_api, light_api));
} }
} }
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -24,13 +24,11 @@ namespace lite { ...@@ -24,13 +24,11 @@ namespace lite {
void Predictor::SaveModel(const std::string &dir) { void Predictor::SaveModel(const std::string &dir) {
#ifndef LITE_WITH_ARM #ifndef LITE_WITH_ARM
LOG(INFO) << "Save model to " << dir;
MkDirRecur(dir); MkDirRecur(dir);
program_->PersistModel(dir, program_desc_);
#else #else
LOG(INFO) << "Save model to ./";
program_->PersistModel("./", program_desc_);
#endif #endif
program_->PersistModel(dir, program_desc_);
LOG(INFO) << "Save model to " << dir;
} }
lite::Tensor *Predictor::GetInput(size_t offset) { lite::Tensor *Predictor::GetInput(size_t offset) {
......
...@@ -221,7 +221,7 @@ function test_arm { ...@@ -221,7 +221,7 @@ function test_arm {
echo "android do not need armv7hf" echo "android do not need armv7hf"
return 0 return 0
fi fi
# TODO(yuanshuai): enable armv7 on android # TODO(yuanshuai): enable armv7 on android
if [[ ${abi} == "armv7" ]]; then if [[ ${abi} == "armv7" ]]; then
echo "skip android v7 test yet" echo "skip android v7 test yet"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册