提交 b41c7a95 编写于 作者: H hjchen2

Refine int8 conv5x5 implementation

上级 f756739e
......@@ -156,7 +156,7 @@ class AttrReader {
template <typename T>
inline T Get(const string &name) const {
PADDLE_MOBILE_ENFORCE(attrs_.count(name) != 0,
"%s should be in AttributeMap", name);
"%s should be in AttributeMap", name.c_str());
return ((Attribute)attrs_.at(name)).Get<T>();
}
......
......@@ -21,9 +21,7 @@ namespace operators {
void conv3x3s1_int8(const framework::Tensor& input,
const framework::Tensor& weight,
framework::Tensor* output) {
// TODO(hjchen2)
}
framework::Tensor* output) {}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -37,16 +37,12 @@ void conv5x5s1_int8(const framework::Tensor& input,
int out_image_size = output_h * output_w;
memset(out_data, 0, output_c * out_image_size * sizeof(int32_t));
#pragma omp parallel for
for (int oc = 0; oc < output_c; ++oc) {
for (int ic = 0; ic < input_c; ++ic) {
const int8_t* kernel = w_data + (oc * input_c + ic) * 25;
int32_t* output0 = out_data;
int32_t* output1 = out_data + output_w;
// load kernel
asm volatile("vld1.8 {d0-d3}, [%0] \n"
: "=r"(kernel)
: // no output
: "memory", "q0", "q1");
int32_t* output0 = out_data + oc * out_image_size;
int32_t* output1 = output0 + output_w;
int oh = 0;
for (; oh < output_h - 1; oh += 2) {
const int8_t* r0 = in_data + ic * image_size + oh * input_w;
......@@ -59,6 +55,10 @@ void conv5x5s1_int8(const framework::Tensor& input,
int ow = output_w >> 3;
int remain = output_w & 0x7;
if (ow > 0) {
asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n"
: [kernel] "+r"(kernel)
:
: "cc", "memory", "q0", "q1");
asm volatile(
"0: \n"
"vld1.8 {d4-d5}, [%[r0]] \n" // r0
......@@ -262,6 +262,10 @@ void conv5x5s1_int8(const framework::Tensor& input,
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
if (remain > 0) {
asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n"
: [kernel] "+r"(kernel)
:
: "cc", "memory", "q0", "q1");
asm volatile(
"0: \n"
"vld1.8 d4, [%[r0]] \n"
......@@ -346,6 +350,10 @@ void conv5x5s1_int8(const framework::Tensor& input,
int ow = output_w >> 3;
int remain = output_w & 0x7;
if (ow > 0) {
asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n"
: [kernel] "+r"(kernel)
:
: "cc", "memory", "q0", "q1");
asm volatile(
"0: \n"
"vld1.8 {d4-d5}, [%[r0]] \n" // r0
......@@ -474,7 +482,12 @@ void conv5x5s1_int8(const framework::Tensor& input,
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
if (remain > 0) {
asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n"
: [kernel] "+r"(kernel)
:
: "cc", "memory", "q0", "q1");
asm volatile(
"0: \n"
"vld1.8 d4, [%[r0]] \n"
......@@ -523,7 +536,6 @@ void conv5x5s1_int8(const framework::Tensor& input,
}
}
}
out_data += out_image_size;
}
#else
// TODO(hjchen2)
......
......@@ -114,16 +114,15 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
}
}
inline void ConvBasic_int8(const ConvParam<CPU> &param) {
inline void ConvCompute_int8(const ConvParam<CPU> &param) {
typedef void (*ConvFunc)(const Tensor &input, const Tensor &kernel,
Tensor *output);
static ConvFunc conv_funcs_table[7][5] = {
{0, 0, 0, 0, 0}, // k = 1
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, // k = 3
{0, 0, 0, 0, 0}, {conv3x3s1_int8, 0, 0, 0, 0}, // k = 3
{0, 0, 0, 0, 0}, {conv5x5s1_int8, 0, 0, 0, 0}, // k = 5
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, // k = 7
};
const Tensor *input = param.Input();
Tensor *filter = param.Filter();
Tensor *output = param.Output();
......@@ -150,11 +149,12 @@ inline void ConvBasic_int8(const ConvParam<CPU> &param) {
input_pad.mutable_data<int8_t>(pad_shape);
pad(in_batch, paddings[0], paddings[1], &input_pad);
}
// int8 only used while dilation==1 and groups==1
if (strides[1] == strides[0] && strides[1] < 6 && kernel_h == kernel_w &&
kernel_h < 8 && dilations[0] == 0 && dilations[1] == 0 && groups == 1) {
ConvFunc conv_func = conv_funcs_table[kernel_h - 1][strides[1] - 1];
if (!conv_func) {
kernel_h < 8 && groups == 1 && dilations[0] == dilations[1] &&
dilations[1] == 1) {
ConvFunc conv_func = conv_funcs_table[kernel_h - 1][strides[0] - 1];
if (conv_func) {
conv_func(input_pad, *filter, &out_batch);
} else {
// TODO(hjchen2)
......@@ -167,21 +167,21 @@ inline void ConvBasic_int8(const ConvParam<CPU> &param) {
template <typename P>
void ConvCompute(const ConvParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
if (param.Input()->type() == typeid(int8_t)) {
ConvCompute_int8(param);
} else {
if (param.Input()->type() == typeid(int8_t)) {
ConvBasic_int8(param);
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
} else {
ConvBasic(param);
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/conv_op.h"
namespace paddle_mobile {
// Reference convolution for checking results:
// accumulate through explicit loops over input, output, and filters.
template <typename Itype, typename Otype>
void conv2d(const framework::Tensor *input, const framework::Tensor *filter,
const framework::AttributeMap &attrs, framework::Tensor *output) {
framework::AttrReader attr_reader(attrs);
std::vector<int> paddings = attr_reader.Get<std::vector<int>>("paddings");
std::vector<int> strides = attr_reader.Get<std::vector<int>>("strides");
std::vector<int> dilations = attr_reader.Get<std::vector<int>>("dilations");
int groups = attr_reader.Get<int>("groups");
int kernel_h = filter->dims()[2];
int kernel_w = filter->dims()[3];
int pad_h = paddings[0];
int pad_w = paddings[1];
int stride_h = strides[0];
int stride_w = strides[1];
int dilation_h = dilations[0];
int dilation_w = dilations[1];
auto in_shape = input->dims();
auto out_shape = output->dims();
const bool has_depth = 0;
int kernel_d, pad_d, stride_d, dilation_d;
if (has_depth) {
kernel_d = kernel_h;
stride_d = stride_h;
pad_d = pad_h;
dilation_d = dilation_h;
} else {
kernel_d = stride_d = dilation_d = 1;
pad_d = 0;
}
// Groups
int o_g = out_shape[1] / groups;
int k_g = in_shape[1] / groups;
int o_head, k_head;
// Convolution
vector<int> weight_offset(4 + has_depth);
vector<int> in_offset(4 + has_depth);
vector<int> out_offset(4 + has_depth);
auto offset = [](const framework::Tensor *input, const vector<int> &indics) {
framework::DDim shape = input->dims();
size_t count = 0;
for (int i = 0; i < indics.size(); ++i) {
count *= shape[i];
count += indics[i];
}
return count;
};
const Itype *in_data = input->data<Itype>();
const Itype *w_data = filter->data<Itype>();
Otype *out_data = output->mutable_data<Otype>();
memset(out_data, 0, output->numel() * sizeof(Otype));
for (int n = 0; n < out_shape[0]; n++) {
for (int g = 0; g < groups; g++) {
o_head = o_g * g;
k_head = k_g * g;
for (int o = 0; o < o_g; o++) {
for (int k = 0; k < k_g; k++) {
for (int z = 0; z < (has_depth ? out_shape[2] : 1); z++) {
for (int y = 0; y < out_shape[2 + has_depth]; y++) {
for (int x = 0; x < out_shape[3 + has_depth]; x++) {
for (int r = 0; r < kernel_d; r++) {
for (int p = 0; p < kernel_h; p++) {
for (int q = 0; q < kernel_w; q++) {
int in_z = z * stride_d - pad_d + r * dilation_d;
int in_y = y * stride_h - pad_h + p * dilation_h;
int in_x = x * stride_w - pad_w + q * dilation_w;
if (in_z >= 0 && in_z < (has_depth ? in_shape[2] : 1) &&
in_y >= 0 && in_y < in_shape[2 + has_depth] &&
in_x >= 0 && in_x < in_shape[3 + has_depth]) {
weight_offset[0] = o + o_head;
weight_offset[1] = k;
if (has_depth) {
weight_offset[2] = r;
}
weight_offset[2 + has_depth] = p;
weight_offset[3 + has_depth] = q;
in_offset[0] = n;
in_offset[1] = k + k_head;
if (has_depth) {
in_offset[2] = in_z;
}
in_offset[2 + has_depth] = in_y;
in_offset[3 + has_depth] = in_x;
out_offset[0] = n;
out_offset[1] = o + o_head;
if (has_depth) {
out_offset[2] = z;
}
out_offset[2 + has_depth] = y;
out_offset[3 + has_depth] = x;
out_data[offset(output, out_offset)] +=
in_data[offset(input, in_offset)] *
w_data[offset(filter, weight_offset)];
}
}
}
}
}
}
}
}
}
}
}
}
template <typename Itype, typename Otype, int Kernel, int Pad, int Stride>
int TestConvOp() {
int kernel_h = Kernel;
int kernel_w = Kernel;
int pad_h = Pad;
int pad_w = Pad;
int stride_h = Stride;
int stride_w = Stride;
int dilation_h = 1;
int dilation_w = 1;
int batch_size = 2;
int input_c = 3;
int input_h = 100;
int input_w = 100;
int output_c = 32;
framework::DDim input_shape =
framework::make_ddim({batch_size, input_c, input_h, input_w});
framework::DDim filter_shape =
framework::make_ddim({output_c, input_c, kernel_h, kernel_w});
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["Input"] = std::vector<std::string>({"input"});
inputs["Filter"] = std::vector<std::string>({"filter"});
outputs["Output"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(input, input_shape, -127, 127);
auto filter_var = scope.get()->Var("filter");
auto filter = filter_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(filter, filter_shape, -127, 127);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["strides"].Set<vector<int>>(std::vector<int>({stride_h, stride_w}));
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad_h, pad_w}));
attrs["dilations"].Set<vector<int>>(
std::vector<int>({dilation_h, dilation_w}));
attrs["groups"].Set<int>(1);
auto *op = new operators::ConvOp<CPU, float>("conv2d", inputs, outputs, attrs,
scope);
op->InferShape();
op->Run();
int kernel_extent_h = dilation_h * (kernel_h - 1) + 1;
int kernel_extent_w = dilation_w * (kernel_w - 1) + 1;
int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1;
int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1;
auto output_shape = framework::make_ddim(
std::vector<int>({batch_size, output_c, output_h, output_w}));
framework::Tensor output_cmp;
output_cmp.mutable_data<Otype>(output_shape);
conv2d<Itype, Otype>(input, filter, attrs, &output_cmp);
// compare results
auto output = output_var->template Get<framework::LoDTensor>();
const Otype *output_data = output->data<Otype>();
Otype *output_cmp_data = output_cmp.data<Otype>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
"output[%d] = %d, output_cmp[%d] = %d", i,
output_data[i], i, output_cmp_data[i]);
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() { return paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>(); }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册