From e396c38808dd6febfcd8bd171305896bbe6d1734 Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Fri, 1 Dec 2017 18:33:09 +0800 Subject: [PATCH] Change resize bilinear from buffer to image2d --- mace/core/opencl_allocator.cc | 3 +- mace/kernels/opencl/cl/resize_bilinear.cl | 49 ++++--- mace/kernels/opencl/resize_bilinear_opencl.cc | 48 ++++--- mace/kernels/resize_bilinear.h | 127 ++++++++++-------- mace/ops/resize_bilinear.cc | 5 + mace/ops/resize_bilinear_benchmark.cc | 31 +++-- mace/ops/resize_bilinear_test.cc | 56 +++++--- mace/python/tools/tf_ops_stats.py | 1 + 8 files changed, 199 insertions(+), 121 deletions(-) diff --git a/mace/core/opencl_allocator.cc b/mace/core/opencl_allocator.cc index 75004f75..0c4cf8f0 100644 --- a/mace/core/opencl_allocator.cc +++ b/mace/core/opencl_allocator.cc @@ -54,10 +54,11 @@ void *OpenCLAllocator::NewImage(const std::vector &image_shape, cl_int error; cl::Image2D *cl_image = new cl::Image2D(OpenCLRuntime::Get()->context(), - CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR , + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, img_format, image_shape[0], image_shape[1], 0, nullptr, &error); + MACE_CHECK(error == CL_SUCCESS); return cl_image; } diff --git a/mace/kernels/opencl/cl/resize_bilinear.cl b/mace/kernels/opencl/cl/resize_bilinear.cl index f34e63cb..efb769d2 100644 --- a/mace/kernels/opencl/cl/resize_bilinear.cl +++ b/mace/kernels/opencl/cl/resize_bilinear.cl @@ -1,18 +1,19 @@ #include -// Supported data type: half/float -__kernel void resize_bilinear_nocache(__global const DATA_TYPE *input, /* n * c, h, w */ - __global DATA_TYPE *output /* n * c, h, w */, +__kernel void resize_bilinear_nocache(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ + __write_only image2d_t output, __private const float height_scale, __private const float width_scale, __private const int in_height, - __private const int in_width) { - const int c = get_global_id(0); - const int h = get_global_id(1); - const int w = get_global_id(2); - const int channels = get_global_size(0); - const int height = get_global_size(1); - const int width = get_global_size(2); + __private const int in_width, + __private const int out_height) { + const int ch_blk = get_global_id(0); + const int ch_blks = get_global_size(0); + const int w = get_global_id(1); + const int out_width = get_global_size(1); + const int hb = get_global_id(2); + const int b = hb / out_height; + const int h = hb % out_height; const float h_in = h * height_scale; const float w_in = w * width_scale; @@ -24,16 +25,26 @@ __kernel void resize_bilinear_nocache(__global const DATA_TYPE *input, /* n * c, const float h_lerp = h_in - h_lower; const float w_lerp = w_in - w_lower; - const DATA_TYPE *input_base = input + c * in_height * in_width; - DATA_TYPE *output_base = output + c * height * width; + const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + const int in_w_offset = ch_blk * in_width; + const int in_h_offset = b * in_height; - DATA_TYPE top_left = input_base[h_lower * in_width + w_lower]; - DATA_TYPE top_right = input_base[h_lower * in_width + w_upper]; - DATA_TYPE bottom_left = input_base[h_upper * in_width + w_lower]; - DATA_TYPE bottom_right = input_base[h_upper * in_width + w_upper]; + DATA_TYPE4 top_left = READ_IMAGET(input, sampler, + (int2)(in_w_offset + w_lower, in_h_offset + h_lower)); + DATA_TYPE4 top_right = READ_IMAGET(input, sampler, + (int2)(in_w_offset + w_upper, in_h_offset + h_lower)); + DATA_TYPE4 bottom_left = READ_IMAGET(input, sampler, + (int2)(in_w_offset + w_lower, in_h_offset + h_upper)); + DATA_TYPE4 bottom_right = READ_IMAGET(input, sampler, + (int2)(in_w_offset + w_upper, in_h_offset + h_upper)); - const DATA_TYPE top = top_left + (top_right - top_left) * w_lerp; - const DATA_TYPE bottom = bottom_left + (bottom_right - bottom_left) * w_lerp; - output_base[h * width + w] = top + (bottom - top) * h_lerp; + DATA_TYPE4 top = top_left + (top_right - top_left) * w_lerp; + DATA_TYPE4 bottom = bottom_left + (bottom_right - bottom_left) * w_lerp; + + DATA_TYPE4 out = top + (bottom - top) * h_lerp; + + const int out_w_offset = ch_blk * out_width; + const int out_h_offset = b * out_height; + WRITE_IMAGET(output, (int2)(out_w_offset + w, out_h_offset + h), out); } diff --git a/mace/kernels/opencl/resize_bilinear_opencl.cc b/mace/kernels/opencl/resize_bilinear_opencl.cc index 7b77afea..15355bd5 100644 --- a/mace/kernels/opencl/resize_bilinear_opencl.cc +++ b/mace/kernels/opencl/resize_bilinear_opencl.cc @@ -6,24 +6,33 @@ #include "mace/core/tensor.h" #include "mace/kernels/resize_bilinear.h" #include "mace/kernels/opencl/helper.h" +#include "mace/utils/utils.h" namespace mace { namespace kernels { -template <> -void ResizeBilinearFunctor::operator()( +template +void ResizeBilinearFunctor::operator()( const Tensor *input, const Tensor *resize_dims, Tensor *output) { const index_t batch = input->dim(0); - const index_t channels = input->dim(1); - const index_t in_height = input->dim(2); - const index_t in_width = input->dim(3); + const index_t in_height = input->dim(1); + const index_t in_width = input->dim(2); + const index_t channels = input->dim(3); + + const index_t channel_blocks = RoundUpDiv4(channels); index_t out_height; index_t out_width; GetOutputSize(resize_dims, &out_height, &out_width); MACE_CHECK(out_height > 0 && out_width > 0); - std::vector out_shape {batch, channels, out_height, out_width}; - output->Resize(out_shape); + std::vector output_shape {batch, out_height, out_width, channels}; + if (input->is_image()) { + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape); + output->ResizeImage(output_shape, output_image_shape); + } else { + output->Resize(output_shape); + } float height_scale = CalculateResizeScale(in_height, out_height, align_corners_); @@ -32,28 +41,35 @@ void ResizeBilinearFunctor::operator()( auto runtime = OpenCLRuntime::Get(); std::set built_options; built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype())); + built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(input->dtype())); auto rb_kernel = runtime->BuildKernel("resize_bilinear", "resize_bilinear_nocache", built_options); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(rb_kernel); + uint32_t idx = 0; - rb_kernel.setArg(idx++, *(static_cast(input->buffer()))); - rb_kernel.setArg(idx++, *(static_cast(output->buffer()))); + rb_kernel.setArg(idx++, *(static_cast(input->buffer()))); + rb_kernel.setArg(idx++, *(static_cast(output->buffer()))); rb_kernel.setArg(idx++, height_scale); rb_kernel.setArg(idx++, width_scale); - rb_kernel.setArg(idx++, static_cast(in_height)); - rb_kernel.setArg(idx++, static_cast(in_width)); + rb_kernel.setArg(idx++, static_cast(in_height)); + rb_kernel.setArg(idx++, static_cast(in_width)); + rb_kernel.setArg(idx++, static_cast(out_height)); auto command_queue = runtime->command_queue(); cl_int error = command_queue.enqueueNDRangeKernel( rb_kernel, cl::NullRange, - cl::NDRange(static_cast(batch * channels), - static_cast(out_height), static_cast(out_width)), - // TODO (heliangliang) tuning and fix when kwg_size < devisor - cl::NDRange(1, 16, kwg_size / 16), - NULL, OpenCLRuntime::Get()->GetDefaultEvent()); + cl::NDRange(static_cast(channel_blocks), + static_cast(out_width), + static_cast(out_height * batch)), + // TODO tuning + cl::NDRange(1, static_cast(out_width > kwg_size ? kwg_size : out_width), 1), + nullptr, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS, error); } +template struct ResizeBilinearFunctor; +template struct ResizeBilinearFunctor; + } // namespace kernels } // namespace mace diff --git a/mace/kernels/resize_bilinear.h b/mace/kernels/resize_bilinear.h index 59bb2505..27415ebd 100644 --- a/mace/kernels/resize_bilinear.h +++ b/mace/kernels/resize_bilinear.h @@ -61,63 +61,90 @@ void ResizeImage(const T *images, const index_t channels, const std::vector &xs_vec, const std::vector &ys, - float *output) { - const index_t in_channel_size = in_height * in_width; - const index_t in_batch_num_values = channels * in_channel_size; - const index_t out_channel_size = out_height * out_width; - const index_t out_batch_num_values = channels * out_channel_size; + T *output) { + const index_t in_batch_num_values = channels * in_height * in_width; + const index_t out_batch_num_values = channels * out_height * out_width; const CachedInterpolation *xs = xs_vec.data(); -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (index_t b = 0; b < batch_size; ++b) { - for (index_t c = 0; c < channels; ++c) { - const T *input_ptr = - images + in_batch_num_values * b + in_channel_size * c; - float *output_ptr = - output + out_batch_num_values * b + out_channel_size * c; - for (index_t y = 0; y < out_height; ++y) { - const T *ys_input_lower_ptr = input_ptr + ys[y].lower * in_width; - const T *ys_input_upper_ptr = input_ptr + ys[y].upper * in_width; - const float ys_lerp = ys[y].lerp; - for (index_t x = 0; x < out_width; ++x) { - auto xs_lower = xs[x].lower; - auto xs_upper = xs[x].upper; - auto xs_lerp = xs[x].lerp; - - const float top_left = ys_input_lower_ptr[xs_lower]; - const float top_right = ys_input_lower_ptr[xs_upper]; - const float bottom_left = ys_input_upper_ptr[xs_lower]; - const float bottom_right = ys_input_upper_ptr[xs_upper]; - - output_ptr[x] = ComputeLerp(top_left, top_right, bottom_left, - bottom_right, xs_lerp, ys_lerp); + const T *batch_input_ptr = images + in_batch_num_values * b;; + T *batch_output_ptr = output + out_batch_num_values * b; + + for (index_t y = 0; y < out_height; ++y) { + const T *y_lower_input_ptr = + batch_input_ptr + ys[y].lower * in_width * channels; + const T *y_upper_input_ptr = + batch_input_ptr + ys[y].upper * in_width * channels; + T *y_output_ptr = batch_output_ptr + y * out_width * channels; + const float ys_lerp = ys[y].lerp; + + for (index_t x = 0; x < out_width; ++x) { + const float xs_lerp = xs[x].lerp; + const T *top_left_ptr = y_lower_input_ptr + xs[x].lower * channels; + const T *top_right_ptr = y_lower_input_ptr + xs[x].upper * channels; + const T *bottom_left_ptr = y_upper_input_ptr + xs[x].lower * channels; + const T *bottom_right_ptr = y_upper_input_ptr + xs[x].upper * channels; + T *output_ptr = y_output_ptr + x * channels; + + for (index_t c = 0; c < channels; ++c) { + const T top_left = top_left_ptr[c]; + const T top_right = top_right_ptr[c]; + const T bottom_left = bottom_left_ptr[c]; + const T bottom_right = bottom_right_ptr[c]; + + output_ptr[c] = ComputeLerp(top_left, top_right, bottom_left, + bottom_right, xs_lerp, ys_lerp); } - output_ptr += out_width; } } } } } +struct ResizeBilinearFunctorBase { + ResizeBilinearFunctorBase(const std::vector &size, + bool align_corners) + : align_corners_(align_corners), size_(size) {} + + protected: + void GetOutputSize(const Tensor *resize_dims, + index_t *out_height, + index_t *out_width) { + if (size_[0] < 0 || size_[1] < 0) { + MACE_CHECK(resize_dims != nullptr && resize_dims->dim_size() == 1); + Tensor::MappingGuard resize_dims_mapper(resize_dims); + auto dims_data = resize_dims->data(); + *out_height = dims_data[0]; + *out_width = dims_data[1]; + } else { + *out_height = size_[0]; + *out_width = size_[1]; + } + } + + bool align_corners_; + std::vector size_; +}; + template -class ResizeBilinearFunctor { - public: +struct ResizeBilinearFunctor : ResizeBilinearFunctorBase { ResizeBilinearFunctor(const std::vector &size, bool align_corners) - : align_corners_(align_corners), size_(size) {} + : ResizeBilinearFunctorBase(size, align_corners) {} void operator()(const Tensor *input, const Tensor *resize_dims, Tensor *output) { const index_t batch = input->dim(0); - const index_t channels = input->dim(1); - const index_t in_height = input->dim(2); - const index_t in_width = input->dim(3); + const index_t in_height = input->dim(1); + const index_t in_width = input->dim(2); + const index_t channels = input->dim(3); index_t out_height; index_t out_width; GetOutputSize(resize_dims, &out_height, &out_width); MACE_CHECK(out_height > 0 && out_width > 0); - std::vector out_shape{batch, channels, out_height, out_width}; + std::vector out_shape{batch, out_height, out_width, channels}; output->Resize(out_shape); Tensor::MappingGuard input_mapper(input); @@ -146,32 +173,18 @@ class ResizeBilinearFunctor { ResizeImage(input_data, batch, in_height, in_width, out_height, out_width, channels, xs, ys, output_data); } +}; - protected: - void GetOutputSize(const Tensor *resize_dims, - index_t *out_height, - index_t *out_width) { - if (size_[0] < 0 || size_[1] < 0) { - MACE_CHECK(resize_dims != nullptr && resize_dims->dim_size() == 1); - Tensor::MappingGuard resize_dims_mapper(resize_dims); - auto dims_data = resize_dims->data(); - *out_height = dims_data[0]; - *out_width = dims_data[1]; - } else { - *out_height = size_[0]; - *out_width = size_[1]; - } - } +template +struct ResizeBilinearFunctor : ResizeBilinearFunctorBase { + ResizeBilinearFunctor(const std::vector &size, bool align_corners) + : ResizeBilinearFunctorBase(size, align_corners) {} - private: - bool align_corners_; - std::vector size_; + void operator()(const Tensor *input, + const Tensor *resize_dims, + Tensor *output); }; -template <> -void ResizeBilinearFunctor::operator()( - const Tensor *input, const Tensor *resize_dims, Tensor *output); - } // namespace kernels } // namespace mace diff --git a/mace/ops/resize_bilinear.cc b/mace/ops/resize_bilinear.cc index c3510f68..8eae7181 100644 --- a/mace/ops/resize_bilinear.cc +++ b/mace/ops/resize_bilinear.cc @@ -23,4 +23,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("ResizeBilinear") .Build(), ResizeBilinearOp); +REGISTER_OPENCL_OPERATOR(OpKeyBuilder("ResizeBilinear") + .TypeConstraint("T") + .Build(), + ResizeBilinearOp); + } // namespace mace diff --git a/mace/ops/resize_bilinear_benchmark.cc b/mace/ops/resize_bilinear_benchmark.cc index 8429fd6b..d9453908 100644 --- a/mace/ops/resize_bilinear_benchmark.cc +++ b/mace/ops/resize_bilinear_benchmark.cc @@ -19,18 +19,30 @@ static void ResizeBilinearBenchmark(int iters, mace::testing::StopTiming(); OpsTestNet net; - OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark") - .Input("Input") - .Input("OutSize") - .Output("Output") - .AddIntsArg("size", {output_height, output_width}) - .Finalize(net.NewOperatorDef()); // Add input data net.AddRandomInput("Input", - {batch, channels, input_height, input_width}); + {batch, input_height, input_width, channels}); net.AddInputFromArray("OutSize", {2}, {output_height, output_width}); + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark") + .Input("InputImage") + .Input("OutSize") + .Output("OutputImage") + .AddIntsArg("size", {output_height, output_width}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark") + .Input("Input") + .Input("OutSize") + .Output("Output") + .AddIntsArg("size", {output_height, output_width}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } // Warm-up for (int i = 0; i < 5; ++i) { @@ -58,9 +70,12 @@ static void ResizeBilinearBenchmark(int iters, #define BM_RESIZE_BILINEAR(N, C, H0, W0, H1, W1, TYPE) \ BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, CPU); \ - BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, NEON); \ BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, OPENCL); +// SNPE 835 GPU: 6870us +BM_RESIZE_BILINEAR(1, 128, 120, 120, 480, 480, half); +BM_RESIZE_BILINEAR(1, 128, 120, 120, 480, 480, float); + BM_RESIZE_BILINEAR(1, 256, 7, 7, 15, 15, float); BM_RESIZE_BILINEAR(1, 256, 15, 15, 30, 30, float); BM_RESIZE_BILINEAR(1, 128, 30, 30, 60, 60, float); diff --git a/mace/ops/resize_bilinear_test.cc b/mace/ops/resize_bilinear_test.cc index 7b7cee9d..3e50c3b4 100644 --- a/mace/ops/resize_bilinear_test.cc +++ b/mace/ops/resize_bilinear_test.cc @@ -23,14 +23,14 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) { // Add input data vector input(24); std::iota(begin(input), end(input), 0); - net.AddInputFromArray("Input", {1, 3, 2, 4}, input); + net.AddInputFromArray("Input", {1, 2, 4, 3}, input); net.AddInputFromArray("OutSize", {2}, {1, 2}); // Run net.RunOp(); // Check - auto expected = CreateTensor({1, 3, 1, 2}, {0, 2, 8, 10, 16, 18}); + auto expected = CreateTensor({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } @@ -49,14 +49,14 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { // Add input data vector input(24); std::iota(begin(input), end(input), 0); - net.AddInputFromArray("Input", {1, 3, 2, 4}, input); + net.AddInputFromArray("Input", {1, 2, 4, 3}, input); net.AddInputFromArray("OutSize", {2}, {1, 2}); // Run net.RunOp(); // Check - auto expected = CreateTensor({1, 3, 1, 2}, {0, 3, 8, 11, 16, 19}); + auto expected = CreateTensor({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } @@ -65,6 +65,7 @@ template void TestRandomResizeBilinear() { srand(time(nullptr)); testing::internal::LogToStderr(); + for (int round = 0; round < 10; ++round) { int batch = 1 + rand() % 5; int channels = 1 + rand() % 100; @@ -72,39 +73,54 @@ void TestRandomResizeBilinear() { int width = 1 + rand() % 100; int in_height = 1 + rand() % 100; int in_width = 1 + rand() % 100; + int align_corners = rand() % 1; // Construct graph OpsTestNet net; - OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") - .Input("Input") - .Input("OutSize") - .Output("Output") - .AddIntArg("align_corners", 1) - .AddIntsArg("size", {height, width}) - .Finalize(net.NewOperatorDef()); - // Add input data net.AddRandomInput("Input", - {batch, channels, in_height, in_width}); + {batch, in_height, in_width, channels}); net.AddInputFromArray("OutSize", {2}, {height, width}); - // Run - net.RunOp(D); - Tensor actual; - actual.Copy(*net.GetOutput("Output")); - + OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") + .Input("Input") + .Input("OutSize") + .Output("Output") + .AddIntArg("align_corners", align_corners) + .AddIntsArg("size", {height, width}) + .Finalize(net.NewOperatorDef()); // Run on CPU net.RunOp(DeviceType::CPU); - Tensor *expected = net.GetOutput("Output"); + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + + OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") + .Input("InputImage") + .Input("OutSize") + .Output("OutputImage") + .AddIntArg("align_corners", align_corners) + .AddIntsArg("size", {height, width}) + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + ImageToBuffer(net, "OutputImage", "DeviceOutput", kernels::BufferType::IN_OUT); + } else { + // TODO support NEON + } // Check - ExpectTensorNear(*expected, actual, 0.001); + ExpectTensorNear(expected, *net.GetOutput("DeviceOutput"), 0.001); } } +/* TEST_F(ResizeBilinearTest, NEONRandomResizeBilinear) { TestRandomResizeBilinear(); } +*/ TEST_F(ResizeBilinearTest, OPENCLRandomResizeBilinear) { TestRandomResizeBilinear(); diff --git a/mace/python/tools/tf_ops_stats.py b/mace/python/tools/tf_ops_stats.py index d1016aff..d60487a9 100644 --- a/mace/python/tools/tf_ops_stats.py +++ b/mace/python/tools/tf_ops_stats.py @@ -92,6 +92,7 @@ def main(unused_args): size = tensor_values[input_name] break key = '%s(size=%s, align_corners=%s)' % (op.type, size, align_corners) + print(key) hist_inc(stats, key) elif op.type in ['AvgPool', 'MaxPool']: padding = op.get_attr('padding') -- GitLab