提交 1efe8b6c 编写于 作者: 吴承辉

Merge branch 'winograd6x6' into 'master'

Implement winograd (6x6, 3x3)

See merge request !359
......@@ -7,6 +7,7 @@
// winograd is always superior to neon impl during benchmark
#define USE_WINOGRAD 1
#define WINOGRAD_OUT_TILE_SIZE 6
namespace mace {
namespace kernels {
......@@ -164,9 +165,9 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
&& stride_w == 1
&& dilation_h == 1 && dilation_w == 1
&& input_channels >= 8 && channels >= 8) {
extra_output_height = RoundUp<index_t>(height, 2);
extra_output_height = RoundUp<index_t>(height, WINOGRAD_OUT_TILE_SIZE);
extra_input_height = std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, 2);
extra_output_width = RoundUp<index_t>(width, WINOGRAD_OUT_TILE_SIZE);
extra_input_width = std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
......@@ -175,12 +176,15 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
pad_right += (extra_input_width - padded_input_width);
}
index_t tile_height_count = (extra_output_height + 1) / 2;
index_t tile_width_count = (extra_output_width + 1) / 2;
index_t tile_height_count = extra_output_height / WINOGRAD_OUT_TILE_SIZE;
index_t tile_width_count = extra_output_width / WINOGRAD_OUT_TILE_SIZE;
index_t tile_count = tile_height_count * tile_width_count;
transformed_input_.Resize({16, batch, input_channels, tile_count});
transformed_filter_.Resize({16, channels, input_channels});
transformed_output_.Resize({16, batch, channels, tile_count});
index_t in_tile_area =
(WINOGRAD_OUT_TILE_SIZE + 2) * (WINOGRAD_OUT_TILE_SIZE + 2);
transformed_input_.Resize({in_tile_area, batch, input_channels,
tile_count});
transformed_filter_.Resize({in_tile_area, channels, input_channels});
transformed_output_.Resize({in_tile_area, batch, channels, tile_count});
conv_func = [=](const float *pad_input, float *pad_output) {
WinoGradConv3x3s1(pad_input,
......@@ -190,6 +194,7 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
extra_input_width,
input_channels,
channels,
WINOGRAD_OUT_TILE_SIZE,
transformed_input_.mutable_data<float>(),
transformed_filter_.mutable_data<float>(),
transformed_output_.mutable_data<float>(),
......
......@@ -8,19 +8,20 @@
#include "mace/kernels/arm/conv_winograd.h"
#include "mace/kernels/gemm.h"
#include "mace/utils/utils.h"
#include "mace/utils/logging.h"
namespace mace {
namespace kernels {
namespace {
// NCHW => TNCB (T: in tile pixels, B: tile indices)
void TransformInput(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t tile_count,
float *output) {
void TransformInput4x4(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t tile_count,
float *output) {
const index_t stride = batch * in_channels * tile_count;
const index_t in_height_width = in_height * in_width;
......@@ -101,12 +102,124 @@ void TransformInput(const float *input,
}
}
// NCHW => TNCB (T: in tile pixels, B: tile indices)
/**
* BT =
⎡1 0 -21/4 0 21/4 0 -1 0⎤
⎢ ⎥
⎢0 1 1 -17/4 -17/4 1 1 0⎥
⎢ ⎥
⎢0 -1 1 17/4 -17/4 -1 1 0⎥
⎢ ⎥
⎢0 1/2 1/4 -5/2 -5/4 2 1 0⎥
⎢ ⎥
⎢0 -1/2 1/4 5/2 -5/4 -2 1 0⎥
⎢ ⎥
⎢0 2 4 -5/2 -5 1/2 1 0⎥
⎢ ⎥
⎢0 -2 4 5/2 -5 -1/2 1 0⎥
⎢ ⎥
⎣0 -1 0 21/4 0 -21/4 0 1⎦
* @param input
* @param batch
* @param in_height
* @param in_width
* @param in_channels
* @param tile_count
* @param output
*/
void TransformInput8x8(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t tile_count,
float *output) {
const index_t stride = batch * in_channels * tile_count;
const index_t in_height_width = in_height * in_width;
#pragma omp parallel for
for (index_t nc = 0; nc < batch * in_channels; ++nc) {
index_t tile_index = nc * tile_count;
float s[8][8];
for (index_t h = 0; h < in_height - 2; h += 6) {
for (index_t w = 0; w < in_width - 2; w += 6) {
index_t tile_offset = nc * in_height_width + h * in_width + w;
for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = input[tile_offset];
d1 = input[tile_offset + 1];
d2 = input[tile_offset + 2];
d3 = input[tile_offset + 3];
d4 = input[tile_offset + 4];
d5 = input[tile_offset + 5];
d6 = input[tile_offset + 6];
d7 = input[tile_offset + 7];
s[i][0] = d0 - d6 + (d4 - d2) * 5.25;
s[i][7] = d7 - d1 + (d3 - d5) * 5.25;
float u = d2 + d6 - d4 * 4.25;
float v = d1 + d5 - d3 * 4.25;
s[i][1] = u + v;
s[i][2] = u - v;
u = d6 + d2 * 0.25 - d4 * 1.25;
v = d1 * 0.5 - d3 * 2.5 + d5 * 2;
s[i][3] = u + v;
s[i][4] = u - v;
u = d6 + (d2 - d4 * 1.25) * 4;
v = d1 * 2 - d3 * 2.5 + d5 * 0.5;
s[i][5] = u + v;
s[i][6] = u - v;
tile_offset += in_width;
}
for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = s[0][i];
d1 = s[1][i];
d2 = s[2][i];
d3 = s[3][i];
d4 = s[4][i];
d5 = s[5][i];
d6 = s[6][i];
d7 = s[7][i];
output[tile_index + i * stride] = d0 - d6 + (d4 - d2) * 5.25;
output[tile_index + (56 + i) * stride] = d7 - d1 + (d3 - d5) * 5.25;
float u = d2 + d6 - d4 * 4.25;
float v = d1 + d5 - d3 * 4.25;
output[tile_index + (8 + i) * stride] = u + v;
output[tile_index + (16 + i) * stride] = u - v;
u = d6 + d2 * 0.25 - d4 * 1.25;
v = d1 * 0.5 - d3 * 2.5 + d5 * 2;
output[tile_index + (24 + i) * stride] = u + v;
output[tile_index + (32 + i) * stride] = u - v;
u = d6 + (d2 - d4 * 1.25) * 4;
v = d1 * 2 - d3 * 2.5 + d5 * 0.5;
output[tile_index + (40 + i) * stride] = u + v;
output[tile_index + (48 + i) * stride] = u - v;
}
++tile_index;
}
}
}
}
// OCHW => TOC
// no need to optimize, it will exist in converter
void TransformFilter(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output) {
void TransformFilter4x4(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output) {
const index_t stride = out_channels * in_channels;
#pragma omp parallel for collapse(2)
......@@ -171,6 +284,83 @@ void TransformFilter(const float *filter,
}
}
// OCHW => TOC
// no need to optimize, it will exist in converter
/**
* G =
⎡ 1 0 0 ⎤
⎢ ⎥
⎢-2/9 -2/9 -2/9 ⎥
⎢ ⎥
⎢-2/9 2/9 -2/9 ⎥
⎢ ⎥
⎢1/90 1/45 2/45 ⎥
⎢ ⎥
⎢1/90 -1/45 2/45 ⎥
⎢ ⎥
⎢1/45 1/90 1/180⎥
⎢ ⎥
⎢1/45 -1/90 1/180⎥
⎢ ⎥
⎣ 0 0 1 ⎦
*
* @param filter
* @param in_channels
* @param out_channels
* @param output
*/
void TransformFilter8x8(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output) {
const index_t stride = out_channels * in_channels;
const float G[8][3] = {
{1.0f, 0.0f, 0.0f},
{-2.0f / 9, -2.0f / 9, -2.0f / 9},
{-2.0f / 9, 2.0f / 9, -2.0f / 9},
{1.0f / 90, 1.0f / 45, 2.0f / 45},
{1.0f / 90, -1.0f / 45, 2.0f / 45},
{1.0f / 45, 1.0f / 90, 1.0f / 180},
{1.0f / 45, -1.0f / 90, 1.0f / 180},
{0.0f, 0.0f, 1.0f}
};
#pragma omp parallel for collapse(2)
for (index_t m = 0; m < out_channels; ++m) {
for (index_t c = 0; c < in_channels; ++c) {
// load filter
index_t filter_offset = (m * in_channels + c) * 9;
float g0, g1, g2, g3, g4, g5, g6, g7, g8;
g0 = filter[filter_offset];
g1 = filter[filter_offset + 1];
g2 = filter[filter_offset + 2];
g3 = filter[filter_offset + 3];
g4 = filter[filter_offset + 4];
g5 = filter[filter_offset + 5];
g6 = filter[filter_offset + 6];
g7 = filter[filter_offset + 7];
g8 = filter[filter_offset + 8];
float s[3][8];
for (int i = 0; i < 8; ++i) {
s[0][i] = g0 * G[i][0] + g1 * G[i][1] + g2 * G[i][2];
s[1][i] = g3 * G[i][0] + g4 * G[i][1] + g5 * G[i][2];
s[2][i] = g6 * G[i][0] + g7 * G[i][1] + g8 * G[i][2];
}
// store output
index_t output_offset = m * in_channels + c;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
output[output_offset + (i * 8 + j) * stride] =
G[i][0] * s[0][j] + G[i][1] * s[1][j] + G[i][2] * s[2][j];
}
}
}
}
}
// TOC * TNCB => TNOB
void BatchGemm(const float *input,
const float *filter,
......@@ -178,17 +368,24 @@ void BatchGemm(const float *input,
index_t in_channels,
index_t out_channels,
index_t tile_count,
int out_tile_size,
float *output) {
const index_t in_stride = batch * in_channels * tile_count;
const index_t in_channels_tile_count = in_channels * tile_count;
const index_t filter_stride = out_channels * in_channels;
const index_t out_stride = batch * out_channels * tile_count;
const index_t out_channels_tile_count = out_channels * tile_count;
const int in_tile_area = (out_tile_size + 2) * (out_tile_size + 2);
if (batch == 1) {
Gemm(filter, input, 16, out_channels, in_channels, tile_count, output);
Gemm(filter,
input,
in_tile_area,
out_channels,
in_channels,
tile_count,
output);
} else {
for (int i = 0; i < 16; ++i) {
for (int i = 0; i < in_tile_area; ++i) {
for (int b = 0; b < batch; ++b) {
const float
*in_ptr = input + i * in_stride + b * in_channels_tile_count;
......@@ -207,13 +404,13 @@ void BatchGemm(const float *input,
}
// TNOB => ToNOB => NOHoWo
void TransformOutput(const float *input,
index_t batch,
index_t out_height,
index_t out_width,
index_t out_channels,
index_t tile_count,
float *output) {
void TransformOutput4x4(const float *input,
index_t batch,
index_t out_height,
index_t out_width,
index_t out_channels,
index_t tile_count,
float *output) {
const index_t in_stride = batch * out_channels * tile_count;
#pragma omp parallel for
......@@ -271,6 +468,107 @@ void TransformOutput(const float *input,
}
}
}
// TNOB => ToNOB => NOHoWo
/**
* AT =
⎡1 1 1 1 1 32 32 0⎤
⎢ ⎥
⎢0 1 -1 2 -2 16 -16 0⎥
⎢ ⎥
⎢0 1 1 4 4 8 8 0⎥
⎢ ⎥
⎢0 1 -1 8 -8 4 -4 0⎥
⎢ ⎥
⎢0 1 1 16 16 2 2 0⎥
⎢ ⎥
⎣0 1 -1 32 -32 1 -1 1⎦
*
* @param input
* @param batch
* @param out_height
* @param out_width
* @param out_channels
* @param tile_count
* @param output
*/
void TransformOutput8x8(const float *input,
index_t batch,
index_t out_height,
index_t out_width,
index_t out_channels,
index_t tile_count,
float *output) {
const index_t in_stride = batch * out_channels * tile_count;
#pragma omp parallel for
for (index_t nm = 0; nm < batch * out_channels; ++nm) {
index_t tile_offset = nm * tile_count;
float s[8][6];
for (index_t h = 0; h < out_height; h += 6) {
for (index_t w = 0; w < out_width; w += 6) {
index_t tile_offset_tmp = tile_offset;
for (int i = 0; i < 8; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = input[tile_offset_tmp + 0 * in_stride];
d1 = input[tile_offset_tmp + 1 * in_stride];
d2 = input[tile_offset_tmp + 2 * in_stride];
d3 = input[tile_offset_tmp + 3 * in_stride];
d4 = input[tile_offset_tmp + 4 * in_stride];
d5 = input[tile_offset_tmp + 5 * in_stride];
d6 = input[tile_offset_tmp + 6 * in_stride];
d7 = input[tile_offset_tmp + 7 * in_stride];
float u = d1 + d2;
float v = d1 - d2;
float w = d3 + d4;
float x = d3 - d4;
float y = d5 + d6;
float z = d5 - d6;
s[i][0] = d0 + u + w + y * 32;
s[i][1] = v + x + x + z * 16;
s[i][2] = u + w * 4 + y * 8;
s[i][3] = v + x * 8 + z * 4;
s[i][4] = u + w * 16 + y + y;
s[i][5] = v + x * 32 + z + d7;
tile_offset_tmp += 8 * in_stride;
}
index_t out_offset = nm * out_height * out_width + h * out_width + w;
for (int i = 0; i < 6; ++i) {
float d0, d1, d2, d3, d4, d5, d6, d7;
d0 = s[0][i];
d1 = s[1][i];
d2 = s[2][i];
d3 = s[3][i];
d4 = s[4][i];
d5 = s[5][i];
d6 = s[6][i];
d7 = s[7][i];
float u = d1 + d2;
float v = d1 - d2;
float w = d3 + d4;
float x = d3 - d4;
float y = d5 + d6;
float z = d5 - d6;
output[out_offset + 0 * out_width + i] = d0 + u + w + y * 32;
output[out_offset + 1 * out_width + i] = v + x + x + z * 16;
output[out_offset + 2 * out_width + i] = u + w * 4 + y * 8;
output[out_offset + 3 * out_width + i] = v + x * 8 + z * 4;
output[out_offset + 4 * out_width + i] = u + w * 16 + y + y;
output[out_offset + 5 * out_width + i] = v + x * 32 + z + d7;
}
++tile_offset;
}
}
}
}
} // namespace
void WinoGradConv3x3s1(const float *input,
......@@ -280,6 +578,7 @@ void WinoGradConv3x3s1(const float *input,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
const int out_tile_size,
float *transformed_input,
float *transformed_filter,
float *transformed_output,
......@@ -287,22 +586,52 @@ void WinoGradConv3x3s1(const float *input,
float *output) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t tile_height_count = (out_height + 1) / 2;
index_t tile_width_count = (out_width + 1) / 2;
index_t tile_height_count =
RoundUpDiv(out_height, static_cast<index_t>(out_tile_size));
index_t tile_width_count =
RoundUpDiv(out_width, static_cast<index_t>(out_tile_size));
index_t tile_count = tile_height_count * tile_width_count;
TransformInput(input,
batch,
in_height,
in_width,
in_channels,
tile_count,
transformed_input);
switch (out_tile_size) {
case 2:
TransformInput4x4(input,
batch,
in_height,
in_width,
in_channels,
tile_count,
transformed_input);
break;
case 6:
TransformInput8x8(input,
batch,
in_height,
in_width,
in_channels,
tile_count,
transformed_input);
break;
default:MACE_NOT_IMPLEMENTED;
}
// TODO(liyin): put it in model converter, but do not worry, it is fast and
// will only do once
if (!is_filter_transformed) {
TransformFilter(filter, in_channels, out_channels, transformed_filter);
switch (out_tile_size) {
case 2:
TransformFilter4x4(filter,
in_channels,
out_channels,
transformed_filter);
break;
case 6:
TransformFilter8x8(filter,
in_channels,
out_channels,
transformed_filter);
break;
default:MACE_NOT_IMPLEMENTED;
}
}
BatchGemm(transformed_input,
......@@ -311,15 +640,30 @@ void WinoGradConv3x3s1(const float *input,
in_channels,
out_channels,
tile_count,
out_tile_size,
transformed_output);
TransformOutput(transformed_output,
batch,
out_height,
out_width,
out_channels,
tile_count,
output);
switch (out_tile_size) {
case 2:
TransformOutput4x4(transformed_output,
batch,
out_height,
out_width,
out_channels,
tile_count,
output);
break;
case 6:
TransformOutput8x8(transformed_output,
batch,
out_height,
out_width,
out_channels,
tile_count,
output);
break;
default:MACE_NOT_IMPLEMENTED;
}
}
void WinoGradConv3x3s1(const float *input,
......@@ -329,16 +673,21 @@ void WinoGradConv3x3s1(const float *input,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
const int out_tile_size,
float *output) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t tile_height_count = (out_height + 1) / 2;
index_t tile_width_count = (out_width + 1) / 2;
index_t tile_height_count =
RoundUpDiv(out_height, static_cast<index_t>(out_tile_size));
index_t tile_width_count =
RoundUpDiv(out_width, static_cast<index_t>(out_tile_size));
index_t tile_count = tile_height_count * tile_width_count;
index_t transformed_input_size = 16 * batch * in_channels * tile_count;
index_t transformed_filter_size = 16 * out_channels * in_channels;
index_t transformed_output_size = 16 * batch * out_channels * tile_count;
index_t in_tile_area = (out_tile_size + 2) * (out_tile_size + 2);
index_t transformed_input_size =
in_tile_area * batch * in_channels * tile_count;
index_t transformed_filter_size = in_tile_area * out_channels * in_channels;
index_t
transformed_output_size = in_tile_area * batch * out_channels * tile_count;
float *transformed_input = new float[transformed_input_size]; // TNCB
float *transformed_filter = new float[transformed_filter_size]; // TOC
......@@ -351,6 +700,7 @@ void WinoGradConv3x3s1(const float *input,
in_width,
in_channels,
out_channels,
out_tile_size,
transformed_input,
transformed_filter,
transformed_output,
......@@ -362,7 +712,6 @@ void WinoGradConv3x3s1(const float *input,
delete[]transformed_output;
}
void ConvRef3x3s1(const float *input,
const float *filter,
const index_t batch,
......@@ -391,7 +740,8 @@ void ConvRef3x3s1(const float *input,
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t
filter_offset = (((m * in_channels) + c) * 3 + kh) * 3 + kw;
output[out_offset] += input[in_offset] * filter[filter_offset];
output[out_offset] +=
input[in_offset] * filter[filter_offset];
}
}
}
......
......@@ -21,6 +21,7 @@ void WinoGradConv3x3s1(const float *input,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
const int out_tile_size,
float *output);
void WinoGradConv3x3s1(const float *input,
......@@ -30,6 +31,7 @@ void WinoGradConv3x3s1(const float *input,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
const int out_tile_size,
float *transformed_input,
float *transformed_filter,
float *transformed_output,
......
......@@ -58,11 +58,12 @@ TEST(ConvWinogradTest, winograd) {
in_width,
in_channels,
out_channels,
6,
output_data);
// test
for (index_t i = 0; i < output_size; ++i) {
EXPECT_NEAR(output_data_ref[i], output_data[i], 0.1);
EXPECT_NEAR(output_data_ref[i], output_data[i], 0.1) << " with index " << i;
}
delete[]input_data;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册