From cdd4b4c7084335fb398b1845bd8c1825e71cbc41 Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Tue, 7 Nov 2017 20:44:04 +0800 Subject: [PATCH] Add naive resize bilinear opencl kernel without caching --- mace/kernels/conv_2d.h | 24 ++++---- mace/kernels/neon/conv_2d_neon.cc | 4 +- mace/kernels/opencl/cl/resize_bilinear.cl | 36 ++++++++++++ mace/kernels/opencl/conv_2d_opencl.cc | 4 +- mace/kernels/opencl/conv_2d_opencl_1x1.cc | 42 +++++++------- mace/kernels/opencl/conv_2d_opencl_3x3.cc | 6 +- mace/kernels/opencl/depthwise_conv_opencl.cc | 4 +- mace/kernels/opencl/resize_bilinear_opencl.cc | 48 +++++++++++++++- mace/kernels/resize_bilinear.h | 2 + mace/ops/resize_bilinear.h | 1 - mace/ops/resize_bilinear_benchmark.cc | 8 +-- mace/ops/resize_bilinear_test.cc | 56 ++++++++++++++++++- 12 files changed, 185 insertions(+), 50 deletions(-) create mode 100644 mace/kernels/opencl/cl/resize_bilinear.cl diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index c960f285..e32dc92d 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -27,18 +27,18 @@ struct Conv2dFunctor { MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(output); - index_t batch = output->shape()[0]; - index_t channels = output->shape()[1]; - index_t height = output->shape()[2]; - index_t width = output->shape()[3]; - - index_t input_batch = input->shape()[0]; - index_t input_channels = input->shape()[1]; - index_t input_height = input->shape()[2]; - index_t input_width = input->shape()[3]; - - index_t kernel_h = filter->shape()[2]; - index_t kernel_w = filter->shape()[3]; + index_t batch = output->dim(0); + index_t channels = output->dim(1); + index_t height = output->dim(2); + index_t width = output->dim(3); + + index_t input_batch = input->dim(0); + index_t input_channels = input->dim(1); + index_t input_height = input->dim(2); + index_t input_width = input->dim(3); + + index_t kernel_h = filter->dim(2); + index_t kernel_w = filter->dim(3); int stride_h = strides_[0]; int stride_w = strides_[1]; diff --git a/mace/kernels/neon/conv_2d_neon.cc b/mace/kernels/neon/conv_2d_neon.cc index 7f912c3d..4e7752dc 100644 --- a/mace/kernels/neon/conv_2d_neon.cc +++ b/mace/kernels/neon/conv_2d_neon.cc @@ -61,8 +61,8 @@ void Conv2dFunctor::operator()(const Tensor *input, {nullptr, nullptr}, {Conv2dNeonK5x5S1, nullptr}}; // not implement yet - index_t kernel_h = filter->shape()[2]; - index_t kernel_w = filter->shape()[3]; + index_t kernel_h = filter->dim(2); + index_t kernel_w = filter->dim(3); if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || selector[kernel_h - 1][strides_[0] - 1] == nullptr) { diff --git a/mace/kernels/opencl/cl/resize_bilinear.cl b/mace/kernels/opencl/cl/resize_bilinear.cl new file mode 100644 index 00000000..b1e987ec --- /dev/null +++ b/mace/kernels/opencl/cl/resize_bilinear.cl @@ -0,0 +1,36 @@ +__kernel void resize_bilinear_nocache(__global const float *input, /* n * c, h, w */ + __global float *output /* n * c, h, w */, + __private const float height_scale, + __private const float width_scale, + __private const int in_height, + __private const int in_width) { + const int c = get_global_id(0); + const int h = get_global_id(1); + const int w = get_global_id(2); + const int channels = get_global_size(0); + const int height = get_global_size(1); + const int width = get_global_size(2); + + const float h_in = h * height_scale; + const float w_in = w * width_scale; + const int h_lower = max(0, (int) floor(h_in)); + const int h_upper = min(in_height - 1, h_lower + 1); + const int w_lower = max(0, (int) floor(w_in)); + const int w_upper = min(in_width - 1, w_lower + 1); + + const float h_lerp = h_in - h_lower; + const float w_lerp = w_in - w_lower; + + const float *input_base = input + c * in_height * in_width; + float *output_base = output + c * height * width; + + float top_left = input_base[h_lower * in_width + w_lower]; + float top_right = input_base[h_lower * in_width + w_upper]; + float bottom_left = input_base[h_upper * in_width + w_lower]; + float bottom_right = input_base[h_upper * in_width + w_upper]; + + const float top = top_left + (top_right - top_left) * w_lerp; + const float bottom = bottom_left + (bottom_right - bottom_left) * w_lerp; + output_base[h * width + w] = top + (bottom - top) * h_lerp; +} + diff --git a/mace/kernels/opencl/conv_2d_opencl.cc b/mace/kernels/opencl/conv_2d_opencl.cc index fcdb3de2..db05068c 100644 --- a/mace/kernels/opencl/conv_2d_opencl.cc +++ b/mace/kernels/opencl/conv_2d_opencl.cc @@ -30,8 +30,8 @@ void Conv2dFunctor::operator()(const Tensor *input, {nullptr, nullptr}, {nullptr, nullptr}}; - index_t kernel_h = filter->shape()[2]; - index_t kernel_w = filter->shape()[3]; + index_t kernel_h = filter->dim(2); + index_t kernel_w = filter->dim(3); if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || selector[kernel_h - 1][strides_[0] - 1] == nullptr) { diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index 130ca4b7..3e988d83 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -15,11 +15,11 @@ void Conv1x1Naive(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output) { - const index_t batch = output->shape()[0]; - const index_t channels = output->shape()[1]; - const index_t height = output->shape()[2]; - const index_t width = output->shape()[3]; - const index_t input_channels = input->shape()[1]; + const index_t batch = output->dim(0); + const index_t channels = output->dim(1); + const index_t height = output->dim(2); + const index_t width = output->dim(3); + const index_t input_channels = input->dim(1); auto runtime = OpenCLRuntime::Get(); auto program = runtime->program(); @@ -46,11 +46,11 @@ void Conv1x1V2(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output) { - const index_t batch = output->shape()[0]; - const index_t channels = output->shape()[1]; - const index_t height = output->shape()[2]; - const index_t width = output->shape()[3]; - const index_t input_channels = input->shape()[1]; + const index_t batch = output->dim(0); + const index_t channels = output->dim(1); + const index_t height = output->dim(2); + const index_t width = output->dim(3); + const index_t input_channels = input->dim(1); auto runtime = OpenCLRuntime::Get(); auto program = runtime->program(); @@ -88,11 +88,11 @@ void Conv1x1V3(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output) { - const index_t batch = output->shape()[0]; - const index_t channels = output->shape()[1]; - const index_t height = output->shape()[2]; - const index_t width = output->shape()[3]; - const index_t input_channels = input->shape()[1]; + const index_t batch = output->dim(0); + const index_t channels = output->dim(1); + const index_t height = output->dim(2); + const index_t width = output->dim(3); + const index_t input_channels = input->dim(1); auto runtime = OpenCLRuntime::Get(); auto program = runtime->program(); @@ -174,13 +174,13 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output) { - const index_t batch = output->shape()[0]; - const index_t height = output->shape()[2]; - const index_t width = output->shape()[3]; + const index_t batch = output->dim(0); + const index_t height = output->dim(2); + const index_t width = output->dim(3); - const index_t input_batch = input->shape()[0]; - const index_t input_height = input->shape()[2]; - const index_t input_width = input->shape()[3]; + const index_t input_batch = input->dim(0); + const index_t input_height = input->dim(2); + const index_t input_width = input->dim(3); MACE_CHECK(input_batch == batch && input_height == height && input_width == width); diff --git a/mace/kernels/opencl/conv_2d_opencl_3x3.cc b/mace/kernels/opencl/conv_2d_opencl_3x3.cc index 41dccf4c..537204f7 100644 --- a/mace/kernels/opencl/conv_2d_opencl_3x3.cc +++ b/mace/kernels/opencl/conv_2d_opencl_3x3.cc @@ -11,9 +11,9 @@ namespace kernels { static void InnerConv2dK3x3S12(const Tensor *input, const Tensor *filter, const Tensor *bias, const uint32_t stride, Tensor *output) { - const index_t channels = output->shape()[1]; - const index_t height = output->shape()[2]; - const index_t width = output->shape()[3]; + const index_t channels = output->dim(1); + const index_t height = output->dim(2); + const index_t width = output->dim(3); MACE_CHECK(input->dim(0) == output->dim(0)); diff --git a/mace/kernels/opencl/depthwise_conv_opencl.cc b/mace/kernels/opencl/depthwise_conv_opencl.cc index 7dcb996f..90704974 100644 --- a/mace/kernels/opencl/depthwise_conv_opencl.cc +++ b/mace/kernels/opencl/depthwise_conv_opencl.cc @@ -27,8 +27,8 @@ void DepthwiseConv2dFunctor::operator()(const Tensor {nullptr, nullptr}, {nullptr, nullptr}}; - index_t kernel_h = filter->shape()[2]; - index_t kernel_w = filter->shape()[3]; + index_t kernel_h = filter->dim(2); + index_t kernel_w = filter->dim(3); if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || selector[kernel_h - 1][strides_[0] - 1] == nullptr) { diff --git a/mace/kernels/opencl/resize_bilinear_opencl.cc b/mace/kernels/opencl/resize_bilinear_opencl.cc index 62f58d0b..f0147abd 100644 --- a/mace/kernels/opencl/resize_bilinear_opencl.cc +++ b/mace/kernels/opencl/resize_bilinear_opencl.cc @@ -2,15 +2,59 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // -#include "mace/kernels/resize_bilinear.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/tensor.h" +#include "mace/kernels/resize_bilinear.h" namespace mace { namespace kernels { template <> void ResizeBilinearFunctor::operator()( - const Tensor *input, const Tensor *resize_dims, Tensor *output) {} + const Tensor *input, const Tensor *resize_dims, Tensor *output) { + const index_t batch = input->dim(0); + const index_t channels = input->dim(1); + const index_t in_height = input->dim(2); + const index_t in_width = input->dim(3); + + index_t out_height; + index_t out_width; + { + MACE_CHECK(resize_dims->dim_size() == 1); + Tensor::MappingGuard resize_dims_mapper(resize_dims); + auto dims_data = resize_dims->data(); + out_height = dims_data[0]; + out_width = dims_data[1]; + } + + std::vector out_shape{batch, channels, out_height, out_width}; + output->Resize(out_shape); + + float height_scale = + CalculateResizeScale(in_height, out_height, align_corners_); + float width_scale = CalculateResizeScale(in_width, out_width, align_corners_); + + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + + auto rb_kernel = cl::Kernel(program, "resize_bilinear_nocache"); + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(rb_kernel); + uint32_t idx = 0; + rb_kernel.setArg(idx++, *(static_cast(input->buffer()))); + rb_kernel.setArg(idx++, *(static_cast(output->buffer()))); + rb_kernel.setArg(idx++, static_cast(height_scale)); + rb_kernel.setArg(idx++, static_cast(width_scale)); + rb_kernel.setArg(idx++, static_cast(in_height)); + rb_kernel.setArg(idx++, static_cast(in_width)); + + auto command_queue = runtime->command_queue(); + cl_int error = command_queue.enqueueNDRangeKernel( + rb_kernel, cl::NullRange, + cl::NDRange(static_cast(batch * channels), + static_cast(out_height), static_cast(out_width)), + cl::NDRange(1, 16, kwg_size / 16)); + MACE_CHECK(error == CL_SUCCESS, error); +} } // namespace kernels } // namespace mace diff --git a/mace/kernels/resize_bilinear.h b/mace/kernels/resize_bilinear.h index 1e59d112..aaed3d9c 100644 --- a/mace/kernels/resize_bilinear.h +++ b/mace/kernels/resize_bilinear.h @@ -127,6 +127,8 @@ struct ResizeBilinearFunctor { vector out_shape{n, channels, out_height, out_width}; output->Resize(out_shape); + Tensor::MappingGuard input_mapper(input); + Tensor::MappingGuard output_mapper(output); const T *input_data = input->data(); T *output_data = output->mutable_data(); diff --git a/mace/ops/resize_bilinear.h b/mace/ops/resize_bilinear.h index fd6d95bf..e25e8ebc 100644 --- a/mace/ops/resize_bilinear.h +++ b/mace/ops/resize_bilinear.h @@ -28,7 +28,6 @@ class ResizeBilinearOp : public Operator { MACE_CHECK(resize_dims->dim_size() == 1, "resize dim must be 2-dimensional.", resize_dims->dim_size()); - functor_(input, resize_dims, output); return true; } diff --git a/mace/ops/resize_bilinear_benchmark.cc b/mace/ops/resize_bilinear_benchmark.cc index c8af5ac7..37c07cd2 100644 --- a/mace/ops/resize_bilinear_benchmark.cc +++ b/mace/ops/resize_bilinear_benchmark.cc @@ -26,10 +26,10 @@ static void ResizeBilinearBenchmark(int iters, .Finalize(net.NewOperatorDef()); // Add input data - net.AddRandomInput( - "Input", {batch, channels, input_height, input_width}); - net.AddInputFromArray( - "OutSize", {2}, {output_height, output_width}); + net.AddRandomInput("Input", + {batch, channels, input_height, input_width}); + net.AddInputFromArray("OutSize", {2}, + {output_height, output_width}); // Warm-up for (int i = 0; i < 5; ++i) { diff --git a/mace/ops/resize_bilinear_test.cc b/mace/ops/resize_bilinear_test.cc index 9d95564b..d569ad71 100644 --- a/mace/ops/resize_bilinear_test.cc +++ b/mace/ops/resize_bilinear_test.cc @@ -10,7 +10,7 @@ using namespace mace; class ResizeBilinearTest : public OpsTestBase {}; -TEST_F(ResizeBilinearTest, ResizeBilinearWOAlignCorners) { +TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) { testing::internal::LogToStderr(); // Construct graph auto &net = test_net(); @@ -60,3 +60,57 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } + +template +void TestRandomResizeBilinear() { + srand(time(nullptr)); + testing::internal::LogToStderr(); + for (int round = 0; round < 10; ++round) { + index_t batch = 1 + rand() % 5; + index_t channels = 1 + rand() % 100; + index_t height = 1 + rand() % 100; + index_t width = 1 + rand() % 100; + index_t in_height = 1 + rand() % 100; + index_t in_width = 1 + rand() % 100; + + // Construct graph + OpsTestNet net; + OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") + .Input("Input") + .Input("OutSize") + .Output("Output") + .AddIntArg("align_corners", 1) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", + {batch, channels, in_height, in_width}); + net.AddInputFromArray("OutSize", {2}, {height, width}); + /* + vector input(24); + std::iota(begin(input), end(input), 0); + net.AddInputFromArray("Input", {1, 3, 2, 4}, input); + net.AddInputFromArray("OutSize", {2}, {1, 2}); + */ + + // Run + net.RunOp(D); + Tensor actual; + actual.Copy(*net.GetOutput("Output")); + + // Run on CPU + net.RunOp(DeviceType::CPU); + Tensor *expected = net.GetOutput("Output"); + + // Check + ExpectTensorNear(*expected, actual, 0.001); + } +} + +TEST_F(ResizeBilinearTest, NEONRandomResizeBilinear) { + TestRandomResizeBilinear(); +} + +TEST_F(ResizeBilinearTest, OPENCLRandomResizeBilinear) { + TestRandomResizeBilinear(); +} -- GitLab