From 5e5d4ae51565e42335224f47f945b7b5e0a8b5c8 Mon Sep 17 00:00:00 2001 From: shixiaowei02 Date: Tue, 25 Jun 2019 05:28:11 +0000 Subject: [PATCH] add gemm-like conv --- paddle/fluid/lite/kernels/CMakeLists.txt | 4 +- paddle/fluid/lite/kernels/arm/conv_compute.cc | 32 +- paddle/fluid/lite/kernels/arm/conv_compute.h | 1 + .../lite/kernels/arm/conv_compute_test.cc | 280 ++++++++++++++---- paddle/fluid/lite/operators/op_params.h | 6 + 5 files changed, 253 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/lite/kernels/CMakeLists.txt b/paddle/fluid/lite/kernels/CMakeLists.txt index 1fad136cb..abc6d65bb 100644 --- a/paddle/fluid/lite/kernels/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/CMakeLists.txt @@ -1,5 +1,7 @@ message(STATUS "add lite kernels") -set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite context_lite ${tensor_lite}) + +set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite context_lite ${tensor_lite} CACHE INTERNAL "" FORCE) + add_subdirectory(host) add_subdirectory(arm) add_subdirectory(cuda) diff --git a/paddle/fluid/lite/kernels/arm/conv_compute.cc b/paddle/fluid/lite/kernels/arm/conv_compute.cc index af8f8e124..44223ee37 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute.cc +++ b/paddle/fluid/lite/kernels/arm/conv_compute.cc @@ -92,8 +92,24 @@ void ConvCompute::Run() { // } } -void ConvComputeInt8::PrepareForRun() {} -void ConvComputeInt8::Run() {} +template +void ConvComputeInt8::PrepareForRun() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + impl_ = new lite::arm::math::GemmLikeConvInt8; + CHECK(this->impl_->create(param, &ctx)); +} + +template +void ConvComputeInt8::Run() { + auto& param = this->Param(); + CHECK(impl_); + impl_->run(param); +} + +template class ConvComputeInt8; +template class ConvComputeInt8; +template class ConvComputeInt8; } // namespace arm } // namespace kernels @@ -116,8 +132,9 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, - paddle::lite::kernels::arm::ConvComputeInt8, def) +REGISTER_LITE_KERNEL( + conv2d, kARM, kInt8, kNCHW, + paddle::lite::kernels::arm::ConvComputeInt8, int8_out) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Filter", @@ -126,12 +143,13 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .Finalize(); -REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kInt8, kNCHW, - paddle::lite::kernels::arm::ConvComputeInt8, def) +REGISTER_LITE_KERNEL( + conv2d, kARM, kInt8, kNCHW, + paddle::lite::kernels::arm::ConvComputeInt8, fp32_out) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindOutput("Output", - {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/conv_compute.h b/paddle/fluid/lite/kernels/arm/conv_compute.h index e5d5721a3..28bf6ea7d 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute.h +++ b/paddle/fluid/lite/kernels/arm/conv_compute.h @@ -41,6 +41,7 @@ class ConvCompute : public KernelLite { nullptr}; }; +template class ConvComputeInt8 : public KernelLite { public: using param_t = operators::ConvParam; diff --git a/paddle/fluid/lite/kernels/arm/conv_compute_test.cc b/paddle/fluid/lite/kernels/arm/conv_compute_test.cc index f25a5cf07..6fc05dbe1 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/conv_compute_test.cc @@ -14,9 +14,11 @@ #include "paddle/fluid/lite/kernels/arm/conv_compute.h" #include +#include #include #include #include +#include "paddle/fluid/lite/arm/math/type_trans.h" #include "paddle/fluid/lite/core/op_registry.h" namespace paddle { @@ -24,83 +26,89 @@ namespace lite { namespace kernels { namespace arm { -template -void conv_compute_ref(const operators::ConvParam& param) { - auto input = param.x; - auto filter = param.filter; - auto output = param.output; - DDim input_dims = param.x->dims(); - DDim filter_dims = param.filter->dims(); - DDim output_dims = param.output->dims(); - std::vector paddings = param.paddings; - std::vector strides = param.strides; - std::vector dilations = param.dilations; - int groups = param.groups; - - auto input_data = param.x->data(); - auto output_data = param.output->mutable_data(); - auto filter_data = param.filter->mutable_data(); - const float* bias_data = nullptr; - if (param.bias != nullptr) { - bias_data = param.bias->mutable_data(); +static float compute_max_kernel(const float* din, int64_t size) { + float max_value = -std::numeric_limits::max(); + for (int64_t i = 0; i < size; i++) { + max_value = max_value > din[0] ? max_value : din[0]; } - bool flag_bias = bias_data != nullptr; - bool flag_relu = param.fuse_relu; + LOG(INFO) << "[max_value]: " << max_value; + return max_value; +} + +static std::vector get_tensor_scale_n(const float* in_data, + int axis_size, int64_t inner_size, + float scale_factor) { + std::vector scale_out(axis_size); + for (int c = 0; c < axis_size; ++c) { // num + const float* ptr_in = in_data + c * inner_size; // channel*width*height + scale_out[c] = compute_max_kernel(ptr_in, inner_size) / scale_factor; + } + for (auto s : scale_out) { + LOG(INFO) << "[Scale out]: " << s; + } + return scale_out; +} + +template +static void conv_basic(const Dtype1* din, Dtype2* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const Dtype1* weights, const Dtype2* bias, int group, + int kernel_w, int kernel_h, int stride_w, int stride_h, + int dila_w, int dila_h, int pad_w, int pad_h, + bool flag_bias, bool flag_relu) { + Dtype2 beta = 0; + auto src_data = din; + auto dst_data_ref = dout; + auto weights_data = weights; + auto with_bias = flag_bias; + auto bias_data = bias; + + int in_num = num; + int out_channels = chout; + int out_h = hout; + int out_w = wout; - int num = input_dims[0]; - int chout = output_dims[1]; - int hout = output_dims[2]; - int wout = output_dims[3]; - - int chin = input_dims[1]; - int hin = input_dims[2]; - int win = input_dims[3]; - int out_c_group = chout / groups; - int in_c_group = chin / groups; - - int stride_h = strides[0]; - int stride_w = strides[1]; - int dilation_h = dilations[0]; - int dilation_w = dilations[1]; - int padding_h = paddings[0]; - int padding_w = paddings[1]; - int kernel_h = filter_dims[2]; - int kernel_w = filter_dims[3]; - - for (int n = 0; n < num; ++n) { - for (int g = 0; g < groups; ++g) { + int in_channel = chin; + int in_h = hin; + int in_w = win; + int out_c_group = out_channels / group; + int in_c_group = in_channel / group; + + for (int n = 0; n < in_num; ++n) { + for (int g = 0; g < group; ++g) { for (int oc = 0; oc < out_c_group; ++oc) { - for (int oh = 0; oh < hout; ++oh) { - for (int ow = 0; ow < wout; ++ow) { - int out_idx = n * groups * out_c_group * hout * wout + - g * out_c_group * hout * wout + oc * hout * wout + - oh * wout + ow; - output_data[out_idx] = - flag_bias ? static_cast(bias_data[g * out_c_group + oc]) - : 0.f; + for (int oh = 0; oh < out_h; ++oh) { + for (int ow = 0; ow < out_w; ++ow) { + int out_idx = n * group * out_c_group * out_h * out_w + + g * out_c_group * out_h * out_w + oc * out_h * out_w + + oh * out_w + ow; + Dtype2 bias_d = + with_bias ? (bias_data[g * out_c_group + oc]) : (Dtype2)0; + dst_data_ref[out_idx] = bias_d; // + dst_data_ref[out_idx] * beta; for (int ic = 0; ic < in_c_group; ++ic) { for (int kh = 0; kh < kernel_h; ++kh) { for (int kw = 0; kw < kernel_w; ++kw) { - int iw = ow * stride_w - padding_w + kw * (dilation_w); - int ih = oh * stride_h - padding_h + kh * (dilation_h); - if (iw < 0 || iw >= win) continue; - if (ih < 0 || ih >= hin) continue; + int iw = ow * stride_w - pad_w + kw * (dila_w); + int ih = oh * stride_h - pad_h + kh * (dila_h); + if (iw < 0 || iw >= in_w) continue; + if (ih < 0 || ih >= in_h) continue; - int iidx = n * chin * hin * win + g * in_c_group * hin * win + - ic * hin * win + ih * win + iw; + int iidx = n * in_channel * in_h * in_w + + g * in_c_group * in_h * in_w + ic * in_h * in_w + + ih * in_w + iw; int widx = g * out_c_group * in_c_group * kernel_h * kernel_w + oc * in_c_group * kernel_h * kernel_w + ic * kernel_h * kernel_w + kh * kernel_w + kw; - output_data[out_idx] += - (dtype)input_data[iidx] * (dtype)filter_data[widx]; + dst_data_ref[out_idx] += src_data[iidx] * weights_data[widx]; } } } if (flag_relu) { - output_data[out_idx] = - output_data[out_idx] > 0.f ? output_data[out_idx] : 0.f; + dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 + ? dst_data_ref[out_idx] + : (Dtype2)0; } } } @@ -109,6 +117,44 @@ void conv_compute_ref(const operators::ConvParam& param) { } } +template +void conv_compute_ref(const operators::ConvParam& param) { + const Dtype1* din = param.x->data(); + Dtype2* dout = param.output->mutable_data(); + + int num = param.x->dims()[0]; + int chout = param.output->dims()[1]; + int hout = param.output->dims()[2]; + int wout = param.output->dims()[3]; + + int chin = param.x->dims()[1]; + int hin = param.x->dims()[2]; + int win = param.x->dims()[3]; + + const Dtype1* weights = param.filter->mutable_data(); + Dtype2* bias = nullptr; + if (param.bias != nullptr) { + bias = param.bias->mutable_data(); + } + + int group = param.groups; + int kernel_w = param.filter->dims()[2]; + int kernel_h = param.filter->dims()[3]; + int stride_w = param.strides[0]; + int stride_h = param.strides[1]; + int dila_w = param.dilations[0]; + int dila_h = param.dilations[1]; + + int pad_w = param.paddings[0]; + int pad_h = param.paddings[1]; + bool flag_bias = (param.bias != nullptr); + bool flag_relu = param.fuse_relu; + + conv_basic(din, dout, num, chout, hout, wout, chin, hin, win, weights, bias, + group, kernel_w, kernel_h, stride_w, stride_h, dila_w, dila_h, + pad_w, pad_h, flag_bias, flag_relu); +} + TEST(conv_arm, retrive_op) { auto conv = KernelRegistry::Global().Create( "conv2d"); @@ -116,12 +162,122 @@ TEST(conv_arm, retrive_op) { ASSERT_TRUE(conv.front()); } +TEST(conv_arm_int8, retrive_op) { + auto conv = + KernelRegistry::Global().Create("conv2d"); + ASSERT_FALSE(conv.empty()); + ASSERT_TRUE(conv.front()); +} + TEST(conv_arm, init) { ConvCompute conv; ASSERT_EQ(conv.precision(), PRECISION(kFloat)); ASSERT_EQ(conv.target(), TARGET(kARM)); } +TEST(conv_arm_int8, init) { + ConvComputeInt8 float_out; + ASSERT_EQ(float_out.precision(), PRECISION(kInt8)); + ASSERT_EQ(float_out.target(), TARGET(kARM)); + ConvComputeInt8 int8_out; + ASSERT_EQ(float_out.precision(), PRECISION(kInt8)); + ASSERT_EQ(float_out.target(), TARGET(kARM)); +} + +TEST(conv_arm_int8, compute) { + DeviceInfo::Init(); + for (auto n : {2}) { + for (auto ic : {6}) { + for (auto oc : {6}) { + for (auto ih : {9}) { + for (auto iw : {9}) { + for (auto flag_bias : {false, /*true*/}) { + for (auto flag_relu : {false, /*true*/}) { + for (auto depthwise : {false, /*true*/}) { + for (auto dilation : {1}) { + for (auto stride : {1}) { + for (auto padding : {0}) { + for (auto ks : {1}) { + int group = 1; + if (depthwise) { // depthwise convolution ? + group = oc = ic; + } + + const int dks = dilation * (ks - 1) + 1; + int oh = (ih + 2 * padding - dks) / stride + 1; + int ow = (iw + 2 * padding - dks) / stride + 1; + std::vector input_shape = {n, ic, ih, iw}; + std::vector filter_shape = {oc, ic / group, + ks, ks}; + std::vector output_shape({n, oc, oh, ow}); + + Tensor input_int8; + Tensor filter_int8; + Tensor output_int32, output_int32_ref; + + input_int8.Resize(input_shape); + filter_int8.Resize(filter_shape); + output_int32.Resize(output_shape); + output_int32_ref.Resize(output_shape); + + int8_t* input_int8_data = + input_int8.mutable_data(); + int8_t* filter_int8_data = + filter_int8.mutable_data(); + for (int i = 0; i < input_int8.dims().production(); + i++) { + input_int8_data[i] = 1.f; + } + for (int i = 0; i < filter_int8.dims().production(); + i++) { + filter_int8_data[i] = 1.f; + } + + operators::ConvParam param; + param.x = &input_int8; + param.filter = &filter_int8; + param.bias = nullptr; + param.fuse_relu = false; + param.paddings = std::vector({padding, padding}); + param.strides = std::vector({stride, stride}); + param.dilations = + std::vector({dilation, dilation}); + param.groups = group; + param.output = &output_int32_ref; + conv_compute_ref(param); + + param.output = &output_int32; + std::unique_ptr ctx(new KernelContext); + lite::arm::math::GemmLikeConvInt8 + int8gemm_int32; + int8gemm_int32.init(param, &ctx->As()); + int8gemm_int32.create(param, &ctx->As()); + int8gemm_int32.run(param); + + int32_t* output_int32_data = + output_int32.mutable_data(); + int32_t* output_int32_ref_data = + output_int32_ref.mutable_data(); + + for (int i = 0; i < output_int32.dims().production(); + i++) { + EXPECT_NEAR(output_int32_data[i], + output_int32_ref_data[i], 1e-3); + } + } + } + } + } + } + } + } + } + } + } + } + } +} + TEST(conv_arm, compute) { DeviceInfo::Init(); #if 1 @@ -219,7 +375,7 @@ TEST(conv_arm, compute) { conv.Launch(); // invoking ref implementation and compare results param.output = &output_ref; - conv_compute_ref(param); + conv_compute_ref(param); auto* output_ref_data = output_ref.mutable_data(); for (int i = 0; i < output.dims().production(); i++) { diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index f06ecf545..416791aa8 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -19,6 +19,11 @@ #include "paddle/fluid/lite/core/framework.pb.h" #include "paddle/fluid/lite/utils/all.h" +#define WITH_INT8_CONFIG \ + bool enable_int8; \ + float input_scale; \ + std::vector weight_scale{}; \ + float output_scale; /* * This file contains all the argument parameter data structure for operators. */ @@ -147,6 +152,7 @@ struct ConvParam { float scale_weights{1.0f}; // only used with mkl-dnn int8 bool force_fp32_output{false}; // only used in mkl-dnn int8 std::string data_format{"Anylayout"}; + WITH_INT8_CONFIG }; // For BatchNorm op -- GitLab