未验证 提交 d1b7aec5 编写于 作者: J juncaipeng 提交者: GitHub

Support Mask RCNN2 (#2588)

* Support Mask RCNN2 (#2588)
上级 b8992673
...@@ -557,6 +557,52 @@ void elementwise_mul<float>(const float* dinx, ...@@ -557,6 +557,52 @@ void elementwise_mul<float>(const float* dinx,
} }
} }
template <>
void elementwise_mul<int>(const int* dinx,
const int* diny,
int* dout,
int num) {
int cnt = num >> 4;
int remain = num % 16;
#pragma omp parallel for
for (int i = 0; i < cnt; ++i) {
const int* dinx_ptr = dinx + (i << 4);
const int* diny_ptr = diny + (i << 4);
int* dout_ptr = dout + (i << 4);
int32x4_t dinx0 = vld1q_s32(dinx_ptr);
int32x4_t dinx1 = vld1q_s32(dinx_ptr + 4);
int32x4_t dinx2 = vld1q_s32(dinx_ptr + 8);
int32x4_t dinx3 = vld1q_s32(dinx_ptr + 12);
int32x4_t diny0 = vld1q_s32(diny_ptr);
int32x4_t diny1 = vld1q_s32(diny_ptr + 4);
int32x4_t diny2 = vld1q_s32(diny_ptr + 8);
int32x4_t diny3 = vld1q_s32(diny_ptr + 12);
dinx0 = vmulq_s32(dinx0, diny0);
dinx1 = vmulq_s32(dinx1, diny1);
dinx2 = vmulq_s32(dinx2, diny2);
dinx3 = vmulq_s32(dinx3, diny3);
vst1q_s32(dout_ptr, dinx0);
vst1q_s32(dout_ptr + 4, dinx1);
vst1q_s32(dout_ptr + 8, dinx2);
vst1q_s32(dout_ptr + 12, dinx3);
}
if (remain > 0) {
const int* dinx_ptr = dinx + (cnt << 4);
const int* diny_ptr = diny + (cnt << 4);
int* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *dinx_ptr * *diny_ptr;
dout_ptr++;
dinx_ptr++;
diny_ptr++;
}
}
}
template <> template <>
void elementwise_mul_relu<float>(const float* dinx, void elementwise_mul_relu<float>(const float* dinx,
const float* diny, const float* diny,
...@@ -678,6 +724,73 @@ void elementwise_mul_broadcast<float>(const float* dinx, ...@@ -678,6 +724,73 @@ void elementwise_mul_broadcast<float>(const float* dinx,
} }
} }
template <>
void elementwise_mul_broadcast<int>(const int* dinx,
const int* diny,
int* dout,
int batch,
int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const int* din_ptr = dinx + offset;
const int diny_data = diny[j];
int* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
int32x4_t rb = vdupq_n_s32(diny_data);
for (int k = 0; k < cnt; ++k) {
int32x4_t din0 = vld1q_s32(din_ptr);
int32x4_t din1 = vld1q_s32(din_ptr + 4);
int32x4_t din2 = vld1q_s32(din_ptr + 8);
int32x4_t din3 = vld1q_s32(din_ptr + 12);
din0 = vmulq_s32(din0, rb);
din1 = vmulq_s32(din1, rb);
din2 = vmulq_s32(din2, rb);
din3 = vmulq_s32(din3, rb);
vst1q_s32(dout_ptr, din0);
vst1q_s32(dout_ptr + 4, din1);
vst1q_s32(dout_ptr + 8, din2);
vst1q_s32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
int32x4_t din0 = vld1q_s32(din_ptr);
int32x4_t din1 = vld1q_s32(din_ptr + 4);
din0 = vmulq_s32(din0, rb);
din1 = vmulq_s32(din1, rb);
vst1q_s32(dout_ptr, din0);
vst1q_s32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
int32x4_t din0 = vld1q_s32(din_ptr);
din0 = vmulq_s32(din0, rb);
vst1q_s32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (int p = 0; p < remain; ++p) {
*dout_ptr = *din_ptr * diny_data;
dout_ptr++;
din_ptr++;
}
}
}
}
}
template <> template <>
void elementwise_mul_relu_broadcast<float>(const float* dinx, void elementwise_mul_relu_broadcast<float>(const float* dinx,
const float* diny, const float* diny,
......
...@@ -526,9 +526,9 @@ void interpolate(lite::Tensor* X, ...@@ -526,9 +526,9 @@ void interpolate(lite::Tensor* X,
} }
auto out_size = OutSize; auto out_size = OutSize;
if (out_size != nullptr) { if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<float>(out_size); auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_height = static_cast<int>(out_size_data[0]); out_height = out_size_data[0];
out_width = static_cast<int>(out_size_data[1]); out_width = out_size_data[1];
} }
} }
float height_scale = scale; float height_scale = scale;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <iomanip> #include <iomanip>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -77,6 +78,20 @@ class TestCase { ...@@ -77,6 +78,20 @@ class TestCase {
// kernel registry. // kernel registry.
void CheckKernelConsistWithDefinition() {} void CheckKernelConsistWithDefinition() {}
// Get the real precision of the output for check precision. When the declare
// precision obtained from the kernel is any, we should set the precision of
// the output in test case.
bool GetPrecisonType(const std::string& var_name,
PrecisionType* precision_type) {
auto res = precision_type_map_.find(var_name);
if (res == precision_type_map_.end()) {
return false;
} else {
*precision_type = precision_type_map_.at(var_name);
return true;
}
}
Scope& scope() { return *scope_; } Scope& scope() { return *scope_; }
Scope* baseline_scope() { return base_scope_; } Scope* baseline_scope() { return base_scope_; }
...@@ -105,6 +120,19 @@ class TestCase { ...@@ -105,6 +120,19 @@ class TestCase {
// Prepare for the operator. // Prepare for the operator.
virtual void PrepareOpDesc(cpp::OpDesc* op_desc) = 0; virtual void PrepareOpDesc(cpp::OpDesc* op_desc) = 0;
// Set the real precision of the output for check precision. When the declare
// precision obtained from the kernel is any, we should set the precision of
// the output in test case.
void SetPrecisionType(const std::string& var_name,
const PrecisionType& precision_type) {
auto res = precision_type_map_.find(var_name);
if (res == precision_type_map_.end()) {
precision_type_map_.insert({var_name, precision_type});
} else {
precision_type_map_.at(var_name) = precision_type;
}
}
public: public:
const Instruction& instruction() { return *instruction_; } const Instruction& instruction() { return *instruction_; }
...@@ -148,6 +176,7 @@ class TestCase { ...@@ -148,6 +176,7 @@ class TestCase {
Scope* base_scope_{}; Scope* base_scope_{};
std::unique_ptr<cpp::OpDesc> op_desc_; std::unique_ptr<cpp::OpDesc> op_desc_;
std::unique_ptr<Instruction> instruction_; std::unique_ptr<Instruction> instruction_;
std::unordered_map<std::string, PrecisionType> precision_type_map_;
}; };
class Arena { class Arena {
...@@ -189,8 +218,11 @@ class Arena { ...@@ -189,8 +218,11 @@ class Arena {
// get tensor type. // get tensor type.
const Type* type = const Type* type =
tester_->instruction().kernel()->GetOutputDeclType(arg_name); tester_->instruction().kernel()->GetOutputDeclType(arg_name);
auto precision_type = type->precision();
switch (type->precision()) { if (precision_type == PRECISION(kAny)) {
CHECK(tester_->GetPrecisonType(var_name, &precision_type));
}
switch (precision_type) {
case PRECISION(kFloat): case PRECISION(kFloat):
return tester_->CheckPrecision<float>(var_name, abs_error_); return tester_->CheckPrecision<float>(var_name, abs_error_);
case PRECISION(kInt8): case PRECISION(kInt8):
...@@ -199,7 +231,6 @@ class Arena { ...@@ -199,7 +231,6 @@ class Arena {
return tester_->CheckPrecision<int32_t>(var_name, abs_error_); return tester_->CheckPrecision<int32_t>(var_name, abs_error_);
case PRECISION(kBool): case PRECISION(kBool):
return tester_->CheckPrecision<bool>(var_name, abs_error_); return tester_->CheckPrecision<bool>(var_name, abs_error_);
default: default:
LOG(FATAL) << "not support type " << PrecisionToStr(type->precision()); LOG(FATAL) << "not support type " << PrecisionToStr(type->precision());
return false; return false;
......
...@@ -25,7 +25,8 @@ namespace { ...@@ -25,7 +25,8 @@ namespace {
class Eliminator : public FuseBase { class Eliminator : public FuseBase {
public: public:
void BuildPattern() override { void BuildPattern() override {
auto* pre_op = OpNode("preop"); // the previous op's output need update // the previous op's output need updat
auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block");
// TODO(Superjomn) check has only one output // TODO(Superjomn) check has only one output
auto* x = VarNode("x")->assert_is_op_input("scale", "X"); auto* x = VarNode("x")->assert_is_op_input("scale", "X");
auto* scale_op = OpNode("scale", "scale") auto* scale_op = OpNode("scale", "scale")
......
...@@ -377,6 +377,19 @@ PMNode *PMNode::assert_is_op(const std::string &op_type) { ...@@ -377,6 +377,19 @@ PMNode *PMNode::assert_is_op(const std::string &op_type) {
return this; return this;
} }
PMNode *PMNode::assert_is_not_op_type(const std::string &op_type) {
asserts_.emplace_back([op_type](const Node *x) {
if (x && x->IsStmt()) {
auto *op_info = x->stmt()->op_info();
if (op_info->Type() == op_type) {
return false;
}
}
return true;
});
return this;
}
PMNode *PMNode::assert_is_var() { PMNode *PMNode::assert_is_var() {
asserts_.emplace_back([](const Node *x) { return x && x->IsArg(); }); asserts_.emplace_back([](const Node *x) { return x && x->IsArg(); });
return this; return this;
......
...@@ -123,6 +123,7 @@ struct PMNode { ...@@ -123,6 +123,7 @@ struct PMNode {
// Assertions, helper functions to simplify the pattern definition. // Assertions, helper functions to simplify the pattern definition.
PMNode* assert_is_op(); PMNode* assert_is_op();
PMNode* assert_is_op(const std::string& op_type); PMNode* assert_is_op(const std::string& op_type);
PMNode* assert_is_not_op_type(const std::string& op_type);
PMNode* assert_is_var(); PMNode* assert_is_var();
PMNode* assert_var_not_persistable(); PMNode* assert_var_not_persistable();
PMNode* assert_is_persistable_var(); PMNode* assert_is_persistable_var();
......
...@@ -79,6 +79,7 @@ add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_k ...@@ -79,6 +79,7 @@ add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_k
add_kernel(assign_value_compute_arm ARM extra SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(assign_value_compute_arm ARM extra SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific # for OCR specific
......
...@@ -74,6 +74,6 @@ void CastCompute::Run() { ...@@ -74,6 +74,6 @@ void CastCompute::Run() {
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
cast, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::CastCompute, def) cast, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::CastCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize(); .Finalize();
...@@ -141,7 +141,7 @@ REGISTER_LITE_KERNEL(collect_fpn_proposals, ...@@ -141,7 +141,7 @@ REGISTER_LITE_KERNEL(collect_fpn_proposals,
kNCHW, kNCHW,
paddle::lite::kernels::arm::CollectFpnProposalsCompute, paddle::lite::kernels::arm::CollectFpnProposalsCompute,
def) def)
.BindInput("MultiLevelRois", {LiteType::GetTensorListTy(TARGET(kARM))}) .BindInput("MultiLevelRois", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("MultiLevelScores", {LiteType::GetTensorListTy(TARGET(kARM))}) .BindInput("MultiLevelScores", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("FpnRois", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("FpnRois", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -219,6 +219,19 @@ REGISTER_LITE_KERNEL(greater_equal, ...@@ -219,6 +219,19 @@ REGISTER_LITE_KERNEL(greater_equal,
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(less_than,
kARM,
kInt32,
kNCHW,
paddle::lite::kernels::arm::CompareCompute_int32<
paddle::lite::kernels::arm::_LessThanFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(equal, REGISTER_LITE_KERNEL(equal,
kARM, kARM,
kInt32, kInt32,
......
...@@ -23,9 +23,8 @@ ...@@ -23,9 +23,8 @@
#include "lite/operators/conditional_block_op.h" #include "lite/operators/conditional_block_op.h"
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
#include "lite/core/profile/basic_profiler.h" #include "lite/core/profile/basic_profiler.h"
#endif // LITE_WITH_PROFILE
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/precision_profiler.h" #include "lite/core/profile/precision_profiler.h"
#include "lite/core/profile/profiler.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -57,6 +56,11 @@ class CondExecutor { ...@@ -57,6 +56,11 @@ class CondExecutor {
} }
void Run() { void Run() {
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
lite::profile::Profiler profiler;
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
for (auto &op_handler : ops_of_block_) { for (auto &op_handler : ops_of_block_) {
op_handler->CheckShape(); op_handler->CheckShape();
op_handler->InferShape(); op_handler->InferShape();
...@@ -64,6 +68,7 @@ class CondExecutor { ...@@ -64,6 +68,7 @@ class CondExecutor {
#ifdef LITE_WITH_PRECISION_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE
std::unique_ptr<KernelBase> kernel(op_handler->GetKernel()); std::unique_ptr<KernelBase> kernel(op_handler->GetKernel());
Instruction inst(op_handler, std::move(kernel)); Instruction inst(op_handler, std::move(kernel));
inst.set_profiler(&profiler);
#endif // LITE_WITH_PRECISION_PROFILE #endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE #endif // LITE_WITH_PROFILE
op_handler->Run(); op_handler->Run();
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/distribute_fpn_proposals_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
const int kBoxDim = 4;
template <typename T>
static inline T BBoxArea(const T* box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
void DistributeFpnProposalsCompute::Run() {
auto& param = Param<operators::DistributeFpnProposalsParam>();
const lite::Tensor* fpn_rois = param.fpn_rois;
std::vector<lite::Tensor*> multi_fpn_rois = param.multi_fpn_rois;
lite::Tensor* restore_index = param.restore_index;
int min_level = param.min_level;
int max_level = param.max_level;
int refer_level = param.refer_level;
int refer_scale = param.refer_scale;
int num_level = max_level - min_level + 1;
CHECK_EQ(fpn_rois->lod().size(), 1);
auto fpn_rois_lod = fpn_rois->lod().back();
int fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1];
std::vector<int> target_level;
// record the number of rois in each level
std::vector<int> num_rois_level(num_level, 0);
std::vector<int> num_rois_level_integral(num_level + 1, 0);
for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) {
auto fpn_rois_slice =
fpn_rois->Slice<float>(static_cast<int64_t>(fpn_rois_lod[i]),
static_cast<int64_t>(fpn_rois_lod[i + 1]));
const float* rois_data = fpn_rois_slice.data<float>();
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
// get the target level of current rois
float roi_scale = std::sqrt(BBoxArea(rois_data, false));
int tgt_lvl = std::floor(
std::log2(roi_scale / refer_scale + static_cast<float>(1e-6)) +
refer_level);
tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level));
target_level.push_back(tgt_lvl);
num_rois_level[tgt_lvl - min_level]++;
rois_data += kBoxDim;
}
}
// define the output rois
// pointer which point to each level fpn rois
std::vector<float*> multi_fpn_rois_data(num_level);
// lod0 which will record the offset information of each level rois
std::vector<std::vector<uint64_t>> multi_fpn_rois_lod0;
for (int i = 0; i < num_level; ++i) {
// allocate memory for each level rois
multi_fpn_rois[i]->Resize({num_rois_level[i], kBoxDim});
multi_fpn_rois_data[i] = multi_fpn_rois[i]->mutable_data<float>();
std::vector<uint64_t> lod0(1, 0);
multi_fpn_rois_lod0.push_back(lod0);
// statistic start point for each level rois
num_rois_level_integral[i + 1] =
num_rois_level_integral[i] + num_rois_level[i];
}
restore_index->Resize({fpn_rois_num, 1});
int* restore_index_data = restore_index->mutable_data<int>();
std::vector<int> restore_index_inter(fpn_rois_num, -1);
// distribute the rois into different fpn level by target level
for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) {
Tensor fpn_rois_slice =
fpn_rois->Slice<float>(static_cast<int64_t>(fpn_rois_lod[i]),
static_cast<int64_t>(fpn_rois_lod[i + 1]));
const float* rois_data = fpn_rois_slice.data<float>();
size_t cur_offset = fpn_rois_lod[i];
// std::vector<size_t > lod_offset[num_level];
for (int j = 0; j < num_level; j++) {
multi_fpn_rois_lod0[j].push_back(multi_fpn_rois_lod0[j][i]);
}
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
int lvl = target_level[cur_offset + j];
memcpy(multi_fpn_rois_data[lvl - min_level],
rois_data,
kBoxDim * sizeof(float));
multi_fpn_rois_data[lvl - min_level] += kBoxDim;
int index_in_shuffle = num_rois_level_integral[lvl - min_level] +
multi_fpn_rois_lod0[lvl - min_level][i + 1];
restore_index_inter[index_in_shuffle] = cur_offset + j;
multi_fpn_rois_lod0[lvl - min_level][i + 1]++;
rois_data += kBoxDim;
}
}
for (int i = 0; i < fpn_rois_num; ++i) {
restore_index_data[restore_index_inter[i]] = i;
}
// merge lod information into LoDTensor
for (int i = 0; i < num_level; ++i) {
lite::LoD lod;
lod.emplace_back(multi_fpn_rois_lod0[i]);
multi_fpn_rois[i]->set_lod(lod);
}
return;
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(distribute_fpn_proposals,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::DistributeFpnProposalsCompute,
def)
.BindInput("FpnRois", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("MultiFpnRois", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("RestoreIndex", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/distribute_fpn_proposals_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class DistributeFpnProposalsCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::DistributeFpnProposalsParam;
void Run() override;
virtual ~DistributeFpnProposalsCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -161,20 +161,21 @@ void ElementwiseSubActivationCompute::Run() { ...@@ -161,20 +161,21 @@ void ElementwiseSubActivationCompute::Run() {
} }
} }
void ElementwiseMulCompute::Run() { template <typename T, PrecisionType PType>
auto& param = Param<operators::ElementwiseParam>(); void ElementwiseMulCompute<T, PType>::Run() {
const float* x_data = param.X->data<float>(); auto& param = this->template Param<operators::ElementwiseParam>();
const float* y_data = param.Y->data<float>(); auto* x_data = param.X->template data<T>();
float* out_data = param.Out->mutable_data<float>(); auto* y_data = param.Y->template data<T>();
auto* out_data = param.Out->template mutable_data<T>();
int axis = param.axis; int axis = param.axis;
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_mul_broadcast( lite::arm::math::elementwise_mul_broadcast<T>(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
} else { } else {
lite::arm::math::elementwise_mul( lite::arm::math::elementwise_mul<T>(
x_data, y_data, out_data, x_dims.production()); x_data, y_data, out_data, x_dims.production());
} }
} }
...@@ -347,17 +348,24 @@ REGISTER_LITE_KERNEL( ...@@ -347,17 +348,24 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(elementwise_mul, using elementwise_mul_float =
kARM, paddle::lite::kernels::arm::ElementwiseMulCompute<float, PRECISION(kFloat)>;
kFloat, REGISTER_LITE_KERNEL(
kNCHW, elementwise_mul, kARM, kFloat, kNCHW, elementwise_mul_float, def)
paddle::lite::kernels::arm::ElementwiseMulCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_mul_int32 =
paddle::lite::kernels::arm::ElementwiseMulCompute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(
elementwise_mul, kARM, kInt32, kNCHW, elementwise_mul_int32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
fusion_elementwise_mul_activation, fusion_elementwise_mul_activation,
kARM, kARM,
......
...@@ -54,8 +54,8 @@ class ElementwiseSubActivationCompute ...@@ -54,8 +54,8 @@ class ElementwiseSubActivationCompute
virtual ~ElementwiseSubActivationCompute() = default; virtual ~ElementwiseSubActivationCompute() = default;
}; };
class ElementwiseMulCompute template <typename T, PrecisionType PType>
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class ElementwiseMulCompute : public KernelLite<TARGET(kARM), PType> {
public: public:
void Run() override; void Run() override;
......
...@@ -329,13 +329,13 @@ TEST(elementwise_mul_arm, retrive_op) { ...@@ -329,13 +329,13 @@ TEST(elementwise_mul_arm, retrive_op) {
} }
TEST(elementwise_mul_arm, init) { TEST(elementwise_mul_arm, init) {
ElementwiseMulCompute elementwise_mul; ElementwiseMulCompute<float, PRECISION(kFloat)> elementwise_mul;
ASSERT_EQ(elementwise_mul.precision(), PRECISION(kFloat)); ASSERT_EQ(elementwise_mul.precision(), PRECISION(kFloat));
ASSERT_EQ(elementwise_mul.target(), TARGET(kARM)); ASSERT_EQ(elementwise_mul.target(), TARGET(kARM));
} }
TEST(elementwise_mul, compute) { TEST(elementwise_mul, compute) {
ElementwiseMulCompute elementwise_mul; ElementwiseMulCompute<float, PRECISION(kFloat)> elementwise_mul;
operators::ElementwiseParam param; operators::ElementwiseParam param;
lite::Tensor x, y, output, output_ref; lite::Tensor x, y, output, output_ref;
......
...@@ -29,7 +29,7 @@ void GatherCompute::Run() { ...@@ -29,7 +29,7 @@ void GatherCompute::Run() {
auto index_size = param.Index->dims()[0]; auto index_size = param.Index->dims()[0];
auto src_dims = param.X->dims(); auto src_dims = param.X->dims();
const float* p_src = param.X->data<float>(); const float* p_src = param.X->data<float>();
const float* p_index = param.Index->data<float>(); const int* p_index = param.Index->data<int>();
int slice_size = 1; int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) { for (int i = 1; i < src_dims.size(); ++i) {
...@@ -50,6 +50,8 @@ void GatherCompute::Run() { ...@@ -50,6 +50,8 @@ void GatherCompute::Run() {
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
gather, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, def) gather, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -84,8 +84,10 @@ REGISTER_LITE_KERNEL(bilinear_interp, ...@@ -84,8 +84,10 @@ REGISTER_LITE_KERNEL(bilinear_interp,
paddle::lite::kernels::arm::BilinearInterpCompute, paddle::lite::kernels::arm::BilinearInterpCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("OutSize", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("OutSize",
.BindInput("SizeTensor", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SizeTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -97,8 +99,10 @@ REGISTER_LITE_KERNEL(nearest_interp, ...@@ -97,8 +99,10 @@ REGISTER_LITE_KERNEL(nearest_interp,
paddle::lite::kernels::arm::NearestInterpCompute, paddle::lite::kernels::arm::NearestInterpCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("OutSize", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("OutSize",
.BindInput("SizeTensor", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SizeTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -176,3 +176,14 @@ REGISTER_LITE_KERNEL(slice, kARM, kFloat, kNCHW, slice_float, def) ...@@ -176,3 +176,14 @@ REGISTER_LITE_KERNEL(slice, kARM, kFloat, kNCHW, slice_float, def)
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using slice_int32 =
paddle::lite::kernels::arm::SliceCompute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(slice, kARM, kInt32, kNCHW, slice_int32, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
...@@ -100,6 +100,7 @@ add_operator(attention_padding_mask_op_lite extra SRCS attention_padding_mask_op ...@@ -100,6 +100,7 @@ add_operator(attention_padding_mask_op_lite extra SRCS attention_padding_mask_op
add_operator(sequence_arithmetic_op_lite extra SRCS sequence_arithmetic_op.cc DEPS ${op_DEPS}) add_operator(sequence_arithmetic_op_lite extra SRCS sequence_arithmetic_op.cc DEPS ${op_DEPS})
add_operator(conditional_block_op_lite extra SRCS conditional_block_op.cc DEPS ${op_DEPS}) add_operator(conditional_block_op_lite extra SRCS conditional_block_op.cc DEPS ${op_DEPS})
add_operator(collect_fpn_proposals_op_lite extra SRCS collect_fpn_proposals_op.cc DEPS ${op_DEPS}) add_operator(collect_fpn_proposals_op_lite extra SRCS collect_fpn_proposals_op.cc DEPS ${op_DEPS})
add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposals_op.cc DEPS ${op_DEPS})
# for OCR specific # for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
...@@ -31,7 +31,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const { ...@@ -31,7 +31,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const {
} }
for (auto item : param_.multi_level_scores) { for (auto item : param_.multi_level_scores) {
auto dims = item->dims(); auto dims = item->dims();
CHECK_OR_FALSE(dims[1] == 2); CHECK_OR_FALSE(dims[1] == 1);
} }
for (int i = 0; i < param_.multi_level_rois.size(); i++) { for (int i = 0; i < param_.multi_level_rois.size(); i++) {
auto roi = param_.multi_level_rois[i]; auto roi = param_.multi_level_rois[i];
...@@ -45,6 +45,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const { ...@@ -45,6 +45,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const {
bool CollectFpnProposalsOpLite::InferShape() const { bool CollectFpnProposalsOpLite::InferShape() const {
param_.fpn_rois->Resize({param_.post_nms_topN, 4}); param_.fpn_rois->Resize({param_.post_nms_topN, 4});
return true; return true;
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/distribute_fpn_proposals_op.h"
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool DistributeFpnProposalsOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.fpn_rois);
CHECK_OR_FALSE(param_.restore_index);
CHECK_OR_FALSE(param_.multi_fpn_rois.size() > 1);
CHECK_OR_FALSE(param_.max_level >= param_.min_level);
size_t num_out_rois =
static_cast<size_t>(param_.max_level - param_.min_level + 1);
CHECK_OR_FALSE(num_out_rois == param_.multi_fpn_rois.size());
return true;
}
bool DistributeFpnProposalsOpLite::InferShape() const {
int num_out_rois = param_.max_level - param_.min_level + 1;
for (int i = 0; i < num_out_rois; i++) {
param_.multi_fpn_rois[i]->Resize({-1, 4});
}
param_.restore_index->Resize({-1, 1});
return true;
}
bool DistributeFpnProposalsOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto fpn_rois = op_desc.Input("FpnRois").front();
param_.fpn_rois = scope->FindVar(fpn_rois)->GetMutable<lite::Tensor>();
auto multi_fpn_rois = op_desc.Output("MultiFpnRois");
for (const auto &name : multi_fpn_rois) {
param_.multi_fpn_rois.push_back(
scope->FindVar(name)->GetMutable<lite::Tensor>());
}
auto restore_index = op_desc.Output("RestoreIndex").front();
param_.restore_index =
scope->FindVar(restore_index)->GetMutable<lite::Tensor>();
param_.min_level = op_desc.GetAttr<int>("min_level");
param_.max_level = op_desc.GetAttr<int>("max_level");
param_.refer_level = op_desc.GetAttr<int>("refer_level");
param_.refer_scale = op_desc.GetAttr<int>("refer_scale");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(distribute_fpn_proposals,
paddle::lite::operators::DistributeFpnProposalsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class DistributeFpnProposalsOpLite : public OpLite {
public:
DistributeFpnProposalsOpLite() {}
explicit DistributeFpnProposalsOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "distribute_fpn_proposals";
}
private:
mutable DistributeFpnProposalsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -1094,6 +1094,16 @@ struct CollectFpnProposalsParam { ...@@ -1094,6 +1094,16 @@ struct CollectFpnProposalsParam {
int post_nms_topN{}; int post_nms_topN{};
}; };
struct DistributeFpnProposalsParam {
const lite::Tensor* fpn_rois{};
std::vector<lite::Tensor*> multi_fpn_rois{};
lite::Tensor* restore_index{};
int min_level{};
int max_level{};
int refer_level{};
int refer_scale{};
};
/// --------------------- instance_norm operators -------------------- /// --------------------- instance_norm operators --------------------
struct InstanceNormParam { struct InstanceNormParam {
lite::Tensor* x{}; lite::Tensor* x{};
......
...@@ -80,6 +80,7 @@ class CastComputeTester : public arena::TestCase { ...@@ -80,6 +80,7 @@ class CastComputeTester : public arena::TestCase {
} }
void PrepareData() override { void PrepareData() override {
SetPrecisionType(output_, PRECISION(kFloat));
if (in_dtype_ == 20) { if (in_dtype_ == 20) {
std::vector<unsigned char> x_data(x_dims_.production()); std::vector<unsigned char> x_data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) { for (int i = 0; i < x_dims_.production(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册