提交 6f95f589 编写于 作者: H Hong Ming 提交者: Tensor Tang

enable conv_winograd, fix conv_gemmlike bug, and update the unit tests of conv op

test=develop
上级 e0e47bdf
...@@ -26,5 +26,6 @@ cc_library(math_arm SRCS ...@@ -26,5 +26,6 @@ cc_library(math_arm SRCS
conv_depthwise.cc conv_depthwise.cc
conv_gemmlike.cc conv_gemmlike.cc
conv_winograd_3x3.cc conv_winograd_3x3.cc
conv_winograd.cc
DEPS ${lite_kernel_deps} eigen3) DEPS ${lite_kernel_deps} eigen3)
...@@ -13,10 +13,6 @@ ...@@ -13,10 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/kernels/arm/conv_compute.h" #include "paddle/fluid/lite/kernels/arm/conv_compute.h"
#include "paddle/fluid/lite/arm/math/conv_direct.h"
#include "paddle/fluid/lite/arm/math/conv_depthwise.h"
#include "paddle/fluid/lite/arm/math/conv_gemmlike.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/type_system.h"
...@@ -25,7 +21,7 @@ namespace lite { ...@@ -25,7 +21,7 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void ConvCompute::Run() { void ConvCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto x_dims = param.x->dims(); auto x_dims = param.x->dims();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
...@@ -61,44 +57,42 @@ void ConvCompute::Run() { ...@@ -61,44 +57,42 @@ void ConvCompute::Run() {
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
// select conv impl // select conv impl
// TODO(xxx): enable more
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
// dw conv impl // dw conv impl
impl_ = new lite::arm::math::DepthwiseConv<PRECISION(kFloat)>; impl_ = new lite::arm::math::DepthwiseConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking dw conv"; // LOG(INFO) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) { no_dilation) {
if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) { if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) {
// winograd conv impl // winograd conv impl
// impl_ = new lite::arm::math::WinogradConv<PRECISION(kFloat)>; impl_ = new lite::arm::math::WinogradConv<PRECISION(kFloat)>;
LOG(FATAL) << "TODO!!! winograd conv"; // LOG(INFO) << "invoking winograd conv";
} else { } else {
// direct conv impl // direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>; impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking direct conv"; // LOG(INFO) << "invoking direct conv";
} }
} else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal &&
no_dilation) { no_dilation) {
// direct conv impl // direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>; impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
// LOG(INFO) << "invoking direct conv";
} else { } else {
impl_ = new lite::arm::math::GemmLikeConv<PRECISION(kFloat)>; impl_ = new lite::arm::math::GemmLikeConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking gemm like conv"; // LOG(INFO) << "invoking gemm like conv";
} }
this->impl_->create(param, &ctx); CHECK(this->impl_->create(param, &ctx));
}
void ConvCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(impl_); CHECK(impl_);
impl_->run(param); impl_->run(param);
// if (this->act_ != nullptr) { // if (this->act_ != nullptr) {
// this->act_->run(outputs, outputs, param.activation_param); // this->act_->run(outputs, outputs, param.activation_param);
// } // }
} }
TargetType ConvCompute::target() const { return TARGET(kARM); }
PrecisionType ConvCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/arm/math/conv_impl.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/operators/conv_op.h" #include "paddle/fluid/lite/operators/conv_op.h"
...@@ -27,10 +26,9 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -27,10 +26,9 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
using param_t = operators::ConvParam; using param_t = operators::ConvParam;
void Run() override; void PrepareForRun() override;
TargetType target() const override; void Run() override;
PrecisionType precision() const override;
virtual ~ConvCompute() = default; virtual ~ConvCompute() = default;
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
namespace paddle { namespace paddle {
...@@ -76,7 +75,9 @@ void conv_compute_ref(const operators::ConvParam& param) { ...@@ -76,7 +75,9 @@ void conv_compute_ref(const operators::ConvParam& param) {
int out_idx = n * groups * out_c_group * hout * wout + int out_idx = n * groups * out_c_group * hout * wout +
g * out_c_group * hout * wout + oc * hout * wout + g * out_c_group * hout * wout + oc * hout * wout +
oh * wout + ow; oh * wout + ow;
output_data[out_idx] = 0.0f; output_data[out_idx] =
flag_bias ? static_cast<float>(bias_data[g * out_c_group + oc])
: 0.f;
for (int ic = 0; ic < in_c_group; ++ic) { for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) { for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) { for (int kw = 0; kw < kernel_w; ++kw) {
...@@ -97,9 +98,6 @@ void conv_compute_ref(const operators::ConvParam& param) { ...@@ -97,9 +98,6 @@ void conv_compute_ref(const operators::ConvParam& param) {
} }
} }
} }
output_data[out_idx] +=
flag_bias ? static_cast<float>(bias_data[g * out_c_group + oc])
: 0.f;
if (flag_relu) { if (flag_relu) {
output_data[out_idx] = output_data[out_idx] =
output_data[out_idx] > 0.f ? output_data[out_idx] : 0.f; output_data[out_idx] > 0.f ? output_data[out_idx] : 0.f;
...@@ -112,8 +110,8 @@ void conv_compute_ref(const operators::ConvParam& param) { ...@@ -112,8 +110,8 @@ void conv_compute_ref(const operators::ConvParam& param) {
} }
TEST(conv_arm, retrive_op) { TEST(conv_arm, retrive_op) {
auto conv = auto conv = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("conv2d"); "conv2d");
ASSERT_FALSE(conv.empty()); ASSERT_FALSE(conv.empty());
ASSERT_TRUE(conv.front()); ASSERT_TRUE(conv.front());
} }
...@@ -125,73 +123,72 @@ TEST(conv_arm, init) { ...@@ -125,73 +123,72 @@ TEST(conv_arm, init) {
} }
TEST(conv_arm, compute) { TEST(conv_arm, compute) {
ConvCompute conv;
operators::ConvParam param;
lite::Tensor input; lite::Tensor input;
lite::Tensor filter; lite::Tensor filter;
lite::Tensor bias; lite::Tensor bias;
lite::Tensor output; lite::Tensor output;
lite::Tensor output_ref; lite::Tensor output_ref;
DeviceInfo::Init(); DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
conv.SetContext(std::move(ctx));
for (auto n : {1, 2}) { for (auto n : {1, 2}) {
for (auto chin : {3, 8, /*32, 128*/}) { for (auto ic : {6, 32 /*, 128*/}) {
for (auto chout : {3, 8, /*32, 128*/}) { for (auto oc : {6, 32 /*, 128*/}) {
for (auto hin : {7, 14, 28, /*56 , 112, 224, 512*/}) { for (auto ih : {9, 18 /*, 56 , 112, 224, 512*/}) {
for (auto win : {7, 14, 28, /*56, 112, 224, 512*/}) { for (auto iw : {9, 18 /*, 56, 112, 224, 512*/}) {
for (auto flag_bias : {false , true}) { for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false , true}) { for (auto flag_relu : {false, true}) {
for (auto depthwise : {false, true}) { for (auto depthwise : {false, true}) {
for (auto dilation : {1 /*, 2*/}) { for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) { for (auto stride : {1, 2}) {
for (auto padding : {0, 1}) { for (auto padding : {0, 1, 2}) {
for (auto ks : {/*1, */3/*, 5*/}) { for (auto ks : {1, 3, 5}) {
int group = 1; int group = 1;
if (depthwise) { // depthwise conv ? if (depthwise) { // depthwise convolution ?
group = chin; group = oc = ic;
chout = chin;
// remove the follow code if
// all kernels are implemented.
if (ks == 5) {
stride = 2;
padding = 2;
}
} }
// get input, filter and output shape // get input, filter and output shape
std::vector<int64_t> input_shape = {n, chin, hin, std::vector<int64_t> input_shape = {n, ic, ih, iw};
win}; std::vector<int64_t> filter_shape = {oc, ic / group,
std::vector<int64_t> filter_shape = { ks, ks};
chout, chin / group, ks, ks}; std::vector<int64_t> output_shape({n, oc});
std::vector<int64_t> output_shape({n, chout});
const int dkernel = dilation * (ks - 1) + 1; const int dkernel = dilation * (ks - 1) + 1;
output_shape.push_back( output_shape.push_back(
(hin + 2 * padding - dkernel) / stride + 1); (ih + 2 * padding - dkernel) / stride + 1);
output_shape.push_back( output_shape.push_back(
(win + 2 * padding - dkernel) / stride + 1); (iw + 2 * padding - dkernel) / stride + 1);
// resize input, filter and output // resize input, filter and output
input.Resize(DDim(input_shape)); input.Resize(input_shape);
filter.Resize(DDim(filter_shape)); filter.Resize(filter_shape);
output.Resize(DDim(output_shape)); output.Resize(output_shape);
output_ref.Resize(DDim(output_shape)); output_ref.Resize(output_shape);
auto* input_data = input.mutable_data<float>(); auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>(); auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>(); auto* output_data = output.mutable_data<float>();
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) { for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128); input_data[i] = static_cast<float>(i % 128);
} }
for (int i = 0; i < filter.dims().production(); i++) { for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] = i / 1000.0f; filter_data[i] =
i * 0.001f /
static_cast<float>(filter.dims().production());
} }
// prepare kernel params and run
ConvCompute conv;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
conv.SetContext(std::move(ctx));
operators::ConvParam param;
param.x = &input; param.x = &input;
param.filter = &filter; param.filter = &filter;
param.output = &output; param.output = &output;
param.bias = nullptr; param.bias = nullptr;
if (flag_bias) {
bias.Resize({oc});
auto* bias_data = bias.mutable_data<float>();
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i);
}
param.bias = &bias;
}
// TODO(hong19860320) param.relu = flag_relu; // TODO(hong19860320) param.relu = flag_relu;
param.paddings = std::vector<int>({padding, padding}); param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride}); param.strides = std::vector<int>({stride, stride});
...@@ -199,9 +196,12 @@ TEST(conv_arm, compute) { ...@@ -199,9 +196,12 @@ TEST(conv_arm, compute) {
std::vector<int>({dilation, dilation}); std::vector<int>({dilation, dilation});
param.groups = group; param.groups = group;
conv.SetParam(param); conv.SetParam(param);
conv.Run(); conv.Launch();
// invoking ref implementation and compare results
param.output = &output_ref; param.output = &output_ref;
conv_compute_ref<float>(param); conv_compute_ref<float>(param);
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) { for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], EXPECT_NEAR(output_data[i], output_ref_data[i],
1e-3); 1e-3);
...@@ -218,59 +218,6 @@ TEST(conv_arm, compute) { ...@@ -218,59 +218,6 @@ TEST(conv_arm, compute) {
} }
} }
} }
#if 0
// for testing gemm like conv
int n = 1;
int chin = 8;
int chout = 8;
int hin = 14;
int win = 14;
int flag_bias = false;
int flag_relu = false;
int dilation = 1;
int stride = 1;
int padding = 1;
int ks = 5;
int group = 1;
// get input, filter and output shape
std::vector<int64_t> input_shape = {n, chin, hin, win};
std::vector<int64_t> filter_shape = {chout, chin / group, ks, ks};
std::vector<int64_t> output_shape({n, chout});
const int dkernel = dilation * (ks - 1) + 1;
output_shape.push_back((hin + 2 * padding - dkernel) / stride + 1);
output_shape.push_back((win + 2 * padding - dkernel) / stride + 1);
// resize input, filter and output
input.Resize(DDim(input_shape));
filter.Resize(DDim(filter_shape));
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] = i / 1000.0f;
}
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
// TODO(hong19860320) param.relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations = std::vector<int>({dilation, dilation});
param.groups = group;
conv.SetParam(param);
conv.Run();
param.output = &output_ref;
conv_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-3);
}
#endif
} }
} // namespace arm } // namespace arm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册