未验证 提交 9f767f94 编写于 作者: H HappyAngel 提交者: GitHub

[arm]add deformable Conv op (#3732)

* add deformable Conv op 

* fix ut, test=develop

* fix format. test=develop

* test=develop

* delete unuseful message. test=develop
上级 847311bd
......@@ -95,6 +95,9 @@ add_kernel(fill_constant_batch_size_like_compute_arm ARM basic SRCS fill_constan
add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for deformable-convNet
add_kernel(deformable_conv_compute_arm ARM extra SRCS deformable_conv_compute.cc DEPS ${lite_kernel_deps} math_arm)
# 4. training kernels
add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/deformable_conv_compute.h"
#include <cmath>
#include <utility>
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/kernels/arm/conv_depthwise.h"
#include "lite/kernels/arm/conv_direct.h"
#include "lite/kernels/arm/conv_gemmlike.h"
#include "lite/kernels/arm/conv_winograd.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <>
void DeformableConvCompute<PRECISION(kFloat),
PRECISION(kFloat)>::PrepareForRun() {
ReInitWhenNeeded();
}
static inline float deformable_bilinear(const float* bottom_data,
const int height,
const int width,
float h,
float w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
if (h_low >= height - 1) {
h_high = h_low = height - 1;
h = static_cast<float>(h_low);
} else {
h_high = h_low + 1;
}
if (w_low >= width - 1) {
w_high = w_low = width - 1;
w = static_cast<float>(w_low);
} else {
w_high = w_low + 1;
}
float lh = h - h_low;
float lw = w - w_low;
float hh = 1 - lh;
float hw = 1 - lw;
float v1 = bottom_data[h_low * width + w_low];
float v2 = bottom_data[h_low * width + w_high];
float v3 = bottom_data[h_high * width + w_low];
float v4 = bottom_data[h_high * width + w_high];
float w1 = hh * hw;
float w2 = hh * lw;
float w3 = lh * hw;
float w4 = lh * lw;
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <>
void DeformableConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
// basic implement
// param.x shape [n, cin, hin, win];
// param.offset shape [n, 2 * deformabel_group * kw * kh, hin, win]
// param.mask shape [n, deformabel_group * kw * kh, hin, win]
// param.filter shape [cout, cin/group, kw, kh]
// param.output shape [n, cout, hout, wout]
// deformable_group == group
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
const auto* in_data = param.x->data<float>();
const auto* filter_data = param.conv_param.filter->data<float>();
const auto* offset_data = param.offset->data<float>();
const auto* mask_data = param.mask->data<float>();
float* out_data = param.output->mutable_data<float>();
auto in_dims = param.x->dims();
auto filter_dims = param.conv_param.filter->dims();
auto out_dims = param.output->dims();
auto stride = param.conv_param.strides;
auto paddings = *param.conv_param.paddings;
auto dilation = *param.conv_param.dilations;
auto group = param.conv_param.groups;
auto deformable_group = param.deformable_groups;
auto num = in_dims[0];
auto cin = in_dims[1];
auto hin = in_dims[2];
auto win = in_dims[3];
auto cout = filter_dims[0];
auto kh = filter_dims[2];
auto kw = filter_dims[3];
auto hout = out_dims[2];
auto wout = out_dims[3];
bool is_bias = param.conv_param.bias ? true : false;
const float* bias =
param.conv_param.bias ? param.conv_param.bias->data<float>() : nullptr;
auto in_c_group = cin / group;
auto out_c_group = cout / group;
float alpha = 1.f;
const float beta = 0.f;
int in_size = hin * win;
int out_size = hout * wout;
int c_in_size = cin * in_size;
int c_out_size = cout * out_size;
int kernel_size = kw * kh;
int col_size = num * cin * kernel_size * in_size;
auto offset_in_size = 2 * group * kernel_size * in_size;
float* col_data = new float[col_size];
for (int n = 0; n < num; n++) {
for (int g = 0; g < group; ++g) {
const float* offset_data_ptr =
offset_data + n * offset_in_size + g * 2 * kernel_size * in_size;
const float* in_data_offset =
in_data + n * c_in_size + g * in_c_group * in_size;
float* col_data_g = col_data + n * c_in_size * kernel_size +
g * in_c_group * kernel_size * in_size;
for (int ic = 0; ic < in_c_group; ++ic) {
const float* in_data_ch = in_data_offset + ic * in_size;
float* col_data_ch = col_data_g + ic * kernel_size * in_size;
for (int fh = 0; fh < kh; fh++) {
for (int fw = 0; fw < kw; fw++) {
const float* offset_data_ptr_h =
offset_data_ptr + (2 * (fh * kw + fw)) * out_size;
const float* offset_data_ptr_w =
offset_data_ptr + (2 * (fh * kw + fw) + 1) * out_size;
float* col_data_g_ksize = col_data_ch + (fh * kw + fw) * in_size;
for (int ih = 0; ih < hin; ih++) {
const float* offset_data_ptr_h_w = offset_data_ptr_h + ih * wout;
const float* offset_data_ptr_w_w = offset_data_ptr_w + ih * wout;
float* col_data_g_ksize_h = col_data_g_ksize + ih * win;
for (int iw = 0; iw < win; iw++) {
const float offset_h = *offset_data_ptr_h_w++;
const float offset_w = *offset_data_ptr_w_w++;
const float im_w =
iw * stride[1] - paddings[2] + kw * dilation[1] + offset_w;
const float im_h =
ih * stride[0] - paddings[0] + kh * dilation[0] + offset_h;
if (im_h >= 0 && im_h < hin && im_w >= 0 && im_w < win) {
float val =
deformable_bilinear(in_data_ch, hin, win, im_h, im_w);
if (param.modulated) {
// use mask
const float* mask_ptr =
mask_data + n * group * kernel_size * in_size +
g * kernel_size * in_size +
(fh * kw + fw) * hout * wout + ih * win + iw;
val *= mask_ptr[0];
}
*col_data_g_ksize_h++ = val;
} else {
*col_data_g_ksize_h++ = 0.0;
}
}
}
}
}
}
}
}
// convolution
int m = cout / group;
int n = hout * wout;
int k = cin * kernel_size / group;
int weights_size_per_group = m * k;
if (flag_trans_weights_) {
filter_data = weights_.data<float>();
}
for (int b = 0; b < num; ++b) {
for (int g = 0; g < group; ++g) {
float* dout_group = out_data + (b * cout + g * m) * out_size;
const float* din_group =
col_data + (b * cin + g * in_c_group) * in_size * kernel_size;
const float* weights_group = filter_data + g * weights_size_per_group;
const float* bias_group = bias + g * m;
if (n == 1) {
lite::arm::math::sgemv(
weights_group,
din_group,
dout_group,
false,
m,
k,
is_bias,
bias_group,
param.conv_param.activation_param.has_active,
param.conv_param.activation_param.active_type,
&ctx,
param.conv_param.activation_param.Relu_clipped_coef,
param.conv_param.activation_param.Leaky_relu_alpha);
} else {
int ldb = n;
lite::arm::math::sgemm_prepack(false,
m,
n,
k,
weights_group,
din_group,
ldb,
0.f,
dout_group,
n,
bias_group,
is_bias,
param.conv_param.activation_param,
&ctx);
}
}
}
delete[] col_data;
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
typedef paddle::lite::kernels::arm::DeformableConvCompute<PRECISION(kFloat),
PRECISION(kFloat)>
DeformableConvFp32;
REGISTER_LITE_KERNEL(deformconv2d, kARM, kFloat, kNCHW, DeformableConvFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Mask", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Offset", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/kernel.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <PrecisionType Ptype, PrecisionType OutType>
class DeformableConvCompute : public KernelLite<TARGET(kARM), Ptype> {
public:
virtual void PrepareForRun();
virtual void ReInitWhenNeeded() {
auto& param = this->template Param<param_t>();
auto& x_dims = param.x->dims();
auto w_dims = param.conv_param.filter->dims();
auto& ctx = this->ctx_->template As<ARMContext>();
auto o_dims = param.output->dims();
int n = o_dims[2] * o_dims[3];
if (last_shape_ == x_dims && last_weights_shape_ == w_dims) {
return;
}
if (n > 1) {
lite::arm::math::trans_gemm_weights<Ptype>(
*(param.conv_param.filter), weights_, param.conv_param.groups, &ctx);
flag_trans_weights_ = true;
} else if (n == 1) {
flag_trans_weights_ = false;
}
last_shape_ = x_dims;
last_weights_shape_ = w_dims;
}
virtual void Run();
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
impl_->SetProfileRuntimeKernelInfo(ch);
}
#endif
~DeformableConvCompute() = default;
private:
using param_t = operators::DeformableConvParam;
DDim last_shape_;
DDim last_weights_shape_;
bool flag_trans_weights_;
Tensor weights_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -142,6 +142,8 @@ add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS ${op_DEPS})
add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS})
add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS})
# for deformable-convNet
add_operator(deformable_conv_op basic SRCS deformable_conv_op.cc DEPS ${op_DEPS})
# 4. training op
add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/deformable_conv_op.h"
#include <algorithm>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool DeformableConvOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(param_.conv_param.filter);
CHECK_OR_FALSE(param_.mask);
CHECK_OR_FALSE(param_.offset);
// bias is optional.
const auto in_dims = param_.x->dims();
const auto filter_dims = param_.conv_param.filter->dims();
CHECK_OR_FALSE(in_dims.size() == 4);
CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size());
CHECK_OR_FALSE(in_dims.size() - param_.conv_param.strides.size() == 2U);
CHECK_EQ_OR_FALSE(filter_dims.size(), 4UL);
CHECK_EQ_OR_FALSE(filter_dims[0] % param_.conv_param.groups, 0);
CHECK_EQ_OR_FALSE(param_.conv_param.groups, param_.deformable_groups);
return true;
}
inline int DeformableConvOutputSize(int input_size,
int filter_size,
int dilation,
int pad_left,
int pad_right,
int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size =
(input_size + (pad_left + pad_right) - dkernel) / stride + 1;
return output_size;
}
bool DeformableConvOpLite::InferShapeImpl() const {
const auto in_dims = param_.x->dims();
const auto filter_dims = param_.conv_param.filter->dims();
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
auto paddings = *param_.conv_param.paddings;
auto dilations = *param_.conv_param.dilations;
for (size_t i = 0; i < param_.conv_param.strides.size(); ++i) {
output_shape.push_back(
DeformableConvOutputSize(in_dims[i + 2],
filter_dims[i + 2],
dilations[i],
paddings[2 * i],
paddings[2 * i + 1],
param_.conv_param.strides[i]));
}
// Set output dims
param_.output->Resize(lite::DDim(output_shape));
// share LoD
param_.output->set_lod(param_.x->lod());
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(DeformableConv2d,
paddle::lite::operators::DeformableConvOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
#ifdef LITE_WITH_PROFILE
#include "lite/api/paddle_place.h"
#endif
namespace paddle {
namespace lite {
namespace operators {
class DeformableConvOpLite : public OpLite {
public:
DeformableConvOpLite() {}
explicit DeformableConvOpLite(const std::string& type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter* ch) {
auto filter_dims = param_.conv_param.filter->dims();
auto input_dims = param_.x->dims();
auto output_dims = param_.output->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->filter_shape = ch->DimToStr(filter_dims);
ch->remark =
std::to_string(filter_dims[2]) + "x" + std::to_string(filter_dims[3]) +
"p" + std::to_string((*param_.conv_param.paddings)[0]) + "s" +
std::to_string(param_.conv_param.strides[0]) + "g" +
std::to_string(param_.conv_param.groups) + "d" +
std::to_string((*param_.conv_param.dilations)[0]) +
(param_.conv_param.bias ? "Bias" : "") +
ActivationTypeToStr(param_.conv_param.activation_param.active_type);
// MACs = 2.f * kw * kh * batchsize * out_c * out_h * out_w * in_c / group
// GMACs = 1e-9f * MACs
// GMACPS = 1e-6f * MACs / predict_ms
ch->macs = 2.f * filter_dims[2] * filter_dims[3] *
output_dims.production() * input_dims[1] /
param_.conv_param.groups;
}
#endif
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
AttachParam(&param_);
auto X = op_desc.Input("Input").front();
auto Filter = op_desc.Input("Filter").front();
auto Mask = op_desc.Input("Mask").front();
auto Offset = op_desc.Input("Offset").front();
auto Out = op_desc.Output("Output").front();
param_.x = scope->FindVar(X)->GetMutable<lite::Tensor>();
param_.mask = scope->FindVar(Mask)->GetMutable<lite::Tensor>();
param_.offset = scope->FindVar(Offset)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.deformable_groups = op_desc.GetAttr<int>("deformable_groups");
param_.im2col_step = op_desc.GetAttr<int>("im2col_step");
param_.conv_param.filter =
scope->FindVar(Filter)->GetMutable<lite::Tensor>();
param_.conv_param.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings");
auto dilations = op_desc.GetAttr<std::vector<int>>("dilations");
param_.conv_param.groups = op_desc.GetAttr<int>("groups");
param_.conv_param.dilations = std::make_shared<std::vector<int>>(dilations);
// 2-pad to 4-pad
if (paddings.size() == 2L) {
for (size_t i = 0; i < param_.conv_param.strides.size(); ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
paddings.insert(paddings.begin() + 2 * i + 1, copy_pad);
}
} else {
if (paddings.size() != 4L) {
LOG(FATAL)
<< "Paddings size should be the same or twice as the input size.";
}
}
param_.conv_param.paddings = std::make_shared<std::vector<int>>(paddings);
// optional params
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_arguments = op_desc.Input("Bias");
if (bias_arguments.size() > 0) {
auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) {
param_.conv_param.bias =
const_cast<lite::Tensor*>(&(bias_var->Get<lite::Tensor>()));
}
}
}
if (op_desc.HasAttr("with_act") && op_desc.GetAttr<bool>("with_act")) {
param_.conv_param.activation_param.has_active = true;
auto act_type = op_desc.GetAttr<std::string>("act_type");
if (act_type == "relu") {
param_.conv_param.activation_param.active_type =
lite_api::ActivationType::kRelu;
param_.conv_param.fuse_relu = true;
} else if (act_type == "relu6") {
param_.conv_param.activation_param.active_type =
lite_api::ActivationType::kRelu6;
param_.conv_param.activation_param.Relu_clipped_coef =
op_desc.GetAttr<float>("fuse_brelu_threshold"); // 6.f
} else if (act_type == "leaky_relu") {
param_.conv_param.activation_param.active_type =
lite_api::ActivationType::kLeakyRelu;
param_.conv_param.activation_param.Leaky_relu_alpha =
op_desc.GetAttr<float>("leaky_relu_alpha");
} else {
CHECK(false) << "The fused DeformableConv only supports fuse with relu"
"and leaky relu";
}
}
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "DeformableConv2d"; }
private:
mutable DeformableConvParam param_;
std::string padding_algorithm_{""};
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -1515,6 +1515,39 @@ struct XPUFcParam : ParamBase {
std::string activation_type{""};
};
// For DeformableConvolution op
struct DeformableConvParam : ParamBase {
lite::Tensor* x{};
lite::Tensor* offset{};
lite::Tensor* mask{};
lite::Tensor* output{};
int deformable_groups{1};
int im2col_step{1};
bool modulated{true}; // True-v2 False-v1
std::string data_format{"Anylayout"};
// convolution parameter
ConvParam conv_param;
// support var_length or not
bool var_length{false};
// only used in conv_transpose.
std::vector<int> output_size;
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct PixelShuffleParam : ParamBase {
lite::Tensor* x{nullptr};
lite::Tensor* output{nullptr};
......
......@@ -8,8 +8,10 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_MLU) AND (LITE
lite_cc_test(conv_transpose_compute_test SRCS conv_transpose_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(conv_int8_compute_test SRCS conv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(pool_compute_test SRCS pool_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(deformable_conv_compute_test SRCS deformable_conv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
lite_cc_test(deformable_conv_compute_test SRCS deformable_conv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(layout_compute_test SRCS layout_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/core/context.h"
#include "lite/core/profile/timer.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/naive_math_impl.h"
#include "lite/tests/utils/tensor_utils.h"
#ifdef LITE_WITH_ARM
#include "lite/kernels/arm/deformable_conv_compute.h"
#endif // LITE_WITH_ARM
DEFINE_int32(power_mode,
3,
"power mode: "
"0 for POWER_HIGH;"
"1 for POWER_LOW;"
"2 for POWER_FULL;"
"3 for NO_BIND");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size");
DEFINE_int32(in_channel, 32, "input channel");
DEFINE_int32(in_height, 112, "input height");
DEFINE_int32(in_width, 112, "input width");
DEFINE_int32(out_channel, 32, "output channel");
DEFINE_int32(group, 1, "group");
DEFINE_int32(kernel_h, 3, "kernel height");
DEFINE_int32(kernel_w, 3, "kernel width");
DEFINE_int32(pad_h, 1, "pad height");
DEFINE_int32(pad_w, 1, "pad width");
DEFINE_int32(stride_h, 1, "stride height");
DEFINE_int32(stride_w, 1, "stride width");
DEFINE_int32(dila_h, 1, "dilation height");
DEFINE_int32(dila_w, 1, "dilation width");
DEFINE_int32(flag_act,
0,
"do activation"); // 0-no act, 1-relu, 2-relu6, 4-leakyrelu
DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha");
DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::DeformableConvParam DeformableConvParam;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in,
const paddle::lite::operators::ConvParam& param) {
DDim dim_out = dim_in;
auto paddings = *param.paddings;
auto dilations = *param.dilations;
dim_out[1] = param.filter->dims()[0];
auto kernel_h = param.filter->dims()[2];
auto kernel_w = param.filter->dims()[3];
auto h = dim_in[2];
auto w = dim_in[3];
int dila_h = dilations[0];
int dila_w = dilations[1];
int pad_top = paddings[0];
int pad_bottom = paddings[1];
int pad_left = paddings[2];
int pad_right = paddings[3];
int stride_h = param.strides[0];
int stride_w = param.strides[1];
auto kernel_exten = dila_h * (kernel_h - 1) + 1;
auto hout = (h + pad_top + pad_bottom - kernel_exten) / stride_h + 1;
kernel_exten = dila_w * (kernel_w - 1) + 1;
auto wout = (w + pad_left + pad_right - kernel_exten) / stride_w + 1;
dim_out[2] = hout;
dim_out[3] = wout;
return dim_out;
}
#ifdef LITE_WITH_ARM
void test_deformable_conv_fp32(const std::vector<DDim>& input_dims,
const DDim& weight_dim,
int group,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
bool modulated,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode,
const float leakey_relu_scale) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
DeformableConvParam param;
param.x = new Tensor;
param.x->set_precision(PRECISION(kFloat));
param.conv_param.filter = new Tensor;
param.conv_param.filter->Resize(weight_dim);
param.conv_param.filter->set_precision(PRECISION(kFloat));
param.offset = new Tensor;
param.offset->set_precision(PRECISION(kFloat));
param.mask = new Tensor;
param.mask->set_precision(PRECISION(kFloat));
if (flag_bias) {
param.conv_param.bias = new Tensor;
param.conv_param.bias->Resize({weight_dim[0]});
param.conv_param.bias->set_precision(PRECISION(kFloat));
}
param.conv_param.strides = strides;
param.conv_param.paddings = std::make_shared<std::vector<int>>(pads);
param.conv_param.dilations = std::make_shared<std::vector<int>>(dilas);
param.conv_param.groups = group;
param.deformable_groups = group;
param.modulated = modulated;
const float six = 6.f;
int flag_act = flag_relu ? 1 : 0;
if (flag_act > 0) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type = (paddle::lite_api::ActivationType)
flag_act; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_act == 1) {
param.conv_param.fuse_relu = true;
} else if (flag_act == 2) {
act_param.Relu_clipped_coef = six;
} else if (flag_act == 4) {
act_param.Leaky_relu_alpha = leakey_relu_scale;
}
param.conv_param.activation_param = act_param;
}
param.output = new Tensor;
param.output->set_precision(PRECISION(kFloat));
paddle::lite::fill_tensor_rand(*param.conv_param.filter, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.filter, 1.f);
if (flag_bias) {
paddle::lite::fill_tensor_rand(*param.conv_param.bias, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.bias, 1.f);
}
auto wptr = param.conv_param.filter->data<float>();
auto bias_ptr = flag_bias ? param.conv_param.bias->data<float>() : nullptr;
for (auto& cls : power_mode) {
for (auto& th : thread_num) {
paddle::lite::kernels::arm::DeformableConvCompute<PRECISION(kFloat),
PRECISION(kFloat)>
deformableConv;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th);
/// set param and context
for (auto& dim_in : input_dims) {
param.x->Resize(dim_in);
DDim out_tmp_dims = compute_out_dim(dim_in, param.conv_param);
if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) {
continue;
}
param.output->Resize(out_tmp_dims);
break;
}
deformableConv.SetParam(param);
deformableConv.SetContext(std::move(ctx1));
/// prepare for run
deformableConv.PrepareForRun();
for (auto& dim_in : input_dims) {
CHECK_EQ(weight_dim[1] * group, dim_in[1])
<< "input channel must equal to weights channel";
DDim dim_out = compute_out_dim(dim_in, param.conv_param);
int num = dim_in[0];
int in_size = dim_in[2] * dim_in[3];
int kernel_size = weight_dim[2] * weight_dim[3];
param.offset->Resize(
{num, 2 * group * kernel_size, dim_in[2], dim_in[3]});
param.mask->Resize({num, group * kernel_size, dim_in[2], dim_in[3]});
paddle::lite::fill_tensor_rand(*param.offset, -1.f, 1.f);
paddle::lite::fill_tensor_rand(*param.mask, -1.f, 1.f);
if (dim_out[2] < 1 || dim_out[3] < 1) {
continue;
}
if (dim_out[2] != dim_in[2] || dim_out[3] != dim_in[3]) {
continue;
}
param.x->Resize(dim_in);
param.output->Resize(dim_out);
paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.x, 1.f);
auto din = param.x->data<float>();
Tensor tout_basic;
if (FLAGS_check_result) {
auto offset_data = param.offset->data<float>();
auto mask_data = param.mask->data<float>();
tout_basic.set_precision(PRECISION(kFloat));
tout_basic.Resize(dim_out);
fill_tensor_const(tout_basic, 0.f);
auto dout_basic = tout_basic.mutable_data<float>();
LOG(INFO) << "flag_relu: " << flag_relu;
deformable_conv_basic<float, float>(din,
offset_data,
mask_data,
dout_basic,
dim_in[0],
dim_out[1],
dim_out[2],
dim_out[3],
dim_in[1],
dim_in[2],
dim_in[3],
wptr,
bias_ptr,
group,
weight_dim[3],
weight_dim[2],
strides[1],
strides[0],
dilas[1],
dilas[0],
pads[2],
pads[0],
flag_bias,
flag_relu,
modulated);
}
/// warm up
for (int i = 0; i < FLAGS_warmup; ++i) {
deformableConv.Launch();
}
/// compute
Timer t0;
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.Start();
deformableConv.Launch();
t0.Stop();
}
double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] *
weight_dim[3] / param.conv_param.groups;
LOG(INFO) << "deformable conv fp32: input shape: " << dim_in
<< ", output shape" << dim_out
<< ",running time, avg: " << t0.LapTimes().Avg()
<< ", min time: " << t0.LapTimes().Min()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.LapTimes().Avg()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.LapTimes().Min();
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tout_basic, *param.output, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-3f) {
if (max_diff > 5e-4f) {
LOG(WARNING) << "weights data";
print_tensor(*param.conv_param.filter);
LOG(WARNING) << "basic result";
print_tensor(tout_basic);
LOG(WARNING) << "lite result";
print_tensor(*param.output);
Tensor tdiff;
tdiff.Resize(tout_basic.dims());
tdiff.set_precision(PRECISION(kFloat));
tensor_diff(tout_basic, *param.output, tdiff);
print_tensor(tdiff);
LOG(FATAL) << "test fp32 deformable conv: input: " << dim_in
<< ", output: " << dim_out
<< ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", "
<< pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", modulated: " << (modulated ? "V2" : "V1")
<< ", threads: " << th << ", power_mode: " << cls
<< " failed!!\n";
}
}
}
LOG(INFO) << "test fp32 deformable conv: input: " << dim_in
<< ", output: " << dim_out << ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", " << pads[2]
<< ", " << pads[3] << ", stride: " << strides[0] << ", "
<< strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", modulated: " << (modulated ? "V2" : "V1")
<< ", threads: " << th << ", power_mode: " << cls
<< " successed!!\n";
}
}
}
delete param.x;
delete param.conv_param.filter;
delete param.offset;
delete param.mask;
delete param.output;
delete param.conv_param.bias;
}
#else
void test_deformable_conv_fp32(const std::vector<DDim>& input_dims,
const DDim& weight_dim,
int group,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
bool modulated,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode,
const float leakey_relu_scale) {}
#endif // LITE_WITH_ARM
#if 1 /// random param conv
TEST(TestDeformableConvRand, test_deformable_conv_rand) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8}) {
for (auto& cout : {1, 5, 16}) {
for (auto& g : {1, 2}) {
for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) {
for (auto& stride : {1, 2}) {
for (auto& pad_h : {0, 1, 2}) {
for (auto& pad_w : {0, 1, 2}) {
for (auto& dila : {1, 2}) {
for (auto& modulated : {false, true}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1}) {
if (cin % g != 0 || cout % g != 0) {
continue;
}
std::vector<DDim> dims;
DDim weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 16, 19, 32, 64}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
const float leakey_relu_scale = 8.88;
test_deformable_conv_fp32(
dims,
weights_dim,
g,
{stride, stride},
{pad_h, pad_h, pad_w, pad_w},
{dila, dila},
flag_bias,
flag_act,
modulated,
{1},
{FLAGS_power_mode},
leakey_relu_scale);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
#endif /// random param conv
#if 1 /// custom
TEST(TestDeformableConvCustom, test_deformable_conv_fp32_custom_size) {
CHECK_EQ(FLAGS_in_channel % FLAGS_group, 0)
<< "input channel must be divided by group";
CHECK_EQ(FLAGS_out_channel % FLAGS_group, 0)
<< "num_output must be divided by group";
test_deformable_conv_fp32(
{DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})},
DDim({FLAGS_out_channel,
FLAGS_in_channel / FLAGS_group,
FLAGS_kernel_h,
FLAGS_kernel_w}),
FLAGS_group,
{FLAGS_stride_h, FLAGS_stride_w},
{FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w},
{FLAGS_dila_h, FLAGS_dila_w},
FLAGS_flag_bias,
FLAGS_flag_act,
true,
{FLAGS_threads},
{FLAGS_power_mode},
FLAGS_leakey_relu_alpha);
}
#endif // custom
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <cmath>
template <typename type>
static void basic_trans_mat_to_c4(const type* input,
......@@ -502,3 +503,145 @@ void deconv_basic(const Dtype1* din,
}
free(workspace_ptr);
}
float deformable_bilinear(const float* bottom_data,
const int data_width,
const int height,
const int width,
float h,
float w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
if (h_low >= height - 1) {
h_high = h_low = height - 1;
h = static_cast<float>(h_low);
} else {
h_high = h_low + 1;
}
if (w_low >= width - 1) {
w_high = w_low = width - 1;
w = static_cast<float>(w_low);
} else {
w_high = w_low + 1;
}
float lh = h - h_low;
float lw = w - w_low;
float hh = 1 - lh;
float hw = 1 - lw;
float v1 = bottom_data[h_low * data_width + w_low];
float v2 = bottom_data[h_low * data_width + w_high];
float v3 = bottom_data[h_high * data_width + w_low];
float v4 = bottom_data[h_high * data_width + w_high];
float w1 = hh * hw;
float w2 = hh * lw;
float w3 = lh * hw;
float w4 = lh * lw;
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
//! for float, dtype1 and type2 is float
//! for int8, dytpe1 is char, dtype2 is int
template <typename Dtype1, typename Dtype2>
void deformable_conv_basic(const Dtype1* in_data,
const float* offset_data,
const float* mask_data,
Dtype2* out_data,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const Dtype1* weights,
const Dtype2* bias,
int group,
int kernel_w,
int kernel_h,
int stride_w,
int stride_h,
int dila_w,
int dila_h,
int pad_w,
int pad_h,
bool flag_bias,
bool flag_relu,
bool modulated) {
int out_c_group = chout / group;
int in_c_group = chin / group;
int in_size = hin * win;
int out_size = hout * wout;
int c_in_size = chin * in_size;
int c_out_size = chout * out_size;
int kernel_size = kernel_w * kernel_h;
for (int n = 0; n < num; n++) {
#pragma omp parallel for collapse(4)
for (int g = 0; g < group; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < hout; oh++) {
for (int ow = 0; ow < wout; ow++) {
int out_idx = n * c_out_size + g * out_c_group * out_size +
oc * out_size + oh * wout + ow;
Dtype2 bias_d = flag_bias ? bias[g * out_c_group + oc] : 0;
out_data[out_idx] = bias_d + out_data[out_idx];
for (int ic = 0; ic < in_c_group; ++ic) {
for (int fh = 0; fh < kernel_h; fh++) {
for (int fw = 0; fw < kernel_w; fw++) {
const float* offset_data_ptr =
offset_data + n * group * 2 * kernel_size * out_size +
g * 2 * kernel_size * out_size;
const int data_offset_h_ptr =
((2 * (fh * kernel_w + fw)) * hout + oh) * wout + ow;
const int data_offset_w_ptr =
((2 * (fh * kernel_w + fw) + 1) * hout + oh) * wout + ow;
const float offset_h = offset_data_ptr[data_offset_h_ptr];
const float offset_w = offset_data_ptr[data_offset_w_ptr];
const float iw =
ow * stride_w - pad_w + kernel_w * dila_w + offset_w;
const float ih =
oh * stride_h - pad_h + kernel_h * dila_h + offset_h;
if (ih >= 0 && ih < hin && iw >= 0 && iw < win) {
const float map_h = kernel_h * dila_h + offset_h;
const float map_w = kernel_w * dila_w + offset_w;
const int cur_height = hin - (oh * stride_h - pad_h);
const int cur_width = win - (ow * stride_w - pad_w);
const float* in_data_offset =
in_data + n * c_in_size +
(g * in_c_group + ic) * in_size +
(oh * stride_h - pad_h) * win + (ow * stride_w - pad_w);
float val = deformable_bilinear(in_data_offset,
win,
cur_height,
cur_width,
map_h,
map_w);
if (modulated) {
// use mask
const float* mask_ptr =
mask_data + n * group * kernel_size * out_size +
g * kernel_size * out_size +
(fh * kernel_w + fw) * hout * wout + oh * wout + ow;
val *= mask_ptr[0];
}
int widx = g * out_c_group * in_c_group * kernel_size +
oc * in_c_group * kernel_size +
ic * kernel_size + fh * kernel_w + fw;
out_data[out_idx] += val * weights[widx];
}
}
}
}
if (flag_relu) {
out_data[out_idx] = out_data[out_idx] > 0 ? out_data[out_idx] : 0;
}
}
}
}
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册