提交 6f95f589 编写于 作者: H Hong Ming 提交者: Tensor Tang

enable conv_winograd, fix conv_gemmlike bug, and update the unit tests of conv op

test=develop
上级 e0e47bdf
......@@ -26,5 +26,6 @@ cc_library(math_arm SRCS
conv_depthwise.cc
conv_gemmlike.cc
conv_winograd_3x3.cc
conv_winograd.cc
DEPS ${lite_kernel_deps} eigen3)
......@@ -13,10 +13,6 @@
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/conv_compute.h"
#include "paddle/fluid/lite/arm/math/conv_direct.h"
#include "paddle/fluid/lite/arm/math/conv_depthwise.h"
#include "paddle/fluid/lite/arm/math/conv_gemmlike.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
......@@ -25,7 +21,7 @@ namespace lite {
namespace kernels {
namespace arm {
void ConvCompute::Run() {
void ConvCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
......@@ -61,44 +57,42 @@ void ConvCompute::Run() {
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
// select conv impl
// TODO(xxx): enable more
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
// dw conv impl
impl_ = new lite::arm::math::DepthwiseConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking dw conv";
// LOG(INFO) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) {
if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) {
// winograd conv impl
// impl_ = new lite::arm::math::WinogradConv<PRECISION(kFloat)>;
LOG(FATAL) << "TODO!!! winograd conv";
impl_ = new lite::arm::math::WinogradConv<PRECISION(kFloat)>;
// LOG(INFO) << "invoking winograd conv";
} else {
// direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking direct conv";
// LOG(INFO) << "invoking direct conv";
}
} else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal &&
no_dilation) {
// direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
// LOG(INFO) << "invoking direct conv";
} else {
impl_ = new lite::arm::math::GemmLikeConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking gemm like conv";
// LOG(INFO) << "invoking gemm like conv";
}
this->impl_->create(param, &ctx);
CHECK(this->impl_->create(param, &ctx));
}
void ConvCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(impl_);
impl_->run(param);
// if (this->act_ != nullptr) {
// this->act_->run(outputs, outputs, param.activation_param);
// }
}
TargetType ConvCompute::target() const { return TARGET(kARM); }
PrecisionType ConvCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -13,8 +13,7 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/arm/math/conv_impl.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/operators/conv_op.h"
......@@ -27,10 +26,9 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ConvParam;
void Run() override;
void PrepareForRun() override;
TargetType target() const override;
PrecisionType precision() const override;
void Run() override;
virtual ~ConvCompute() = default;
......
......@@ -17,7 +17,6 @@
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
......@@ -76,7 +75,9 @@ void conv_compute_ref(const operators::ConvParam& param) {
int out_idx = n * groups * out_c_group * hout * wout +
g * out_c_group * hout * wout + oc * hout * wout +
oh * wout + ow;
output_data[out_idx] = 0.0f;
output_data[out_idx] =
flag_bias ? static_cast<float>(bias_data[g * out_c_group + oc])
: 0.f;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
......@@ -97,9 +98,6 @@ void conv_compute_ref(const operators::ConvParam& param) {
}
}
}
output_data[out_idx] +=
flag_bias ? static_cast<float>(bias_data[g * out_c_group + oc])
: 0.f;
if (flag_relu) {
output_data[out_idx] =
output_data[out_idx] > 0.f ? output_data[out_idx] : 0.f;
......@@ -112,8 +110,8 @@ void conv_compute_ref(const operators::ConvParam& param) {
}
TEST(conv_arm, retrive_op) {
auto conv =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("conv2d");
auto conv = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"conv2d");
ASSERT_FALSE(conv.empty());
ASSERT_TRUE(conv.front());
}
......@@ -125,73 +123,72 @@ TEST(conv_arm, init) {
}
TEST(conv_arm, compute) {
ConvCompute conv;
operators::ConvParam param;
lite::Tensor input;
lite::Tensor filter;
lite::Tensor bias;
lite::Tensor output;
lite::Tensor output_ref;
DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
conv.SetContext(std::move(ctx));
for (auto n : {1, 2}) {
for (auto chin : {3, 8, /*32, 128*/}) {
for (auto chout : {3, 8, /*32, 128*/}) {
for (auto hin : {7, 14, 28, /*56 , 112, 224, 512*/}) {
for (auto win : {7, 14, 28, /*56, 112, 224, 512*/}) {
for (auto flag_bias : {false , true}) {
for (auto flag_relu : {false , true}) {
for (auto ic : {6, 32 /*, 128*/}) {
for (auto oc : {6, 32 /*, 128*/}) {
for (auto ih : {9, 18 /*, 56 , 112, 224, 512*/}) {
for (auto iw : {9, 18 /*, 56, 112, 224, 512*/}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto depthwise : {false, true}) {
for (auto dilation : {1 /*, 2*/}) {
for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1}) {
for (auto ks : {/*1, */3/*, 5*/}) {
for (auto padding : {0, 1, 2}) {
for (auto ks : {1, 3, 5}) {
int group = 1;
if (depthwise) { // depthwise conv ?
group = chin;
chout = chin;
// remove the follow code if
// all kernels are implemented.
if (ks == 5) {
stride = 2;
padding = 2;
}
if (depthwise) { // depthwise convolution ?
group = oc = ic;
}
// get input, filter and output shape
std::vector<int64_t> input_shape = {n, chin, hin,
win};
std::vector<int64_t> filter_shape = {
chout, chin / group, ks, ks};
std::vector<int64_t> output_shape({n, chout});
std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {oc, ic / group,
ks, ks};
std::vector<int64_t> output_shape({n, oc});
const int dkernel = dilation * (ks - 1) + 1;
output_shape.push_back(
(hin + 2 * padding - dkernel) / stride + 1);
(ih + 2 * padding - dkernel) / stride + 1);
output_shape.push_back(
(win + 2 * padding - dkernel) / stride + 1);
(iw + 2 * padding - dkernel) / stride + 1);
// resize input, filter and output
input.Resize(DDim(input_shape));
filter.Resize(DDim(filter_shape));
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
input.Resize(input_shape);
filter.Resize(filter_shape);
output.Resize(output_shape);
output_ref.Resize(output_shape);
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] = i / 1000.0f;
filter_data[i] =
i * 0.001f /
static_cast<float>(filter.dims().production());
}
// prepare kernel params and run
ConvCompute conv;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
conv.SetContext(std::move(ctx));
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
if (flag_bias) {
bias.Resize({oc});
auto* bias_data = bias.mutable_data<float>();
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i);
}
param.bias = &bias;
}
// TODO(hong19860320) param.relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
......@@ -199,9 +196,12 @@ TEST(conv_arm, compute) {
std::vector<int>({dilation, dilation});
param.groups = group;
conv.SetParam(param);
conv.Run();
conv.Launch();
// invoking ref implementation and compare results
param.output = &output_ref;
conv_compute_ref<float>(param);
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i],
1e-3);
......@@ -218,59 +218,6 @@ TEST(conv_arm, compute) {
}
}
}
#if 0
// for testing gemm like conv
int n = 1;
int chin = 8;
int chout = 8;
int hin = 14;
int win = 14;
int flag_bias = false;
int flag_relu = false;
int dilation = 1;
int stride = 1;
int padding = 1;
int ks = 5;
int group = 1;
// get input, filter and output shape
std::vector<int64_t> input_shape = {n, chin, hin, win};
std::vector<int64_t> filter_shape = {chout, chin / group, ks, ks};
std::vector<int64_t> output_shape({n, chout});
const int dkernel = dilation * (ks - 1) + 1;
output_shape.push_back((hin + 2 * padding - dkernel) / stride + 1);
output_shape.push_back((win + 2 * padding - dkernel) / stride + 1);
// resize input, filter and output
input.Resize(DDim(input_shape));
filter.Resize(DDim(filter_shape));
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] = i / 1000.0f;
}
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
// TODO(hong19860320) param.relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations = std::vector<int>({dilation, dilation});
param.groups = group;
conv.SetParam(param);
conv.Run();
param.output = &output_ref;
conv_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-3);
}
#endif
}
} // namespace arm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册