From d1b7aec5a5bc2448c7975fba9e8ab86d6ee75f84 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 18 Dec 2019 15:57:47 +0800 Subject: [PATCH] Support Mask RCNN2 (#2588) * Support Mask RCNN2 (#2588) --- lite/backends/arm/math/elementwise.cc | 113 +++++++++++++ lite/backends/arm/math/interpolate.cc | 6 +- lite/core/arena/framework.h | 37 ++++- .../identity_scale_eliminate_pass.cc | 3 +- lite/core/mir/pattern_matcher.cc | 13 ++ lite/core/mir/pattern_matcher.h | 1 + lite/kernels/arm/CMakeLists.txt | 1 + lite/kernels/arm/cast_compute.cc | 4 +- .../arm/collect_fpn_proposals_compute.cc | 4 +- lite/kernels/arm/compare_compute.cc | 13 ++ lite/kernels/arm/conditional_block_compute.h | 9 +- .../arm/distribute_fpn_proposals_compute.cc | 151 ++++++++++++++++++ .../arm/distribute_fpn_proposals_compute.h | 38 +++++ lite/kernels/arm/elementwise_compute.cc | 34 ++-- lite/kernels/arm/elementwise_compute.h | 4 +- lite/kernels/arm/elementwise_compute_test.cc | 4 +- lite/kernels/arm/gather_compute.cc | 6 +- lite/kernels/arm/interpolate_compute.cc | 12 +- lite/kernels/arm/slice_compute.cc | 11 ++ lite/operators/CMakeLists.txt | 1 + lite/operators/collect_fpn_proposals_op.cc | 3 +- lite/operators/distribute_fpn_proposals_op.cc | 70 ++++++++ lite/operators/distribute_fpn_proposals_op.h | 51 ++++++ lite/operators/op_params.h | 10 ++ lite/tests/kernels/cast_compute_test.cc | 1 + 25 files changed, 563 insertions(+), 37 deletions(-) create mode 100644 lite/kernels/arm/distribute_fpn_proposals_compute.cc create mode 100644 lite/kernels/arm/distribute_fpn_proposals_compute.h create mode 100644 lite/operators/distribute_fpn_proposals_op.cc create mode 100644 lite/operators/distribute_fpn_proposals_op.h diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index a4c61f9a9d..186ad19735 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -557,6 +557,52 @@ void elementwise_mul(const float* dinx, } } +template <> +void elementwise_mul(const int* dinx, + const int* diny, + int* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + const int* dinx_ptr = dinx + (i << 4); + const int* diny_ptr = diny + (i << 4); + int* dout_ptr = dout + (i << 4); + + int32x4_t dinx0 = vld1q_s32(dinx_ptr); + int32x4_t dinx1 = vld1q_s32(dinx_ptr + 4); + int32x4_t dinx2 = vld1q_s32(dinx_ptr + 8); + int32x4_t dinx3 = vld1q_s32(dinx_ptr + 12); + + int32x4_t diny0 = vld1q_s32(diny_ptr); + int32x4_t diny1 = vld1q_s32(diny_ptr + 4); + int32x4_t diny2 = vld1q_s32(diny_ptr + 8); + int32x4_t diny3 = vld1q_s32(diny_ptr + 12); + + dinx0 = vmulq_s32(dinx0, diny0); + dinx1 = vmulq_s32(dinx1, diny1); + dinx2 = vmulq_s32(dinx2, diny2); + dinx3 = vmulq_s32(dinx3, diny3); + + vst1q_s32(dout_ptr, dinx0); + vst1q_s32(dout_ptr + 4, dinx1); + vst1q_s32(dout_ptr + 8, dinx2); + vst1q_s32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const int* dinx_ptr = dinx + (cnt << 4); + const int* diny_ptr = diny + (cnt << 4); + int* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *dinx_ptr * *diny_ptr; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + template <> void elementwise_mul_relu(const float* dinx, const float* diny, @@ -678,6 +724,73 @@ void elementwise_mul_broadcast(const float* dinx, } } +template <> +void elementwise_mul_broadcast(const int* dinx, + const int* diny, + int* dout, + int batch, + int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const int* din_ptr = dinx + offset; + const int diny_data = diny[j]; + int* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + int32x4_t rb = vdupq_n_s32(diny_data); + for (int k = 0; k < cnt; ++k) { + int32x4_t din0 = vld1q_s32(din_ptr); + int32x4_t din1 = vld1q_s32(din_ptr + 4); + int32x4_t din2 = vld1q_s32(din_ptr + 8); + int32x4_t din3 = vld1q_s32(din_ptr + 12); + + din0 = vmulq_s32(din0, rb); + din1 = vmulq_s32(din1, rb); + din2 = vmulq_s32(din2, rb); + din3 = vmulq_s32(din3, rb); + + vst1q_s32(dout_ptr, din0); + vst1q_s32(dout_ptr + 4, din1); + vst1q_s32(dout_ptr + 8, din2); + vst1q_s32(dout_ptr + 12, din3); + + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + int32x4_t din0 = vld1q_s32(din_ptr); + int32x4_t din1 = vld1q_s32(din_ptr + 4); + din0 = vmulq_s32(din0, rb); + din1 = vmulq_s32(din1, rb); + vst1q_s32(dout_ptr, din0); + vst1q_s32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + int32x4_t din0 = vld1q_s32(din_ptr); + din0 = vmulq_s32(din0, rb); + vst1q_s32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; ++p) { + *dout_ptr = *din_ptr * diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } +} + template <> void elementwise_mul_relu_broadcast(const float* dinx, const float* diny, diff --git a/lite/backends/arm/math/interpolate.cc b/lite/backends/arm/math/interpolate.cc index 34d9a20433..1c53142fc5 100644 --- a/lite/backends/arm/math/interpolate.cc +++ b/lite/backends/arm/math/interpolate.cc @@ -526,9 +526,9 @@ void interpolate(lite::Tensor* X, } auto out_size = OutSize; if (out_size != nullptr) { - auto out_size_data = get_new_data_from_tensor(out_size); - out_height = static_cast(out_size_data[0]); - out_width = static_cast(out_size_data[1]); + auto out_size_data = get_new_data_from_tensor(out_size); + out_height = out_size_data[0]; + out_width = out_size_data[1]; } } float height_scale = scale; diff --git a/lite/core/arena/framework.h b/lite/core/arena/framework.h index ac71a6c011..671da20bdc 100644 --- a/lite/core/arena/framework.h +++ b/lite/core/arena/framework.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include "lite/core/op_registry.h" @@ -77,6 +78,20 @@ class TestCase { // kernel registry. void CheckKernelConsistWithDefinition() {} + // Get the real precision of the output for check precision. When the declare + // precision obtained from the kernel is any, we should set the precision of + // the output in test case. + bool GetPrecisonType(const std::string& var_name, + PrecisionType* precision_type) { + auto res = precision_type_map_.find(var_name); + if (res == precision_type_map_.end()) { + return false; + } else { + *precision_type = precision_type_map_.at(var_name); + return true; + } + } + Scope& scope() { return *scope_; } Scope* baseline_scope() { return base_scope_; } @@ -105,6 +120,19 @@ class TestCase { // Prepare for the operator. virtual void PrepareOpDesc(cpp::OpDesc* op_desc) = 0; + // Set the real precision of the output for check precision. When the declare + // precision obtained from the kernel is any, we should set the precision of + // the output in test case. + void SetPrecisionType(const std::string& var_name, + const PrecisionType& precision_type) { + auto res = precision_type_map_.find(var_name); + if (res == precision_type_map_.end()) { + precision_type_map_.insert({var_name, precision_type}); + } else { + precision_type_map_.at(var_name) = precision_type; + } + } + public: const Instruction& instruction() { return *instruction_; } @@ -148,6 +176,7 @@ class TestCase { Scope* base_scope_{}; std::unique_ptr op_desc_; std::unique_ptr instruction_; + std::unordered_map precision_type_map_; }; class Arena { @@ -189,8 +218,11 @@ class Arena { // get tensor type. const Type* type = tester_->instruction().kernel()->GetOutputDeclType(arg_name); - - switch (type->precision()) { + auto precision_type = type->precision(); + if (precision_type == PRECISION(kAny)) { + CHECK(tester_->GetPrecisonType(var_name, &precision_type)); + } + switch (precision_type) { case PRECISION(kFloat): return tester_->CheckPrecision(var_name, abs_error_); case PRECISION(kInt8): @@ -199,7 +231,6 @@ class Arena { return tester_->CheckPrecision(var_name, abs_error_); case PRECISION(kBool): return tester_->CheckPrecision(var_name, abs_error_); - default: LOG(FATAL) << "not support type " << PrecisionToStr(type->precision()); return false; diff --git a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc index acea48c742..345361047b 100644 --- a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc +++ b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc @@ -25,7 +25,8 @@ namespace { class Eliminator : public FuseBase { public: void BuildPattern() override { - auto* pre_op = OpNode("preop"); // the previous op's output need update + // the previous op's output need updat + auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block"); // TODO(Superjomn) check has only one output auto* x = VarNode("x")->assert_is_op_input("scale", "X"); auto* scale_op = OpNode("scale", "scale") diff --git a/lite/core/mir/pattern_matcher.cc b/lite/core/mir/pattern_matcher.cc index 8e0fc55be2..b625919cbf 100644 --- a/lite/core/mir/pattern_matcher.cc +++ b/lite/core/mir/pattern_matcher.cc @@ -377,6 +377,19 @@ PMNode *PMNode::assert_is_op(const std::string &op_type) { return this; } +PMNode *PMNode::assert_is_not_op_type(const std::string &op_type) { + asserts_.emplace_back([op_type](const Node *x) { + if (x && x->IsStmt()) { + auto *op_info = x->stmt()->op_info(); + if (op_info->Type() == op_type) { + return false; + } + } + return true; + }); + return this; +} + PMNode *PMNode::assert_is_var() { asserts_.emplace_back([](const Node *x) { return x && x->IsArg(); }); return this; diff --git a/lite/core/mir/pattern_matcher.h b/lite/core/mir/pattern_matcher.h index 47a0a30b56..90c4359c6d 100644 --- a/lite/core/mir/pattern_matcher.h +++ b/lite/core/mir/pattern_matcher.h @@ -123,6 +123,7 @@ struct PMNode { // Assertions, helper functions to simplify the pattern definition. PMNode* assert_is_op(); PMNode* assert_is_op(const std::string& op_type); + PMNode* assert_is_not_op_type(const std::string& op_type); PMNode* assert_is_var(); PMNode* assert_var_not_persistable(); PMNode* assert_is_persistable_var(); diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index eab03aac6c..f543c000f8 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -79,6 +79,7 @@ add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_k add_kernel(assign_value_compute_arm ARM extra SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) # for OCR specific diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc index 1fef52bcb7..266ae1fc91 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -74,6 +74,6 @@ void CastCompute::Run() { REGISTER_LITE_KERNEL( cast, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::CastCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .Finalize(); diff --git a/lite/kernels/arm/collect_fpn_proposals_compute.cc b/lite/kernels/arm/collect_fpn_proposals_compute.cc index 5c727e7a5d..d54b96348e 100644 --- a/lite/kernels/arm/collect_fpn_proposals_compute.cc +++ b/lite/kernels/arm/collect_fpn_proposals_compute.cc @@ -141,7 +141,7 @@ REGISTER_LITE_KERNEL(collect_fpn_proposals, kNCHW, paddle::lite::kernels::arm::CollectFpnProposalsCompute, def) - .BindInput("MultiLevelRois", {LiteType::GetTensorListTy(TARGET(kARM))}) - .BindInput("MultiLevelScores", {LiteType::GetTensorListTy(TARGET(kARM))}) + .BindInput("MultiLevelRois", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("MultiLevelScores", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("FpnRois", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/compare_compute.cc b/lite/kernels/arm/compare_compute.cc index 7dcfa1eac0..6118cbc6e4 100644 --- a/lite/kernels/arm/compare_compute.cc +++ b/lite/kernels/arm/compare_compute.cc @@ -219,6 +219,19 @@ REGISTER_LITE_KERNEL(greater_equal, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .Finalize(); + +REGISTER_LITE_KERNEL(less_than, + kARM, + kInt32, + kNCHW, + paddle::lite::kernels::arm::CompareCompute_int32< + paddle::lite::kernels::arm::_LessThanFunctor>, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) + .Finalize(); + REGISTER_LITE_KERNEL(equal, kARM, kInt32, diff --git a/lite/kernels/arm/conditional_block_compute.h b/lite/kernels/arm/conditional_block_compute.h index 0c0369b0fd..91eadff931 100644 --- a/lite/kernels/arm/conditional_block_compute.h +++ b/lite/kernels/arm/conditional_block_compute.h @@ -23,9 +23,8 @@ #include "lite/operators/conditional_block_op.h" #ifdef LITE_WITH_PROFILE #include "lite/core/profile/basic_profiler.h" -#endif // LITE_WITH_PROFILE -#ifdef LITE_WITH_PROFILE #include "lite/core/profile/precision_profiler.h" +#include "lite/core/profile/profiler.h" #endif namespace paddle { @@ -57,6 +56,11 @@ class CondExecutor { } void Run() { +#ifdef LITE_WITH_PROFILE +#ifdef LITE_WITH_PRECISION_PROFILE + lite::profile::Profiler profiler; +#endif // LITE_WITH_PRECISION_PROFILE +#endif // LITE_WITH_PROFILE for (auto &op_handler : ops_of_block_) { op_handler->CheckShape(); op_handler->InferShape(); @@ -64,6 +68,7 @@ class CondExecutor { #ifdef LITE_WITH_PRECISION_PROFILE std::unique_ptr kernel(op_handler->GetKernel()); Instruction inst(op_handler, std::move(kernel)); + inst.set_profiler(&profiler); #endif // LITE_WITH_PRECISION_PROFILE #endif // LITE_WITH_PROFILE op_handler->Run(); diff --git a/lite/kernels/arm/distribute_fpn_proposals_compute.cc b/lite/kernels/arm/distribute_fpn_proposals_compute.cc new file mode 100644 index 0000000000..0871a3e84b --- /dev/null +++ b/lite/kernels/arm/distribute_fpn_proposals_compute.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/distribute_fpn_proposals_compute.h" +#include +#include +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +const int kBoxDim = 4; + +template +static inline T BBoxArea(const T* box, bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +void DistributeFpnProposalsCompute::Run() { + auto& param = Param(); + const lite::Tensor* fpn_rois = param.fpn_rois; + std::vector multi_fpn_rois = param.multi_fpn_rois; + lite::Tensor* restore_index = param.restore_index; + int min_level = param.min_level; + int max_level = param.max_level; + int refer_level = param.refer_level; + int refer_scale = param.refer_scale; + int num_level = max_level - min_level + 1; + + CHECK_EQ(fpn_rois->lod().size(), 1); + auto fpn_rois_lod = fpn_rois->lod().back(); + int fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1]; + std::vector target_level; + // record the number of rois in each level + std::vector num_rois_level(num_level, 0); + std::vector num_rois_level_integral(num_level + 1, 0); + for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) { + auto fpn_rois_slice = + fpn_rois->Slice(static_cast(fpn_rois_lod[i]), + static_cast(fpn_rois_lod[i + 1])); + const float* rois_data = fpn_rois_slice.data(); + for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) { + // get the target level of current rois + float roi_scale = std::sqrt(BBoxArea(rois_data, false)); + int tgt_lvl = std::floor( + std::log2(roi_scale / refer_scale + static_cast(1e-6)) + + refer_level); + tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level)); + target_level.push_back(tgt_lvl); + num_rois_level[tgt_lvl - min_level]++; + rois_data += kBoxDim; + } + } + // define the output rois + // pointer which point to each level fpn rois + std::vector multi_fpn_rois_data(num_level); + // lod0 which will record the offset information of each level rois + std::vector> multi_fpn_rois_lod0; + for (int i = 0; i < num_level; ++i) { + // allocate memory for each level rois + multi_fpn_rois[i]->Resize({num_rois_level[i], kBoxDim}); + multi_fpn_rois_data[i] = multi_fpn_rois[i]->mutable_data(); + std::vector lod0(1, 0); + multi_fpn_rois_lod0.push_back(lod0); + // statistic start point for each level rois + num_rois_level_integral[i + 1] = + num_rois_level_integral[i] + num_rois_level[i]; + } + restore_index->Resize({fpn_rois_num, 1}); + int* restore_index_data = restore_index->mutable_data(); + std::vector restore_index_inter(fpn_rois_num, -1); + // distribute the rois into different fpn level by target level + for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) { + Tensor fpn_rois_slice = + fpn_rois->Slice(static_cast(fpn_rois_lod[i]), + static_cast(fpn_rois_lod[i + 1])); + const float* rois_data = fpn_rois_slice.data(); + size_t cur_offset = fpn_rois_lod[i]; + // std::vector lod_offset[num_level]; + for (int j = 0; j < num_level; j++) { + multi_fpn_rois_lod0[j].push_back(multi_fpn_rois_lod0[j][i]); + } + for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) { + int lvl = target_level[cur_offset + j]; + memcpy(multi_fpn_rois_data[lvl - min_level], + rois_data, + kBoxDim * sizeof(float)); + multi_fpn_rois_data[lvl - min_level] += kBoxDim; + int index_in_shuffle = num_rois_level_integral[lvl - min_level] + + multi_fpn_rois_lod0[lvl - min_level][i + 1]; + restore_index_inter[index_in_shuffle] = cur_offset + j; + multi_fpn_rois_lod0[lvl - min_level][i + 1]++; + rois_data += kBoxDim; + } + } + for (int i = 0; i < fpn_rois_num; ++i) { + restore_index_data[restore_index_inter[i]] = i; + } + // merge lod information into LoDTensor + for (int i = 0; i < num_level; ++i) { + lite::LoD lod; + lod.emplace_back(multi_fpn_rois_lod0[i]); + multi_fpn_rois[i]->set_lod(lod); + } + return; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(distribute_fpn_proposals, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::DistributeFpnProposalsCompute, + def) + .BindInput("FpnRois", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("MultiFpnRois", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("RestoreIndex", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/distribute_fpn_proposals_compute.h b/lite/kernels/arm/distribute_fpn_proposals_compute.h new file mode 100644 index 0000000000..e150b338de --- /dev/null +++ b/lite/kernels/arm/distribute_fpn_proposals_compute.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/operators/distribute_fpn_proposals_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class DistributeFpnProposalsCompute + : public KernelLite { + public: + using param_t = operators::DistributeFpnProposalsParam; + + void Run() override; + + virtual ~DistributeFpnProposalsCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index 2e57b6a3b3..94c5e140ba 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -161,20 +161,21 @@ void ElementwiseSubActivationCompute::Run() { } } -void ElementwiseMulCompute::Run() { - auto& param = Param(); - const float* x_data = param.X->data(); - const float* y_data = param.Y->data(); - float* out_data = param.Out->mutable_data(); +template +void ElementwiseMulCompute::Run() { + auto& param = this->template Param(); + auto* x_data = param.X->template data(); + auto* y_data = param.Y->template data(); + auto* out_data = param.Out->template mutable_data(); int axis = param.axis; auto x_dims = param.X->dims(); auto y_dims = param.Y->dims(); int pre, n, post; if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_mul_broadcast( + lite::arm::math::elementwise_mul_broadcast( x_data, y_data, out_data, pre, n, post); } else { - lite::arm::math::elementwise_mul( + lite::arm::math::elementwise_mul( x_data, y_data, out_data, x_dims.production()); } } @@ -347,17 +348,24 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -REGISTER_LITE_KERNEL(elementwise_mul, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::ElementwiseMulCompute, - def) +using elementwise_mul_float = + paddle::lite::kernels::arm::ElementwiseMulCompute; +REGISTER_LITE_KERNEL( + elementwise_mul, kARM, kFloat, kNCHW, elementwise_mul_float, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +using elementwise_mul_int32 = + paddle::lite::kernels::arm::ElementwiseMulCompute; +REGISTER_LITE_KERNEL( + elementwise_mul, kARM, kInt32, kNCHW, elementwise_mul_int32, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_mul_activation, kARM, diff --git a/lite/kernels/arm/elementwise_compute.h b/lite/kernels/arm/elementwise_compute.h index e76449aebc..731010a0d1 100644 --- a/lite/kernels/arm/elementwise_compute.h +++ b/lite/kernels/arm/elementwise_compute.h @@ -54,8 +54,8 @@ class ElementwiseSubActivationCompute virtual ~ElementwiseSubActivationCompute() = default; }; -class ElementwiseMulCompute - : public KernelLite { +template +class ElementwiseMulCompute : public KernelLite { public: void Run() override; diff --git a/lite/kernels/arm/elementwise_compute_test.cc b/lite/kernels/arm/elementwise_compute_test.cc index 2bc5863a18..b0ac3a7d33 100644 --- a/lite/kernels/arm/elementwise_compute_test.cc +++ b/lite/kernels/arm/elementwise_compute_test.cc @@ -329,13 +329,13 @@ TEST(elementwise_mul_arm, retrive_op) { } TEST(elementwise_mul_arm, init) { - ElementwiseMulCompute elementwise_mul; + ElementwiseMulCompute elementwise_mul; ASSERT_EQ(elementwise_mul.precision(), PRECISION(kFloat)); ASSERT_EQ(elementwise_mul.target(), TARGET(kARM)); } TEST(elementwise_mul, compute) { - ElementwiseMulCompute elementwise_mul; + ElementwiseMulCompute elementwise_mul; operators::ElementwiseParam param; lite::Tensor x, y, output, output_ref; diff --git a/lite/kernels/arm/gather_compute.cc b/lite/kernels/arm/gather_compute.cc index a46a6f9d6a..c91b86e53f 100644 --- a/lite/kernels/arm/gather_compute.cc +++ b/lite/kernels/arm/gather_compute.cc @@ -29,7 +29,7 @@ void GatherCompute::Run() { auto index_size = param.Index->dims()[0]; auto src_dims = param.X->dims(); const float* p_src = param.X->data(); - const float* p_index = param.Index->data(); + const int* p_index = param.Index->data(); int slice_size = 1; for (int i = 1; i < src_dims.size(); ++i) { @@ -50,6 +50,8 @@ void GatherCompute::Run() { REGISTER_LITE_KERNEL( gather, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, def) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Index", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/interpolate_compute.cc b/lite/kernels/arm/interpolate_compute.cc index 0398dabeae..760b2fcf06 100644 --- a/lite/kernels/arm/interpolate_compute.cc +++ b/lite/kernels/arm/interpolate_compute.cc @@ -84,8 +84,10 @@ REGISTER_LITE_KERNEL(bilinear_interp, paddle::lite::kernels::arm::BilinearInterpCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("OutSize", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("SizeTensor", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("OutSize", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("SizeTensor", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); @@ -97,8 +99,10 @@ REGISTER_LITE_KERNEL(nearest_interp, paddle::lite::kernels::arm::NearestInterpCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("OutSize", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("SizeTensor", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("OutSize", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("SizeTensor", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/slice_compute.cc b/lite/kernels/arm/slice_compute.cc index 4bf790cf14..05f48917aa 100644 --- a/lite/kernels/arm/slice_compute.cc +++ b/lite/kernels/arm/slice_compute.cc @@ -176,3 +176,14 @@ REGISTER_LITE_KERNEL(slice, kARM, kFloat, kNCHW, slice_float, def) .BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +using slice_int32 = + paddle::lite::kernels::arm::SliceCompute; +REGISTER_LITE_KERNEL(slice, kARM, kInt32, kNCHW, slice_int32, def) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .Finalize(); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 231a2cedec..34b364ae39 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -100,6 +100,7 @@ add_operator(attention_padding_mask_op_lite extra SRCS attention_padding_mask_op add_operator(sequence_arithmetic_op_lite extra SRCS sequence_arithmetic_op.cc DEPS ${op_DEPS}) add_operator(conditional_block_op_lite extra SRCS conditional_block_op.cc DEPS ${op_DEPS}) add_operator(collect_fpn_proposals_op_lite extra SRCS collect_fpn_proposals_op.cc DEPS ${op_DEPS}) +add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposals_op.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/collect_fpn_proposals_op.cc b/lite/operators/collect_fpn_proposals_op.cc index b8659fe466..4731d4bf81 100644 --- a/lite/operators/collect_fpn_proposals_op.cc +++ b/lite/operators/collect_fpn_proposals_op.cc @@ -31,7 +31,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const { } for (auto item : param_.multi_level_scores) { auto dims = item->dims(); - CHECK_OR_FALSE(dims[1] == 2); + CHECK_OR_FALSE(dims[1] == 1); } for (int i = 0; i < param_.multi_level_rois.size(); i++) { auto roi = param_.multi_level_rois[i]; @@ -45,6 +45,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const { bool CollectFpnProposalsOpLite::InferShape() const { param_.fpn_rois->Resize({param_.post_nms_topN, 4}); + return true; } diff --git a/lite/operators/distribute_fpn_proposals_op.cc b/lite/operators/distribute_fpn_proposals_op.cc new file mode 100644 index 0000000000..5d6a0fca92 --- /dev/null +++ b/lite/operators/distribute_fpn_proposals_op.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/distribute_fpn_proposals_op.h" +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool DistributeFpnProposalsOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.fpn_rois); + CHECK_OR_FALSE(param_.restore_index); + CHECK_OR_FALSE(param_.multi_fpn_rois.size() > 1); + CHECK_OR_FALSE(param_.max_level >= param_.min_level); + size_t num_out_rois = + static_cast(param_.max_level - param_.min_level + 1); + CHECK_OR_FALSE(num_out_rois == param_.multi_fpn_rois.size()); + return true; +} + +bool DistributeFpnProposalsOpLite::InferShape() const { + int num_out_rois = param_.max_level - param_.min_level + 1; + for (int i = 0; i < num_out_rois; i++) { + param_.multi_fpn_rois[i]->Resize({-1, 4}); + } + param_.restore_index->Resize({-1, 1}); + return true; +} + +bool DistributeFpnProposalsOpLite::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + auto fpn_rois = op_desc.Input("FpnRois").front(); + param_.fpn_rois = scope->FindVar(fpn_rois)->GetMutable(); + + auto multi_fpn_rois = op_desc.Output("MultiFpnRois"); + for (const auto &name : multi_fpn_rois) { + param_.multi_fpn_rois.push_back( + scope->FindVar(name)->GetMutable()); + } + auto restore_index = op_desc.Output("RestoreIndex").front(); + param_.restore_index = + scope->FindVar(restore_index)->GetMutable(); + param_.min_level = op_desc.GetAttr("min_level"); + param_.max_level = op_desc.GetAttr("max_level"); + param_.refer_level = op_desc.GetAttr("refer_level"); + param_.refer_scale = op_desc.GetAttr("refer_scale"); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(distribute_fpn_proposals, + paddle::lite::operators::DistributeFpnProposalsOpLite); diff --git a/lite/operators/distribute_fpn_proposals_op.h b/lite/operators/distribute_fpn_proposals_op.h new file mode 100644 index 0000000000..2390e32932 --- /dev/null +++ b/lite/operators/distribute_fpn_proposals_op.h @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class DistributeFpnProposalsOpLite : public OpLite { + public: + DistributeFpnProposalsOpLite() {} + + explicit DistributeFpnProposalsOpLite(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "distribute_fpn_proposals"; + } + + private: + mutable DistributeFpnProposalsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index ef0bce5ead..aade54c0e5 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1094,6 +1094,16 @@ struct CollectFpnProposalsParam { int post_nms_topN{}; }; +struct DistributeFpnProposalsParam { + const lite::Tensor* fpn_rois{}; + std::vector multi_fpn_rois{}; + lite::Tensor* restore_index{}; + int min_level{}; + int max_level{}; + int refer_level{}; + int refer_scale{}; +}; + /// --------------------- instance_norm operators -------------------- struct InstanceNormParam { lite::Tensor* x{}; diff --git a/lite/tests/kernels/cast_compute_test.cc b/lite/tests/kernels/cast_compute_test.cc index 7c83aed164..fea3452dbc 100644 --- a/lite/tests/kernels/cast_compute_test.cc +++ b/lite/tests/kernels/cast_compute_test.cc @@ -80,6 +80,7 @@ class CastComputeTester : public arena::TestCase { } void PrepareData() override { + SetPrecisionType(output_, PRECISION(kFloat)); if (in_dtype_ == 20) { std::vector x_data(x_dims_.production()); for (int i = 0; i < x_dims_.production(); i++) { -- GitLab