提交 8c572b39 编写于 作者: T tensor-tang

enable more conv impls

上级 aa55112f
...@@ -11,9 +11,20 @@ cc_library(math_arm SRCS ...@@ -11,9 +11,20 @@ cc_library(math_arm SRCS
scale.cc scale.cc
elementwise.cc elementwise.cc
sgemv.cc sgemv.cc
type_trans.cpp
conv_impl.cc conv_impl.cc
conv_direct_3x3s1.cc conv_direct_3x3s1.cc
conv_direct_3x3s2.cc conv_direct_3x3s2.cc
conv_direct.cc conv_direct.cc
conv_depthwise_3x3_int7.cc
conv_depthwise_3x3_int8.cc
conv_depthwise_5x5s1_int8.cc
conv_depthwise_3x3p0.cc
conv_depthwise_3x3p1.cc
conv_depthwise_5x5s1.cc
conv_depthwise_5x5s2.cc
conv_depthwise.cc
conv_gemmlike.cc
conv_winograd_3x3.cc
DEPS ${lite_kernel_deps} eigen3) DEPS ${lite_kernel_deps} eigen3)
此差异已折叠。
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/fluid/lite/kernels/arm/conv_compute.h" #include "paddle/fluid/lite/kernels/arm/conv_compute.h"
#include "paddle/fluid/lite/arm/math/conv_direct.h" #include "paddle/fluid/lite/arm/math/conv_direct.h"
#include "paddle/fluid/lite/arm/math/conv_depthwise.h"
#include "paddle/fluid/lite/arm/math/conv_gemmlike.h"
#include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/type_system.h"
...@@ -62,22 +64,26 @@ void ConvCompute::Run() { ...@@ -62,22 +64,26 @@ void ConvCompute::Run() {
// TODO(xxx): enable more // TODO(xxx): enable more
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
// dw conv impl // dw conv impl
// impl_ = new lite::arm::math::prepackA<PRECISION(kFloat)>; impl_ = new lite::arm::math::DepthwiseConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) { no_dilation) {
if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) { if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) {
// winograd conv impl // winograd conv impl
// impl_ = new lite::arm::math::WinogradConv<PRECISION(kFloat)>; // impl_ = new lite::arm::math::WinogradConv<PRECISION(kFloat)>;
LOG(FATAL) << "TODO!!! winograd conv";
} else { } else {
// direct conv impl // direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>; impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking direct conv";
} }
} else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal &&
no_dilation) { no_dilation) {
// direct conv impl // direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>; impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
} else { } else {
// impl_ = new lite::arm::math::GemmLikeConv<PRECISION(kFloat)>; impl_ = new lite::arm::math::GemmLikeConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking gemm like conv";
} }
this->impl_->create(param, &ctx); this->impl_->create(param, &ctx);
...@@ -98,7 +104,15 @@ PrecisionType ConvCompute::precision() const { return PRECISION(kFloat); } ...@@ -98,7 +104,15 @@ PrecisionType ConvCompute::precision() const { return PRECISION(kFloat); }
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(conv, kARM, kFloat, kNCHW, REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def) paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/fluid/lite/kernels/arm/conv_compute.h" #include "paddle/fluid/lite/kernels/arm/conv_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
...@@ -23,9 +25,95 @@ namespace lite { ...@@ -23,9 +25,95 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <typename dtype>
void conv_compute_ref(const operators::ConvParam& param) {
auto input = param.x;
auto filter = param.filter;
auto output = param.output;
DDim input_dims = param.x->dims();
DDim filter_dims = param.filter->dims();
DDim output_dims = param.output->dims();
std::vector<int> paddings = param.paddings;
std::vector<int> strides = param.strides;
std::vector<int> dilations = param.dilations;
int groups = param.groups;
auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>();
auto filter_data = param.filter->mutable_data<float>();
const float* bias_data = nullptr;
if (param.bias != nullptr) {
bias_data = param.bias->mutable_data<float>();
}
bool flag_bias = bias_data != nullptr;
bool flag_relu = false; // TODO(hong19860320) param.relu
int num = input_dims[0];
int chout = output_dims[1];
int hout = output_dims[2];
int wout = output_dims[3];
int chin = input_dims[1];
int hin = input_dims[2];
int win = input_dims[3];
int out_c_group = chout / groups;
int in_c_group = chin / groups;
int stride_h = strides[0];
int stride_w = strides[1];
int dilation_h = dilations[0];
int dilation_w = dilations[1];
int padding_h = paddings[0];
int padding_w = paddings[1];
int kernel_h = filter_dims[2];
int kernel_w = filter_dims[3];
for (int n = 0; n < num; ++n) {
for (int g = 0; g < groups; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < hout; ++oh) {
for (int ow = 0; ow < wout; ++ow) {
int out_idx = n * groups * out_c_group * hout * wout +
g * out_c_group * hout * wout + oc * hout * wout +
oh * wout + ow;
output_data[out_idx] = 0.0f;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - padding_w + kw * (dilation_w);
int ih = oh * stride_h - padding_h + kh * (dilation_h);
if (iw < 0 || iw >= win) continue;
if (ih < 0 || ih >= hin) continue;
int iidx = n * chin * hin * win + g * in_c_group * hin * win +
ic * hin * win + ih * win + iw;
int widx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
output_data[out_idx] +=
(dtype)input_data[iidx] * (dtype)filter_data[widx];
}
}
}
output_data[out_idx] +=
flag_bias ? static_cast<float>(bias_data[g * out_c_group + oc])
: 0.f;
if (flag_relu) {
output_data[out_idx] =
output_data[out_idx] > 0.f ? output_data[out_idx] : 0.f;
}
}
}
}
}
}
}
TEST(conv_arm, retrive_op) { TEST(conv_arm, retrive_op) {
auto conv = auto conv =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("conv"); KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("conv2d");
ASSERT_FALSE(conv.empty()); ASSERT_FALSE(conv.empty());
ASSERT_TRUE(conv.front()); ASSERT_TRUE(conv.front());
} }
...@@ -36,8 +124,153 @@ TEST(conv_arm, init) { ...@@ -36,8 +124,153 @@ TEST(conv_arm, init) {
ASSERT_EQ(conv.target(), TARGET(kARM)); ASSERT_EQ(conv.target(), TARGET(kARM));
} }
TEST(conv_arm, compare_test) { TEST(conv_arm, compute) {
// TODO(xxx): add more compare ConvCompute conv;
operators::ConvParam param;
lite::Tensor input;
lite::Tensor filter;
lite::Tensor bias;
lite::Tensor output;
lite::Tensor output_ref;
DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
conv.SetContext(std::move(ctx));
for (auto n : {1, 2}) {
for (auto chin : {3, 8, /*32, 128*/}) {
for (auto chout : {3, 8, /*32, 128*/}) {
for (auto hin : {7, 14, 28, /*56 , 112, 224, 512*/}) {
for (auto win : {7, 14, 28, /*56, 112, 224, 512*/}) {
for (auto flag_bias : {false , true}) {
for (auto flag_relu : {false , true}) {
for (auto depthwise : {false, true}) {
for (auto dilation : {1 /*, 2*/}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1}) {
for (auto ks : {/*1, */3/*, 5*/}) {
int group = 1;
if (depthwise) { // depthwise conv ?
group = chin;
chout = chin;
// remove the follow code if
// all kernels are implemented.
if (ks == 5) {
stride = 2;
padding = 2;
}
}
// get input, filter and output shape
std::vector<int64_t> input_shape = {n, chin, hin,
win};
std::vector<int64_t> filter_shape = {
chout, chin / group, ks, ks};
std::vector<int64_t> output_shape({n, chout});
const int dkernel = dilation * (ks - 1) + 1;
output_shape.push_back(
(hin + 2 * padding - dkernel) / stride + 1);
output_shape.push_back(
(win + 2 * padding - dkernel) / stride + 1);
// resize input, filter and output
input.Resize(DDim(input_shape));
filter.Resize(DDim(filter_shape));
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] = i / 1000.0f;
}
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
// TODO(hong19860320) param.relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations =
std::vector<int>({dilation, dilation});
param.groups = group;
conv.SetParam(param);
conv.Run();
param.output = &output_ref;
conv_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i],
1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
}
#if 0
// for testing gemm like conv
int n = 1;
int chin = 8;
int chout = 8;
int hin = 14;
int win = 14;
int flag_bias = false;
int flag_relu = false;
int dilation = 1;
int stride = 1;
int padding = 1;
int ks = 5;
int group = 1;
// get input, filter and output shape
std::vector<int64_t> input_shape = {n, chin, hin, win};
std::vector<int64_t> filter_shape = {chout, chin / group, ks, ks};
std::vector<int64_t> output_shape({n, chout});
const int dkernel = dilation * (ks - 1) + 1;
output_shape.push_back((hin + 2 * padding - dkernel) / stride + 1);
output_shape.push_back((win + 2 * padding - dkernel) / stride + 1);
// resize input, filter and output
input.Resize(DDim(input_shape));
filter.Resize(DDim(filter_shape));
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] = i / 1000.0f;
}
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
// TODO(hong19860320) param.relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations = std::vector<int>({dilation, dilation});
param.groups = group;
conv.SetParam(param);
conv.Run();
param.output = &output_ref;
conv_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-3);
}
#endif
} }
} // namespace arm } // namespace arm
...@@ -45,4 +278,5 @@ TEST(conv_arm, compare_test) { ...@@ -45,4 +278,5 @@ TEST(conv_arm, compare_test) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(conv, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def);
...@@ -73,4 +73,5 @@ bool ConvOpLite::InferShape() const { ...@@ -73,4 +73,5 @@ bool ConvOpLite::InferShape() const {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(conv, paddle::lite::operators::ConvOpLite); REGISTER_LITE_OP(conv2d, paddle::lite::operators::ConvOpLite);
REGISTER_LITE_OP(depthwise_conv2d, paddle::lite::operators::ConvOpLite);
\ No newline at end of file
...@@ -41,27 +41,39 @@ class ConvOpLite : public OpLite { ...@@ -41,27 +41,39 @@ class ConvOpLite : public OpLite {
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto input = op_desc.Input("Input").front(); auto input = op_desc.Input("Input").front();
auto filter = op_desc.Input("Filter").front(); auto filter = op_desc.Input("Filter").front();
auto bias = op_desc.Input("Bias").front();
auto resid = op_desc.Input("ResidualData").front(); // maybe not used
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>(); param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.filter = scope->FindVar(filter)->GetMutable<lite::Tensor>(); param_.filter = scope->FindVar(filter)->GetMutable<lite::Tensor>();
param_.residualData = scope->FindVar(resid)->GetMutable<lite::Tensor>();
param_.bias = scope->FindVar(bias)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(out)); CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides"); param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings"); param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups"); param_.groups = op_desc.GetAttr<int>("groups");
param_.dilations = op_desc.GetAttr<std::vector<int>>("dilations"); param_.dilations = op_desc.GetAttr<std::vector<int>>("dilations");
// optional params
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_var = scope->FindVar(op_desc.Input("Bias").front());
if (bias_var != nullptr) {
param_.bias =
const_cast<lite::Tensor *>(&(bias_var->Get<lite::Tensor>()));
}
}
if (std::find(input_arg_names.begin(), input_arg_names.end(), "ResidualData") !=
input_arg_names.end()) {
auto residual_data_var = scope->FindVar(op_desc.Input("ResidualData").front());
if (residual_data_var != nullptr) {
param_.residualData =
const_cast<lite::Tensor *>(&(residual_data_var->Get<lite::Tensor>()));
}
}
return true; return true;
} }
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "conv"; } std::string DebugString() const override { return "conv2d"; }
private: private:
mutable ConvParam param_; mutable ConvParam param_;
......
...@@ -124,8 +124,8 @@ struct ConcatParam { ...@@ -124,8 +124,8 @@ struct ConcatParam {
struct ConvParam { struct ConvParam {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* filter{}; lite::Tensor* filter{};
lite::Tensor* bias{}; lite::Tensor* bias{nullptr};
lite::Tensor* residualData{}; lite::Tensor* residualData{nullptr};
lite::Tensor* output{}; lite::Tensor* output{};
std::vector<int> strides{1, 1}; std::vector<int> strides{1, 1};
std::vector<int> paddings{0, 0}; std::vector<int> paddings{0, 0};
......
...@@ -34,7 +34,7 @@ function cmake_arm { ...@@ -34,7 +34,7 @@ function cmake_arm {
function build { function build {
file=$1 file=$1
for _test in $(cat $file); do for _test in $(cat $file); do
make $_test -j$(expr $(nproc) - 2) make $_test -j$(expr $(nproc))
done done
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册