diff --git a/lite/api/model_test.cc b/lite/api/model_test.cc index 6e0a249a81c8c2476a9a0685ab6492da3d4013a6..114d1acdbe1aa3e73bfa593a7a8950eacf3d415d 100644 --- a/lite/api/model_test.cc +++ b/lite/api/model_test.cc @@ -40,6 +40,7 @@ void OutputOptModel(const std::string& load_model_dir, config.set_valid_places({ Place{TARGET(kX86), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kFloat)}, }); auto predictor = lite_api::CreatePaddlePredictor(config); diff --git a/lite/core/mir/fusion/conv_elementwise_fuser.cc b/lite/core/mir/fusion/conv_elementwise_fuser.cc index c3ab3e4c4ca9bd8d6a6eaaf82e40dcb06cf99ea9..abc78edda88e008945e9d184b02e5feef3e5a4b1 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuser.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuser.cc @@ -27,8 +27,10 @@ void ConvElementwiseFuser::BuildPattern() { VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); auto* filter = VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput(); - auto* bias = - VarNode("bias")->assert_is_op_input("elementwise_add", "Y")->AsInput(); + auto* bias = VarNode("bias") + ->assert_is_op_input("elementwise_add", "Y") + ->AsInput() + ->assert_is_persistable_var(); // create op nodes auto* conv2d = diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc index 6956e805c673d8776d7bdd414dce0a5ddfcd965a..02311b0579456768389bb725cae71ae897c78432 100644 --- a/lite/core/mir/memory_optimize_pass.cc +++ b/lite/core/mir/memory_optimize_pass.cc @@ -49,6 +49,7 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( "equal", "lod_reset", "concat", + "yolo_box", "graph_op", "feed", "fetch"}; diff --git a/lite/kernels/arm/norm_compute.cc b/lite/kernels/arm/norm_compute.cc index 3cc1645fc6823c4c3276cd1f22f4be8a584d2073..fb8b4bbe0773b808a0f6942d1120ebc7d4e844d2 100644 --- a/lite/kernels/arm/norm_compute.cc +++ b/lite/kernels/arm/norm_compute.cc @@ -47,4 +47,5 @@ REGISTER_LITE_KERNEL( norm, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::NormCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Norm", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 5c88a28527aec59e0c5c729f893f5a81d3dfdae2..321474875fe2dc17e28e084b8477f468ccb99a88 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -600,6 +600,7 @@ struct SequenceSoftmaxParam { struct NormParam { const lite::Tensor* X{}; lite::Tensor* Out{}; + lite::Tensor* Norm{}; int axis{1}; float epsilon{1e-10}; }; diff --git a/lite/tests/kernels/norm_compute_test.cc b/lite/tests/kernels/norm_compute_test.cc index 830bac062784a8c16752f4e43a23ed8157cc6c0f..6aee1758c19cd793de709921c4733b8892e5f3d9 100644 --- a/lite/tests/kernels/norm_compute_test.cc +++ b/lite/tests/kernels/norm_compute_test.cc @@ -46,7 +46,7 @@ class NormComputeTester : public arena::TestCase { auto* x = scope->FindTensor(input_); const auto* x_data = x->data(); - int axis = axis_ < 0 ? axis + dims_.size() : axis_; + int axis = axis_ < 0 ? axis_ + dims_.size() : axis_; int pre_n = dims_.count(0, axis); int n = dims_[axis]; int post_n = dims_.count(axis + 1, dims_.size()); diff --git a/lite/tools/build.sh b/lite/tools/build.sh index d56a5f81cb87c84078c475caa634e54df7fd280f..87e50fd11e839ee3dd45552ee17944a39bb5b2be 100755 --- a/lite/tools/build.sh +++ b/lite/tools/build.sh @@ -221,7 +221,7 @@ function print_usage { echo -e "argument choices:" echo -e "--arm_os:\t android|ios|ios64" echo -e "--arm_abi:\t armv8|armv7" - echo -e "--arm_lang:\t gcc|clang (for android)" + echo -e "--arm_lang:\t only support gcc now, clang will be supported in future.(for android)" echo -e "--android_stl:\t c++_static|c++_shared (for android)" echo echo -e "tasks:" @@ -252,6 +252,13 @@ function main { ;; --arm_lang=*) ARM_LANG="${i#*=}" + if [ ${ARM_LANG} == "clang" ]; then + set +x + echo + echo -e "error: only support gcc now, clang will be supported in future." + echo + exit 1 + fi shift ;; --android_stl=*)