From c101c8f515e3c0a79e6e411d3ac22e20dbed9914 Mon Sep 17 00:00:00 2001 From: liuqi Date: Sun, 12 Nov 2017 12:20:21 +0800 Subject: [PATCH] Finish avg and max pooling of opencl kernel. --- mace/core/tensor.h | 4 + mace/kernels/conv_pool_2d_util.cc | 70 +++++-- mace/kernels/conv_pool_2d_util.h | 6 +- mace/kernels/neon/conv_2d_neon.cc | 14 +- mace/kernels/neon/depthwise_conv_neon.cc | 3 +- mace/kernels/neon/pooling_neon.cc | 15 +- mace/kernels/opencl/cl/pooling.cl | 186 +++++++++++++++++++ mace/kernels/opencl/conv_2d_opencl.cc | 4 +- mace/kernels/opencl/depthwise_conv_opencl.cc | 4 +- mace/kernels/opencl/pooling_opencl.cc | 135 ++++++++++++++ mace/kernels/pooling.h | 28 +-- mace/ops/BUILD | 16 -- mace/ops/pooling.cc | 1 + mace/ops/pooling.h | 7 +- mace/ops/pooling_test.cc | 186 ++++++++++++++++++- 15 files changed, 597 insertions(+), 82 deletions(-) create mode 100644 mace/kernels/opencl/cl/pooling.cl create mode 100644 mace/kernels/opencl/pooling_opencl.cc diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 37b01a42..24f1ec2a 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -207,7 +207,11 @@ class Tensor { os.str(""); os.clear(); + MappingGuard guard(this); for (int i = 0; i < size_; ++i) { + if ( i != 0 && i % shape_[3] == 0) { + os << "\n"; + } CASES(dtype_, (os << this->data()[i]) << ", "); } LOG(INFO) << os.str(); diff --git a/mace/kernels/conv_pool_2d_util.cc b/mace/kernels/conv_pool_2d_util.cc index 44fc4f70..f3fe94c8 100644 --- a/mace/kernels/conv_pool_2d_util.cc +++ b/mace/kernels/conv_pool_2d_util.cc @@ -61,9 +61,11 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW // based on the model accuracy. padding_size[0] = - (output_height - 1) * strides[0] + k_extent_height - input_shape[2]; + std::max(0, (output_height - 1) * strides[0] + + k_extent_height - input_shape[2]); padding_size[1] = - (output_width - 1) * strides[1] + k_extent_width - input_shape[3]; + std::max(0, (output_width - 1) * strides[1] + + k_extent_width - input_shape[3]); output_shape[0] = input_shape[0]; output_shape[1] = output_channels; @@ -110,15 +112,21 @@ void CalPaddingSize(const index_t *input_shape, // NCHW // utilize the more centered features. We need to benchmark // based on the model accuracy. padding_size[0] = - (output_height - 1) * strides[0] + k_extent_height - input_shape[2]; + std::max(0, (output_height - 1) * strides[0] + + k_extent_height - input_shape[2]); padding_size[1] = - (output_width - 1) * strides[1] + k_extent_width - input_shape[3]; + std::max(0, (output_width - 1) * strides[1] + + k_extent_width - input_shape[3]); } -void ConstructInputWithPadding(const float *input, - const index_t *input_shape, +void ConstructInputWithPadding(const Tensor *input_tensor, const int *paddings, - Tensor *output_tensor) { + Tensor *output_tensor, + bool padding_same_value) { + Tensor::MappingGuard input_mapper(input_tensor); + const float *input = input_tensor->data(); + const index_t *input_shape = input_tensor->shape().data(); + index_t batch = input_shape[0]; index_t channels = input_shape[1]; index_t height = input_shape[2]; @@ -133,21 +141,51 @@ void ConstructInputWithPadding(const float *input, output_tensor->Resize(output_shape); - Tensor::MappingGuard padded_input_mapper(output_tensor); + Tensor::MappingGuard padded_output_mapper(output_tensor); float *output_ptr = output_tensor->mutable_data(); memset(output_ptr, 0, output_tensor->size() * sizeof(float)); // Skip the padded top rows - output_ptr += padded_top * output_width; - for (int i = 0; i < batch; ++i) { - for (int j = 0; j < channels; ++j) { - for (int k = 0; k < height; ++k) { - memcpy(output_ptr + padded_left, input, width * sizeof(float)); + if (padding_same_value) { +#define COPY_INPUT \ + std::fill(output_ptr, output_ptr+padded_left, input[0]); \ + output_ptr += padded_left; \ + memcpy(output_ptr, input, width * sizeof(float)); \ + output_ptr += width; \ + std::fill(output_ptr , output_ptr + padded_right, input[width-1]); \ + output_ptr += padded_right; + + const int padded_bottom = paddings[0] - padded_top; + const int padded_right = paddings[1] - padded_left; + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + for (int k = 0; k < padded_top; ++k) { + COPY_INPUT; + } + for (int k = 0; k < height; ++k) { + COPY_INPUT; + input += width; + } + input -= width; + for (int k = 0; k < padded_bottom; ++k) { + COPY_INPUT; + } input += width; - output_ptr += output_width; } - // Skip the padded bottom in this channel and top in the next channel - output_ptr += paddings[0] * output_width; + } +#undef COPY_INPUT + } else { + output_ptr += padded_top * output_width; + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + for (int k = 0; k < height; ++k) { + memcpy(output_ptr + padded_left, input, width * sizeof(float)); + input += width; + output_ptr += output_width; + } + // Skip the padded bottom in this channel and top in the next channel + output_ptr += paddings[0] * output_width; + } } } } diff --git a/mace/kernels/conv_pool_2d_util.h b/mace/kernels/conv_pool_2d_util.h index 0424f43d..958b65d5 100644 --- a/mace/kernels/conv_pool_2d_util.h +++ b/mace/kernels/conv_pool_2d_util.h @@ -32,10 +32,10 @@ void CalPaddingSize(const index_t *input_shape, // NCHW Padding padding, int *padding_size); -void ConstructInputWithPadding(const float *input, - const index_t *input_shape, +void ConstructInputWithPadding(const Tensor *input, const int *paddings, - Tensor *output_tensor); + Tensor *output_tensor, + bool padding_same_value = false); } // namespace kernels } // namespace mace diff --git a/mace/kernels/neon/conv_2d_neon.cc b/mace/kernels/neon/conv_2d_neon.cc index 4e7752dc..e774b943 100644 --- a/mace/kernels/neon/conv_2d_neon.cc +++ b/mace/kernels/neon/conv_2d_neon.cc @@ -75,6 +75,12 @@ void Conv2dFunctor::operator()(const Tensor *input, return; } + Tensor padded_input; + // Keep this alive during kernel execution + if (paddings_[0] > 0 || paddings_[1] > 0) { + ConstructInputWithPadding(input, paddings_.data(), &padded_input); + input = &padded_input; + } Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard filter_mapper(filter); Tensor::MappingGuard bias_mapper(bias); @@ -86,14 +92,6 @@ void Conv2dFunctor::operator()(const Tensor *input, auto output_data = output->mutable_data(); auto output_shape = output->shape().data(); - // Keep this alive during kernel execution - Tensor padded_input; - if (paddings_[0] > 0 || paddings_[1] > 0) { - ConstructInputWithPadding(input_data, input->shape().data(), - paddings_.data(), &padded_input); - input_data = padded_input.data(); - input_shape = padded_input.shape().data(); - } auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1]; conv2d_neon_func(input_data, input_shape, filter_data, nullptr, bias_data, output_data, output_shape); diff --git a/mace/kernels/neon/depthwise_conv_neon.cc b/mace/kernels/neon/depthwise_conv_neon.cc index cbae961b..42b7fa35 100644 --- a/mace/kernels/neon/depthwise_conv_neon.cc +++ b/mace/kernels/neon/depthwise_conv_neon.cc @@ -67,8 +67,7 @@ void DepthwiseConv2dFunctor::operator()( // Keep this alive during kernel execution Tensor padded_input; if (paddings_[0] > 0 || paddings_[1] > 0) { - ConstructInputWithPadding(input_ptr, input_shape, paddings_.data(), - &padded_input); + ConstructInputWithPadding(input, paddings_.data(), &padded_input); input_ptr = padded_input.data(); input_shape = padded_input.shape().data(); } diff --git a/mace/kernels/neon/pooling_neon.cc b/mace/kernels/neon/pooling_neon.cc index 7112234d..0f916234 100644 --- a/mace/kernels/neon/pooling_neon.cc +++ b/mace/kernels/neon/pooling_neon.cc @@ -55,10 +55,13 @@ extern void PoolingAvgNeonK3x3S2x2Padded(const float *input, template <> void PoolingFunctor::operator()( - const float *input, - const index_t *input_shape, - float *output, - const index_t *output_shape) { + const Tensor *input_tensor, + Tensor *output_tensor) { + + const float *input = input_tensor->data(); + float *output = output_tensor->mutable_data(); + const index_t *input_shape = input_tensor->shape().data(); + const index_t *output_shape = output_tensor->shape().data(); int paddings[2]; std::vector filter_shape = {input_shape[1], input_shape[0], @@ -67,7 +70,7 @@ void PoolingFunctor::operator()( strides_, this->padding_, paddings); #ifdef __COPY_MAKE_PADDING Tensor padded_input; - ConstructInputWithPadding(input, input_shape, paddings, &padded_input); + ConstructInputWithPadding(input_tensor, paddings, &padded_input); input = padded_input.data(); input_shape = padded_input.shape().data(); #endif @@ -111,7 +114,7 @@ void PoolingFunctor::operator()( } else { // not implement yet PoolingFunctor(pooling_type_, kernels_, strides_, padding_, dilations_)( - input, input_shape, output, output_shape); + input_tensor, output_tensor); } } diff --git a/mace/kernels/opencl/cl/pooling.cl b/mace/kernels/opencl/cl/pooling.cl new file mode 100644 index 00000000..9f9e38d4 --- /dev/null +++ b/mace/kernels/opencl/cl/pooling.cl @@ -0,0 +1,186 @@ +float4 vec_pooling_3_s1(const float *input_ptr, const int in_width) { + float4 row00 = vload4(0, input_ptr); + float2 row01 = vload2(0, input_ptr + 4); + float4 row10 = vload4(0, input_ptr + in_width); + float2 row11 = vload2(0, input_ptr + in_width + 4); + float4 row20 = vload4(0, input_ptr + in_width * 2); + float2 row21 = vload2(0, input_ptr + in_width * 2 + 4); + + float8 data00 = (float8)(row00.s01212323); + float4 data01 = (float4)(row01.s0, row00.s3, row01.s01); + float8 data10 = (float8)(row10.s01212323); + float4 data11 = (float4)(row11.s0, row10.s3, row11.s01); + float8 data20 = (float8)(row20.s01212323); + float4 data21 = (float4)(row21.s0, row20.s3, row21.s01); + + float8 left = fmax(fmax(data00, data10), data20); + float4 right = fmax(fmax(data01, data11), data21); + + float4 res = fmax((float4)(left.s036, right.s1), (float4)(left.s147, right.s2)); + res = fmax(res, (float4)(left.s25, right.s03)); + + return res; +} +float4 vec_pooling_3_s2(const float *input_ptr, const int in_width) { + float8 row00 = vload8(0, input_ptr); + float row01 = *(input_ptr + 8); + float8 row10 = vload8(0, input_ptr + in_width); + float row11 = *(input_ptr + in_width + 8); + float8 row20 = vload8(0, input_ptr + in_width * 2); + float row21 = *(input_ptr + in_width * 2 + 8); + + float8 data00 = (float8)(row00.s01223445); + float4 data01 = (float4)(row00.s667, row01); + float8 data10 = (float8)(row10.s01223445); + float4 data11 = (float4)(row10.s667, row11); + float8 data20 = (float8)(row20.s01223445); + float4 data21 = (float4)(row20.s667, row21); + + float8 left = fmax(fmax(data00, data10), data20); + float4 right = fmax(fmax(data01, data11), data21); + + float4 res = fmax((float4)(left.s036, right.s1), (float4)(left.s147, right.s2)); + res = fmax(res, (float4)(left.s25, right.s03)); + + return res; +} + +float inner_pooling_3(const float *input_ptr, const int in_width) { + float3 row0 = vload3(0, input_ptr); + float3 row1 = vload3(0, input_ptr + in_width); + float3 row2 = vload3(0, input_ptr + in_width * 2); + + float3 data = fmax(fmax(row0, row1), row2); + + float res = fmax(fmax(data.s0, data.s1), data.s2); + return res; +} + +__kernel void pooling3(__global const float *input, /* n, c, h, w */ + __private const int in_height, + __private const int in_width, + __private const int out_chan_num, + __private const int out_height, + __private const int out_width, + __private const int stride, + __global float *output) { + int batch = get_global_id(0); + int out_chan_blk = get_global_id(1); + int out_pixel_blk = get_global_id(2); + + const int round_out_width = (out_width + 3) / 4; + const int out_pixel_height = out_pixel_blk / round_out_width; + const int out_pixel_width = out_pixel_blk % round_out_width; + + const int out_chan_begin = out_chan_blk * 4; + const int out_chan_end = min(out_chan_begin + 4, out_chan_num); + const int out_pixel_begin = out_pixel_height * out_width + out_pixel_width * 4; + const int out_pixel_end = min(out_pixel_begin + 4, (out_pixel_height + 1) * out_width); + const int in_pixel_begin = out_pixel_height * stride * in_width + out_pixel_width * stride * 4; + + const int in_pixel = in_height * in_width; + const int out_pixel = out_height * out_width; + + const int in_offset = batch * out_chan_num * in_pixel; + const int out_offset = batch * out_chan_num * out_pixel; + const float *input_base = input + in_offset + in_pixel_begin; + float *output_base = output + out_offset + out_pixel_begin; + + const int pixels = out_pixel_end - out_pixel_begin; + + for (int i = out_chan_begin; i < out_chan_end; ++i) { + const float *input_ptr = input_base + i * in_pixel; + float *output_ptr = output_base + i * out_pixel; + if (pixels == 4) { + float4 res; + if (stride == 1) { + res = vec_pooling_3_s1(input_ptr, in_width); + } else { + res = vec_pooling_3_s2(input_ptr, in_width); + } + vstore4(res, 0, output_ptr); + } else { + for (int p = 0; p < pixels; ++p) { + output_ptr[p] = inner_pooling_3(input_ptr, in_width); + input_ptr += stride; + } + } + } +} + +int calculate_avg_block_size(const int pos_h, + const int pos_w, + const int pool_size, + const int pad_h, + const int pad_w, + const int h_size, + const int w_size) { + const int h_start = max(0, pos_h - pad_h); + const int w_start = max(0, pos_w - pad_w); + const int h_end = min(pos_h + pool_size - pad_h, h_size); + const int w_end = min(pos_w + pool_size - pad_w, w_size); + return (h_end - h_start) * (w_end - w_start); +} + +__kernel void poolingn(__global const float *input, /* n, c, h, w */ + __private const int in_height, + __private const int in_width, + __private const int out_chan_num, + __private const int out_height, + __private const int out_width, + __private const int stride, + __private const int pad_h, + __private const int pad_w, + __private const int pooling_size, + __global float *output) { + int batch = get_global_id(0); + int out_chan_idx = get_global_id(1); + int out_pixel_idx = get_global_id(2); + + const int out_pixel_height = out_pixel_idx / out_width; + const int out_pixel_width = out_pixel_idx % out_width; + + const int out_chan_begin = out_chan_idx * 4; + const int out_chan_end = min(out_chan_begin + 4, out_chan_num); + const int in_pixel_idx = out_pixel_height * stride * in_width + + out_pixel_width * stride; + + const int in_pixel = in_height * in_width; + const int out_pixel = out_height * out_width; + + const int in_offset = batch * out_chan_num * in_pixel; + const int out_offset = batch * out_chan_num * out_pixel; + const float *input_base = input + in_offset + in_pixel_idx; + float *output_base = output + out_offset + out_pixel_idx; + + const int block_size = calculate_avg_block_size( + out_pixel_height * stride, + out_pixel_width * stride, + pooling_size, + pad_h/2, + pad_w/2, + in_height - pad_h, + in_width - pad_w); + for (int i = out_chan_begin; i < out_chan_end; ++i) { + float8 sum8 = 0.0f; + float sum1 = 0.0f; + float *output_ptr = output_base + i * out_pixel; + for (int y = 0; y < pooling_size; ++y) { + const float *input_ptr = input_base + i * in_pixel + y * in_width; + int x = 0; + for (; x < (pooling_size-8); x += 8) { + float8 data = vload8(0, input_ptr); + sum8 += data; + input_ptr += 8; + } + for (; x < pooling_size; ++x) { + sum1 += *input_ptr; + input_ptr++; + } + } + float4 sum4 = sum8.s0123 + sum8.s4567; + float2 sum2 = sum4.s01 + sum4.s23; + + *output_ptr = (sum2.s0 + sum2.s1 + sum1) / block_size; + } +} diff --git a/mace/kernels/opencl/conv_2d_opencl.cc b/mace/kernels/opencl/conv_2d_opencl.cc index db05068c..2ff4a9c5 100644 --- a/mace/kernels/opencl/conv_2d_opencl.cc +++ b/mace/kernels/opencl/conv_2d_opencl.cc @@ -47,9 +47,7 @@ void Conv2dFunctor::operator()(const Tensor *input, auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1]; if (paddings_[0] > 0 || paddings_[1] > 0) { Tensor padded_input(GetDeviceAllocator(DeviceType::OPENCL), DataTypeToEnum::v()); - Tensor::MappingGuard input_mapper(input); - ConstructInputWithPadding(input->data(), input->shape().data(), paddings_.data(), - &padded_input); + ConstructInputWithPadding(input, paddings_.data(), &padded_input); conv2d_func(&padded_input, filter, bias, output); }else { conv2d_func(input, filter, bias, output); diff --git a/mace/kernels/opencl/depthwise_conv_opencl.cc b/mace/kernels/opencl/depthwise_conv_opencl.cc index 90704974..7e75fc00 100644 --- a/mace/kernels/opencl/depthwise_conv_opencl.cc +++ b/mace/kernels/opencl/depthwise_conv_opencl.cc @@ -45,9 +45,7 @@ void DepthwiseConv2dFunctor::operator()(const Tensor auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1]; if (paddings_[0] > 0 || paddings_[1] > 0) { Tensor padded_input(GetDeviceAllocator(DeviceType::OPENCL), DataTypeToEnum::v()); - Tensor::MappingGuard input_mapper(input); - ConstructInputWithPadding(input->data(), input->shape().data(), paddings_.data(), - &padded_input); + ConstructInputWithPadding(input, paddings_.data(), &padded_input); conv2d_func(&padded_input, filter, bias, output); }else { conv2d_func(input, filter, bias, output); diff --git a/mace/kernels/opencl/pooling_opencl.cc b/mace/kernels/opencl/pooling_opencl.cc new file mode 100644 index 00000000..8daa78f6 --- /dev/null +++ b/mace/kernels/opencl/pooling_opencl.cc @@ -0,0 +1,135 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/pooling.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" + +namespace mace { +namespace kernels { + +static void Pooling3(const Tensor *input, + const int *stride, + const PoolingType type, + Tensor *output) { + if (type != MAX) { + MACE_NOT_IMPLEMENTED; + } + index_t batch = output->dim(0); + index_t channels = output->dim(1); + index_t out_height = output->dim(2); + index_t out_width = output->dim(3); + + index_t channel_blk = (channels + 3) / 4; + const index_t pixel_width = (out_width + 3) / 4 ; + const uint32_t gws[3] = { + static_cast(batch), + static_cast(channel_blk), + static_cast(pixel_width * out_height), + }; + + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + + auto max_pooling_kernel = cl::Kernel(program, "pooling3"); + + const uint32_t lws[3] = {1, 8, 128}; + + uint32_t idx = 0; + max_pooling_kernel.setArg(idx++, *(static_cast(input->buffer()))); + max_pooling_kernel.setArg(idx++, static_cast(input->dim(2))); + max_pooling_kernel.setArg(idx++, static_cast(input->dim(3))); + max_pooling_kernel.setArg(idx++, static_cast(channels)); + max_pooling_kernel.setArg(idx++, static_cast(out_height)); + max_pooling_kernel.setArg(idx++, static_cast(out_width)); + max_pooling_kernel.setArg(idx++, stride[0]); + max_pooling_kernel.setArg(idx++, *(static_cast(output->buffer()))); + + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + max_pooling_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(lws[0], lws[1], lws[2])); + MACE_CHECK(error == CL_SUCCESS); +} + +static void PoolingN(const Tensor *input, + const int *stride, + const int *paddings, + const int pooling_size, + const PoolingType type, + Tensor *output) { + if (type != AVG) { + MACE_NOT_IMPLEMENTED; + } + index_t batch = output->dim(0); + index_t channels = output->dim(1); + index_t out_height = output->dim(2); + index_t out_width = output->dim(3); + + index_t channel_blk = (channels + 3) / 4; + const uint32_t gws[3] = { + static_cast(batch), + static_cast(channel_blk), + static_cast(out_height * out_width), + }; + + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + + auto pooling_kernel = cl::Kernel(program, "poolingn"); + + const uint32_t lws[3] = {1, 8, 128}; + + uint32_t idx = 0; + pooling_kernel.setArg(idx++, *(static_cast(input->buffer()))); + pooling_kernel.setArg(idx++, static_cast(input->dim(2))); + pooling_kernel.setArg(idx++, static_cast(input->dim(3))); + pooling_kernel.setArg(idx++, static_cast(channels)); + pooling_kernel.setArg(idx++, static_cast(out_height)); + pooling_kernel.setArg(idx++, static_cast(out_width)); + pooling_kernel.setArg(idx++, stride[0]); + pooling_kernel.setArg(idx++, paddings[0]); + pooling_kernel.setArg(idx++, paddings[1]); + pooling_kernel.setArg(idx++, pooling_size); + pooling_kernel.setArg(idx++, *(static_cast(output->buffer()))); + + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + pooling_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(lws[0], lws[1], lws[2])); + MACE_CHECK(error == CL_SUCCESS); +} + +template <> +void PoolingFunctor::operator()(const Tensor *input, + Tensor *output) { + int paddings[2]; + std::vector filter_shape = {input->dim(1), input->dim(0), + kernels_[0], kernels_[1]}; + kernels::CalPaddingSize(input->shape().data(), filter_shape.data(), this->dilations_, + strides_, this->padding_, paddings); +#define POOLING_HELPER \ + switch(kernels_[0]) { \ + case 3: \ + Pooling3(input, strides_, pooling_type_, output); \ + break; \ + default: \ + PoolingN(input, strides_, paddings, kernels_[0], \ + pooling_type_, output); \ + break; \ + } + + if (paddings[0] > 0 || paddings[1] > 0) { + Tensor padded_input(GetDeviceAllocator(DeviceType::OPENCL), DataTypeToEnum::v()); + ConstructInputWithPadding(input, paddings, &padded_input, pooling_type_ == MAX); + input = &padded_input; + POOLING_HELPER + } else { + POOLING_HELPER + } +#undef POOLING_HELPER +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/pooling.h b/mace/kernels/pooling.h index 821baf3a..11c05e47 100644 --- a/mace/kernels/pooling.h +++ b/mace/kernels/pooling.h @@ -31,10 +31,14 @@ struct PoolingFunctor { padding_(padding), dilations_(dilations) {} - void operator()(const T *input, - const index_t *input_shape, - T *output, - const index_t *output_shape) { + void operator()(const Tensor *input_tensor, + Tensor *output_tensor) { + Tensor::MappingGuard in_guard(input_tensor); + Tensor::MappingGuard out_guard(output_tensor); + const T *input = input_tensor->data(); + T *output = output_tensor->mutable_data(); + const index_t *input_shape = input_tensor->shape().data(); + const index_t *output_shape = output_tensor->shape().data(); index_t batch = output_shape[0]; index_t channels = output_shape[1]; index_t height = output_shape[2]; @@ -99,6 +103,7 @@ struct PoolingFunctor { for (int h = 0; h < height; ++h) { for (int w = 0; w < width; ++w) { T sum = 0; + int block_size = 0; for (int kh = 0; kh < kernel_h; ++kh) { for (int kw = 0; kw < kernel_w; ++kw) { int inh = padded_h_start + h * stride_h + dilation_h * kh; @@ -107,10 +112,11 @@ struct PoolingFunctor { inw < input_width) { index_t input_offset = in_offset + inh * input_width + inw; sum += input[input_offset]; + block_size += 1; } } } - output[out_offset] = sum / (kernel_h * kernel_w); + output[out_offset] = sum / block_size; out_offset += 1; } } @@ -128,17 +134,13 @@ struct PoolingFunctor { template <> void PoolingFunctor::operator()( - const float *input, - const index_t *input_shape, - float *output, - const index_t *output_shape); + const Tensor *input_tensor, + Tensor *output_tensor); template <> void PoolingFunctor::operator()( - const float *input, - const index_t *input_shape, - float *output, - const index_t *output_shape); + const Tensor *input_tensor, + Tensor *output_tensor); } // namespace kernels } // namespace mace diff --git a/mace/ops/BUILD b/mace/ops/BUILD index cacf9103..e823136d 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -62,22 +62,6 @@ cc_test( ], ) -cc_test( - name = "pooling_test", - testonly = 1, - srcs = glob( - ["pooling_test.cc"], - ), - copts = ["-std=c++11"], - linkopts = ["-fopenmp"] + if_android(["-ldl"]), - linkstatic = 1, - deps = [ - ":ops", - ":test", - "@gtest//:gtest_main", - ], -) - cc_test( name = "ops_benchmark", testonly = 1, diff --git a/mace/ops/pooling.cc b/mace/ops/pooling.cc index 4b972647..1c4f1af2 100644 --- a/mace/ops/pooling.cc +++ b/mace/ops/pooling.cc @@ -12,4 +12,5 @@ REGISTER_CPU_OPERATOR(Pooling, PoolingOp); REGISTER_NEON_OPERATOR(Pooling, PoolingOp); #endif // __ARM_NEON +REGISTER_OPENCL_OPERATOR(Pooling, PoolingOp); } // namespace mace diff --git a/mace/ops/pooling.h b/mace/ops/pooling.h index d1172d6f..f62992f5 100644 --- a/mace/ops/pooling.h +++ b/mace/ops/pooling.h @@ -20,8 +20,8 @@ class PoolingOp : public ConvPool2dOpBase { pooling_type_( static_cast(OperatorBase::GetSingleArgument( "pooling_type", static_cast(AVG)))), - functor_(pooling_type_, kernels_.data(), ConvPool2dOpBase::strides_.data(), - ConvPool2dOpBase::padding_, ConvPool2dOpBase::dilations_.data()){}; + functor_(pooling_type_, kernels_.data(), this->strides_.data(), + this->padding_, this->dilations_.data()){}; bool Run() override { const Tensor *input = this->Input(INPUT); @@ -42,8 +42,7 @@ class PoolingOp : public ConvPool2dOpBase { paddings.data()); output->Resize(output_shape); - functor_(input->data(), input->shape().data(), - output->mutable_data(), output->shape().data()); + functor_(input, output); return true; }; diff --git a/mace/ops/pooling_test.cc b/mace/ops/pooling_test.cc index cd2dd609..bf2b1824 100644 --- a/mace/ops/pooling_test.cc +++ b/mace/ops/pooling_test.cc @@ -150,9 +150,11 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } -TEST_F(PoolingOpTest, MAX_k3x3s2x2) { + +template +static void SimpleMaxPooling3S2() { // Construct graph - auto &net = test_net(); + OpsTestNet net; OpDefBuilder("Pooling", "PoolingTest") .Input("Input") .Output("Output") @@ -164,12 +166,12 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) { .Finalize(net.NewOperatorDef()); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 1, 3, 9}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26}); // Run - net.RunOp(DeviceType::NEON); + net.RunOp(D); // Check auto expected = CreateTensor({1, 1, 1, 4}, {20, 22, 24, 26}); @@ -177,9 +179,95 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) { ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } -TEST_F(PoolingOpTest, AVG_k2x2s2x2) { +TEST_F(PoolingOpTest, CPUSimpleMaxPooling3S2) { + SimpleMaxPooling3S2(); +} +TEST_F(PoolingOpTest, NEONSimpleMaxPooling3S2) { + SimpleMaxPooling3S2(); +} +TEST_F(PoolingOpTest, OPENCLSimpleMaxPooling3S2) { + SimpleMaxPooling3S2(); +} + +template +static void AlignedMaxPooling3S2(Padding padding) { // Construct graph - auto &net = test_net(); + OpsTestNet net; + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .AddIntArg("pooling_type", PoolingType::MAX) + .AddIntsArg("kernels", {3, 3}) + .AddIntsArg("strides", {2, 2}) + .AddIntArg("padding", padding) + .AddIntsArg("dilations", {1, 1}) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", {3, 128, 64, 64}); + // Run + net.RunOp(D); + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on cpu + net.RunOp(); + + ExpectTensorNear(*net.GetOutput("Output"), expected, 0.001); +} + +// TODO(chenghui) : there is a bug. +//TEST_F(PoolingOpTest, NEONAlignedMaxPooling3S2) { +// AlignedMaxPooling3S2(Padding::VALID); +// AlignedMaxPooling3S2(Padding::SAME); +//} + +TEST_F(PoolingOpTest, OPENCLAlignedMaxPooling3S2) { + AlignedMaxPooling3S2(Padding::VALID); + AlignedMaxPooling3S2(Padding::SAME); +} + +template +static void UnalignedMaxPooling3S2(Padding padding) { + // Construct graph + OpsTestNet net; + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .AddIntArg("pooling_type", PoolingType::MAX) + .AddIntsArg("kernels", {3, 3}) + .AddIntsArg("strides", {2, 2}) + .AddIntArg("padding", padding) + .AddIntsArg("dilations", {1, 1}) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", {3, 113, 43, 47}); + // Run + net.RunOp(D); + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on cpu + net.RunOp(); + + ExpectTensorNear(*net.GetOutput("Output"), expected, 0.001); +} + +// TODO(chenghui) : there is a bug. +//TEST_F(PoolingOpTest, NEONUnalignedMaxPooling3S2) { +// UnalignedMaxPooling3S2(); +//} + +TEST_F(PoolingOpTest, OPENCLUnalignedMaxPooling3S2) { + UnalignedMaxPooling3S2(Padding::VALID); + UnalignedMaxPooling3S2(Padding::SAME); +} + +template +static void SimpleAvgPoolingTest() { + // Construct graph + OpsTestNet net; OpDefBuilder("Pooling", "PoolingTest") .Input("Input") .Output("Output") @@ -191,14 +279,96 @@ TEST_F(PoolingOpTest, AVG_k2x2s2x2) { .Finalize(net.NewOperatorDef()); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 1, 2, 8}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); // Run - net.RunOp(DeviceType::NEON); + net.RunOp(D); // Check auto expected = CreateTensor({1, 1, 1, 4}, {4.5, 6.5, 8.5, 10.5}); ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } + +TEST_F(PoolingOpTest, NEONSimpleAvgPooling) { + SimpleAvgPoolingTest(); +} + +TEST_F(PoolingOpTest, OPENCLSimpleAvgPooling) { + SimpleAvgPoolingTest(); +} + +template +static void AlignedAvgPoolingTest(Padding padding) { + // Construct graph + OpsTestNet net; + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .AddIntArg("pooling_type", PoolingType::AVG) + .AddIntsArg("kernels", {4, 4}) + .AddIntsArg("strides", {4, 4}) + .AddIntArg("padding", padding) + .AddIntsArg("dilations", {1, 1}) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", {3, 128, 15, 15}); + // Run + net.RunOp(D); + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on cpu + net.RunOp(); + + ExpectTensorNear(*net.GetOutput("Output"), expected, 1e-5); +} + +TEST_F(PoolingOpTest, NEONAlignedAvgPooling) { + AlignedAvgPoolingTest(Padding::VALID); + AlignedAvgPoolingTest(Padding::SAME); +} + +TEST_F(PoolingOpTest, OPENCLAlignedAvgPooling) { + AlignedAvgPoolingTest(Padding::VALID); + AlignedAvgPoolingTest(Padding::SAME); +} + +template +static void UnAlignedAvgPoolingTest(Padding padding) { + // Construct graph + OpsTestNet net; + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .AddIntArg("pooling_type", PoolingType::AVG) + .AddIntsArg("kernels", {7, 7}) + .AddIntsArg("strides", {7, 7}) + .AddIntArg("padding", padding) + .AddIntsArg("dilations", {1, 1}) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", {3, 128, 31, 37}); + // Run + net.RunOp(D); + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on cpu + net.RunOp(); + + ExpectTensorNear(*net.GetOutput("Output"), expected, 1e-5); +} + +TEST_F(PoolingOpTest, NEONUnAlignedAvgPooling) { + UnAlignedAvgPoolingTest(Padding::VALID); + UnAlignedAvgPoolingTest(Padding::SAME); +} + +TEST_F(PoolingOpTest, OPENCLUnAlignedAvgPooling) { + UnAlignedAvgPoolingTest(Padding::VALID); + UnAlignedAvgPoolingTest(Padding::SAME); +} -- GitLab