提交 97b2c1a9 编写于 作者: H hjchen2

Revert padding in winograd input transform, it looks faster than padding in advance

上级 b92caeb2
...@@ -141,8 +141,10 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) { ...@@ -141,8 +141,10 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
Tensor in_batch = input->Slice(i, i + 1); Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1); Tensor out_batch = output->Slice(i, i + 1);
int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]); // int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]);
int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]); // int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]);
int pad_bottom = paddings[0];
int pad_right = paddings[1];
if (paddings[0] || paddings[1] || pad_bottom || pad_right) { if (paddings[0] || paddings[1] || pad_bottom || pad_right) {
framework::DDim pad_shape = in_batch.dims(); framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += paddings[0] + pad_bottom; pad_shape[2] += paddings[0] + pad_bottom;
......
...@@ -336,6 +336,23 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, ...@@ -336,6 +336,23 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
memset(outptr, 0, output->numel() * sizeof(float)); memset(outptr, 0, output->numel() * sizeof(float));
const float *inptr = input.data<float>(); const float *inptr = input.data<float>();
int inter_h = (height - 2) / 6;
int inter_w = (width - 2) / 6;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
framework::Tensor input_pad;
if (remain_h > 2 || remain_w > 2) {
inter_h += (remain_h > 2);
inter_w += (remain_w > 2);
height = (inter_h - 1) * 6 + 8;
width = (inter_w - 1) * 6 + 8;
framework::DDim input_shape =
framework::make_ddim(std::vector<int>{1, channel, height, width});
PadFunctor<CPU, float> pad;
inptr = input_pad.mutable_data<float>(input_shape);
pad(input, 0, height - input.dims()[2], 0, width - input.dims()[3],
&input_pad);
}
size_t image_size = height * width; size_t image_size = height * width;
const float transform_matrix[8] = {5.25f, -5.f, -4.25f, -2.5f, const float transform_matrix[8] = {5.25f, -5.f, -4.25f, -2.5f,
2.f, -1.25f, 0.5f, 0.25f}; 2.f, -1.25f, 0.5f, 0.25f};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册