提交 5e5d4ae5 编写于 作者: S shixiaowei02 提交者: xingzhaolong

add gemm-like conv

上级 e28b5a3c
message(STATUS "add lite kernels")
set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite context_lite ${tensor_lite})
set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite context_lite ${tensor_lite} CACHE INTERNAL "" FORCE)
add_subdirectory(host)
add_subdirectory(arm)
add_subdirectory(cuda)
......
......@@ -92,8 +92,24 @@ void ConvCompute::Run() {
// }
}
void ConvComputeInt8::PrepareForRun() {}
void ConvComputeInt8::Run() {}
template <PrecisionType Ptype_out>
void ConvComputeInt8<Ptype_out>::PrepareForRun() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>;
CHECK(this->impl_->create(param, &ctx));
}
template <PrecisionType Ptype_out>
void ConvComputeInt8<Ptype_out>::Run() {
auto& param = this->Param<param_t>();
CHECK(impl_);
impl_->run(param);
}
template class ConvComputeInt8<PRECISION(kInt8)>;
template class ConvComputeInt8<PRECISION(kFloat)>;
template class ConvComputeInt8<PRECISION(kInt32)>;
} // namespace arm
} // namespace kernels
......@@ -116,8 +132,9 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8, def)
REGISTER_LITE_KERNEL(
conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8<PRECISION(kInt8)>, int8_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter",
......@@ -126,12 +143,13 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW,
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8, def)
REGISTER_LITE_KERNEL(
conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8<PRECISION(kFloat)>, fp32_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
......@@ -41,6 +41,7 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
nullptr};
};
template <PrecisionType Ptype_out>
class ConvComputeInt8 : public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
using param_t = operators::ConvParam;
......
......@@ -14,9 +14,11 @@
#include "paddle/fluid/lite/kernels/arm/conv_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/arm/math/type_trans.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
......@@ -24,89 +26,133 @@ namespace lite {
namespace kernels {
namespace arm {
template <typename dtype>
void conv_compute_ref(const operators::ConvParam& param) {
auto input = param.x;
auto filter = param.filter;
auto output = param.output;
DDim input_dims = param.x->dims();
DDim filter_dims = param.filter->dims();
DDim output_dims = param.output->dims();
std::vector<int> paddings = param.paddings;
std::vector<int> strides = param.strides;
std::vector<int> dilations = param.dilations;
int groups = param.groups;
auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>();
auto filter_data = param.filter->mutable_data<float>();
const float* bias_data = nullptr;
if (param.bias != nullptr) {
bias_data = param.bias->mutable_data<float>();
static float compute_max_kernel(const float* din, int64_t size) {
float max_value = -std::numeric_limits<float>::max();
for (int64_t i = 0; i < size; i++) {
max_value = max_value > din[0] ? max_value : din[0];
}
bool flag_bias = bias_data != nullptr;
bool flag_relu = param.fuse_relu;
LOG(INFO) << "[max_value]: " << max_value;
return max_value;
}
static std::vector<float> get_tensor_scale_n(const float* in_data,
int axis_size, int64_t inner_size,
float scale_factor) {
std::vector<float> scale_out(axis_size);
for (int c = 0; c < axis_size; ++c) { // num
const float* ptr_in = in_data + c * inner_size; // channel*width*height
scale_out[c] = compute_max_kernel(ptr_in, inner_size) / scale_factor;
}
for (auto s : scale_out) {
LOG(INFO) << "[Scale out]: " << s;
}
return scale_out;
}
template <typename Dtype1, typename Dtype2>
static void conv_basic(const Dtype1* din, Dtype2* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const Dtype1* weights, const Dtype2* bias, int group,
int kernel_w, int kernel_h, int stride_w, int stride_h,
int dila_w, int dila_h, int pad_w, int pad_h,
bool flag_bias, bool flag_relu) {
Dtype2 beta = 0;
auto src_data = din;
auto dst_data_ref = dout;
auto weights_data = weights;
auto with_bias = flag_bias;
auto bias_data = bias;
int in_num = num;
int out_channels = chout;
int out_h = hout;
int out_w = wout;
int num = input_dims[0];
int chout = output_dims[1];
int hout = output_dims[2];
int wout = output_dims[3];
int chin = input_dims[1];
int hin = input_dims[2];
int win = input_dims[3];
int out_c_group = chout / groups;
int in_c_group = chin / groups;
int stride_h = strides[0];
int stride_w = strides[1];
int dilation_h = dilations[0];
int dilation_w = dilations[1];
int padding_h = paddings[0];
int padding_w = paddings[1];
int kernel_h = filter_dims[2];
int kernel_w = filter_dims[3];
for (int n = 0; n < num; ++n) {
for (int g = 0; g < groups; ++g) {
int in_channel = chin;
int in_h = hin;
int in_w = win;
int out_c_group = out_channels / group;
int in_c_group = in_channel / group;
for (int n = 0; n < in_num; ++n) {
for (int g = 0; g < group; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < hout; ++oh) {
for (int ow = 0; ow < wout; ++ow) {
int out_idx = n * groups * out_c_group * hout * wout +
g * out_c_group * hout * wout + oc * hout * wout +
oh * wout + ow;
output_data[out_idx] =
flag_bias ? static_cast<float>(bias_data[g * out_c_group + oc])
: 0.f;
for (int oh = 0; oh < out_h; ++oh) {
for (int ow = 0; ow < out_w; ++ow) {
int out_idx = n * group * out_c_group * out_h * out_w +
g * out_c_group * out_h * out_w + oc * out_h * out_w +
oh * out_w + ow;
Dtype2 bias_d =
with_bias ? (bias_data[g * out_c_group + oc]) : (Dtype2)0;
dst_data_ref[out_idx] = bias_d; // + dst_data_ref[out_idx] * beta;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - padding_w + kw * (dilation_w);
int ih = oh * stride_h - padding_h + kh * (dilation_h);
if (iw < 0 || iw >= win) continue;
if (ih < 0 || ih >= hin) continue;
int iw = ow * stride_w - pad_w + kw * (dila_w);
int ih = oh * stride_h - pad_h + kh * (dila_h);
if (iw < 0 || iw >= in_w) continue;
if (ih < 0 || ih >= in_h) continue;
int iidx = n * chin * hin * win + g * in_c_group * hin * win +
ic * hin * win + ih * win + iw;
int iidx = n * in_channel * in_h * in_w +
g * in_c_group * in_h * in_w + ic * in_h * in_w +
ih * in_w + iw;
int widx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
output_data[out_idx] +=
(dtype)input_data[iidx] * (dtype)filter_data[widx];
dst_data_ref[out_idx] += src_data[iidx] * weights_data[widx];
}
}
}
if (flag_relu) {
output_data[out_idx] =
output_data[out_idx] > 0.f ? output_data[out_idx] : 0.f;
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
}
}
}
}
}
}
}
template <typename Dtype1, typename Dtype2>
void conv_compute_ref(const operators::ConvParam& param) {
const Dtype1* din = param.x->data<Dtype1>();
Dtype2* dout = param.output->mutable_data<Dtype2>();
int num = param.x->dims()[0];
int chout = param.output->dims()[1];
int hout = param.output->dims()[2];
int wout = param.output->dims()[3];
int chin = param.x->dims()[1];
int hin = param.x->dims()[2];
int win = param.x->dims()[3];
const Dtype1* weights = param.filter->mutable_data<Dtype1>();
Dtype2* bias = nullptr;
if (param.bias != nullptr) {
bias = param.bias->mutable_data<Dtype2>();
}
int group = param.groups;
int kernel_w = param.filter->dims()[2];
int kernel_h = param.filter->dims()[3];
int stride_w = param.strides[0];
int stride_h = param.strides[1];
int dila_w = param.dilations[0];
int dila_h = param.dilations[1];
int pad_w = param.paddings[0];
int pad_h = param.paddings[1];
bool flag_bias = (param.bias != nullptr);
bool flag_relu = param.fuse_relu;
conv_basic(din, dout, num, chout, hout, wout, chin, hin, win, weights, bias,
group, kernel_w, kernel_h, stride_w, stride_h, dila_w, dila_h,
pad_w, pad_h, flag_bias, flag_relu);
}
TEST(conv_arm, retrive_op) {
......@@ -116,12 +162,122 @@ TEST(conv_arm, retrive_op) {
ASSERT_TRUE(conv.front());
}
TEST(conv_arm_int8, retrive_op) {
auto conv =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kInt8)>("conv2d");
ASSERT_FALSE(conv.empty());
ASSERT_TRUE(conv.front());
}
TEST(conv_arm, init) {
ConvCompute conv;
ASSERT_EQ(conv.precision(), PRECISION(kFloat));
ASSERT_EQ(conv.target(), TARGET(kARM));
}
TEST(conv_arm_int8, init) {
ConvComputeInt8<PRECISION(kFloat)> float_out;
ASSERT_EQ(float_out.precision(), PRECISION(kInt8));
ASSERT_EQ(float_out.target(), TARGET(kARM));
ConvComputeInt8<PRECISION(kInt8)> int8_out;
ASSERT_EQ(float_out.precision(), PRECISION(kInt8));
ASSERT_EQ(float_out.target(), TARGET(kARM));
}
TEST(conv_arm_int8, compute) {
DeviceInfo::Init();
for (auto n : {2}) {
for (auto ic : {6}) {
for (auto oc : {6}) {
for (auto ih : {9}) {
for (auto iw : {9}) {
for (auto flag_bias : {false, /*true*/}) {
for (auto flag_relu : {false, /*true*/}) {
for (auto depthwise : {false, /*true*/}) {
for (auto dilation : {1}) {
for (auto stride : {1}) {
for (auto padding : {0}) {
for (auto ks : {1}) {
int group = 1;
if (depthwise) { // depthwise convolution ?
group = oc = ic;
}
const int dks = dilation * (ks - 1) + 1;
int oh = (ih + 2 * padding - dks) / stride + 1;
int ow = (iw + 2 * padding - dks) / stride + 1;
std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {oc, ic / group,
ks, ks};
std::vector<int64_t> output_shape({n, oc, oh, ow});
Tensor input_int8;
Tensor filter_int8;
Tensor output_int32, output_int32_ref;
input_int8.Resize(input_shape);
filter_int8.Resize(filter_shape);
output_int32.Resize(output_shape);
output_int32_ref.Resize(output_shape);
int8_t* input_int8_data =
input_int8.mutable_data<int8_t>();
int8_t* filter_int8_data =
filter_int8.mutable_data<int8_t>();
for (int i = 0; i < input_int8.dims().production();
i++) {
input_int8_data[i] = 1.f;
}
for (int i = 0; i < filter_int8.dims().production();
i++) {
filter_int8_data[i] = 1.f;
}
operators::ConvParam param;
param.x = &input_int8;
param.filter = &filter_int8;
param.bias = nullptr;
param.fuse_relu = false;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations =
std::vector<int>({dilation, dilation});
param.groups = group;
param.output = &output_int32_ref;
conv_compute_ref<int8_t, int>(param);
param.output = &output_int32;
std::unique_ptr<KernelContext> ctx(new KernelContext);
lite::arm::math::GemmLikeConvInt8<PRECISION(kInt32)>
int8gemm_int32;
int8gemm_int32.init(param, &ctx->As<ARMContext>());
int8gemm_int32.create(param, &ctx->As<ARMContext>());
int8gemm_int32.run(param);
int32_t* output_int32_data =
output_int32.mutable_data<int32_t>();
int32_t* output_int32_ref_data =
output_int32_ref.mutable_data<int32_t>();
for (int i = 0; i < output_int32.dims().production();
i++) {
EXPECT_NEAR(output_int32_data[i],
output_int32_ref_data[i], 1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
TEST(conv_arm, compute) {
DeviceInfo::Init();
#if 1
......@@ -219,7 +375,7 @@ TEST(conv_arm, compute) {
conv.Launch();
// invoking ref implementation and compare results
param.output = &output_ref;
conv_compute_ref<float>(param);
conv_compute_ref<float, float>(param);
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) {
......
......@@ -19,6 +19,11 @@
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/utils/all.h"
#define WITH_INT8_CONFIG \
bool enable_int8; \
float input_scale; \
std::vector<float> weight_scale{}; \
float output_scale;
/*
* This file contains all the argument parameter data structure for operators.
*/
......@@ -147,6 +152,7 @@ struct ConvParam {
float scale_weights{1.0f}; // only used with mkl-dnn int8
bool force_fp32_output{false}; // only used in mkl-dnn int8
std::string data_format{"Anylayout"};
WITH_INT8_CONFIG
};
// For BatchNorm op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册