提交 97b2c1a9 编写于 作者: H hjchen2

Revert padding in winograd input transform, it looks faster than padding in advance

上级 b92caeb2
......@@ -130,8 +130,8 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
auto winograd_pad = [&](int width, int pad) {
int output_tile = tile - kernel + 1;
// int tiles = (width + pad - kernel) / output_tile + 1;
// return (tiles - 1) * output_tile + tile - width;
// int tiles = (width + pad - kernel) / output_tile + 1;
// return (tiles - 1) * output_tile + tile - width;
int pad_width = (width + 2 * pad - kernel) / output_tile * output_tile;
return pad_width + tile - width;
};
......@@ -141,8 +141,10 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
for (int i = 0; i < batch_size; ++i) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]);
int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]);
// int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]);
// int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]);
int pad_bottom = paddings[0];
int pad_right = paddings[1];
if (paddings[0] || paddings[1] || pad_bottom || pad_right) {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += paddings[0] + pad_bottom;
......
......@@ -336,6 +336,23 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
memset(outptr, 0, output->numel() * sizeof(float));
const float *inptr = input.data<float>();
int inter_h = (height - 2) / 6;
int inter_w = (width - 2) / 6;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
framework::Tensor input_pad;
if (remain_h > 2 || remain_w > 2) {
inter_h += (remain_h > 2);
inter_w += (remain_w > 2);
height = (inter_h - 1) * 6 + 8;
width = (inter_w - 1) * 6 + 8;
framework::DDim input_shape =
framework::make_ddim(std::vector<int>{1, channel, height, width});
PadFunctor<CPU, float> pad;
inptr = input_pad.mutable_data<float>(input_shape);
pad(input, 0, height - input.dims()[2], 0, width - input.dims()[3],
&input_pad);
}
size_t image_size = height * width;
const float transform_matrix[8] = {5.25f, -5.f, -4.25f, -2.5f,
2.f, -1.25f, 0.5f, 0.25f};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册