diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index 4f95a9e7abd446ec8839b1998e13e5c7594dfd97..6ce4ed75e0221c46054a6a92f31adef3b18898ff 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -139,6 +139,7 @@ const std::map OpenCLRuntime::program_map_ = { {"addn", "addn.cl"}, {"batch_norm", "batch_norm.cl"}, + {"conv_2d", "conv_2d.cl"}, {"conv_2d_1x1", "conv_2d_1x1.cl"}, {"conv_2d_3x3", "conv_2d_3x3.cl"}, {"depthwise_conv_3x3", "depthwise_conv_3x3.cl"}, diff --git a/mace/kernels/opencl/cl/conv_2d.cl b/mace/kernels/opencl/cl/conv_2d.cl new file mode 100644 index 0000000000000000000000000000000000000000..d4a65670ecb2028fb4a8b409845aba1d99c4ae55 --- /dev/null +++ b/mace/kernels/opencl/cl/conv_2d.cl @@ -0,0 +1,149 @@ +#include + +__kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ + __read_only image2d_t filter, /* cout%4 * cin * kw * kh, cout/4 */ +#ifdef BIAS + __read_only image2d_t bias, /* cout%4 * cout/4 */ +#endif + __write_only image2d_t output, + __private const int in_height, + __private const int in_width, + __private const int in_ch_blks, + __private const int out_height, + __private const int out_width, + __private const int filter_height, + __private const int filter_width, + __private const int padding_top, + __private const int padding_left) { + const int out_ch_blk = get_global_id(0); + const int out_w_blk = get_global_id(1); + const int out_w_blks = get_global_size(1); + const int out_hb = get_global_id(2); + const int rounded_in_ch = in_ch_blks * 4; + + const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +#ifdef BIAS + DATA_TYPE4 out0 = + READ_IMAGET(bias, sampler, (int2)(out_ch_blk, 0)); + DATA_TYPE4 out1 = out0; + DATA_TYPE4 out2 = out0; + DATA_TYPE4 out3 = out0; +#else + DATA_TYPE4 out0 = 0; + DATA_TYPE4 out1 = 0; + DATA_TYPE4 out2 = 0; + DATA_TYPE4 out3 = 0; +#endif + +#if STRIDE == 1 + int in_width0 = out_w_blk - padding_left; + int in_width1 = in_width0 + out_w_blks; + int in_width2 = in_width1 + out_w_blks; + int in_width3 = in_width2 + out_w_blks; + const int height_idx = (out_hb % out_height) - padding_top; +#else + int in_width0 = out_w_blk * 2 - padding_left; + int in_width1 = (out_w_blk + out_w_blks) * 2 - padding_left; + int in_width2 = (out_w_blk + 2 * out_w_blks) * 2 - padding_left; + int in_width3 = (out_w_blk + 3 * out_w_blks) * 2 - padding_left; + const int height_idx = (out_hb % out_height) * 2 - padding_top; +#endif + + const int batch_idx = (out_hb / out_height) * in_height; + + DATA_TYPE4 in0, in1, in2, in3; + DATA_TYPE4 weights0, weights1, weights2, weights3; + int in_idx, in_width_idx; + // Unrolling this loop hurt perfmance + for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { + for (short hb_idx = 0; hb_idx < filter_height; ++hb_idx) { + for (short width_idx = 0; width_idx < filter_width; ++width_idx) { + + in_idx = in_ch_blk * in_width; + + int in_hb_value = height_idx + hb_idx; + in_hb_value = select(in_hb_value + batch_idx, + -1, + (in_hb_value < 0 || in_hb_value >= in_height)); + + int in_width_value; +#define READ_INPUT(i) \ + in_width_value = in_width##i + width_idx; \ + in_width_value = select(in_idx + in_width_value, \ + -1, \ + (in_width_value < 0 || in_width_value >= in_width)); \ + in##i = READ_IMAGET(input, sampler, (int2)(in_width_value, in_hb_value)); + + READ_INPUT(0); + READ_INPUT(1); + READ_INPUT(2); + READ_INPUT(3); + +#undef READ_INPUT + + int filter_idx = (in_ch_blk << 2) + (hb_idx * filter_width + width_idx) * rounded_in_ch; + weights0 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk)); + weights1 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 1, out_ch_blk)); + weights2 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 2, out_ch_blk)); + weights3 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 3, out_ch_blk)); + + // Will prefetch L2 improve performance? How to pretch image data? + + // Interleaving load and mul does not improve performance as expected + out0 += in0.x * weights0; + out0 += in0.y * weights1; + out0 += in0.z * weights2; + out0 += in0.w * weights3; + + out1 += in1.x * weights0; + out1 += in1.y * weights1; + out1 += in1.z * weights2; + out1 += in1.w * weights3; + + out2 += in2.x * weights0; + out2 += in2.y * weights1; + out2 += in2.z * weights2; + out2 += in2.w * weights3; + + out3 += in3.x * weights0; + out3 += in3.y * weights1; + out3 += in3.z * weights2; + out3 += in3.w * weights3; + + } + } + } + +#ifdef FUSED_RELU + // TODO relux + out0 = fmax(out0, 0); + out1 = fmax(out1, 0); + out2 = fmax(out2, 0); + out3 = fmax(out3, 0); +#endif + + const int out_x_base = out_ch_blk * out_width; + int w = out_w_blk; + WRITE_IMAGET(output, + (int2)(out_x_base + w, out_hb), + out0); + + w += out_w_blks; + if (w >= out_width) return; + WRITE_IMAGET(output, + (int2)(out_x_base + w, out_hb), + out1); + + w += out_w_blks; + if (w >= out_width) return; + WRITE_IMAGET(output, + (int2)(out_x_base + w, out_hb), + out2); + + w += out_w_blks; + if (w >= out_width) return; + WRITE_IMAGET(output, + (int2)(out_x_base + w, out_hb), + out3); + +} diff --git a/mace/kernels/opencl/conv_2d_opencl.cc b/mace/kernels/opencl/conv_2d_opencl.cc index cb48be92c1112a8148be742f8a029ef37d37e716..c40481543796215c80f4367e8e5f01a59b32c3be 100644 --- a/mace/kernels/opencl/conv_2d_opencl.cc +++ b/mace/kernels/opencl/conv_2d_opencl.cc @@ -28,6 +28,11 @@ extern void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter, const int *padding, const DataType dt, Tensor *output); +extern void Conv2dOpencl(const Tensor *input, const Tensor *filter, + const Tensor *bias, const bool fused_relu, + const uint32_t stride, const int *padding, + const DataType dt, Tensor *output); + template void Conv2dFunctor::operator()(const Tensor *input, const Tensor *filter, @@ -47,17 +52,13 @@ void Conv2dFunctor::operator()(const Tensor *input, index_t kernel_h = filter->dim(0); index_t kernel_w = filter->dim(1); - if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || - strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || - selector[kernel_h - 1][strides_[0] - 1] == nullptr) { + if (!input->is_image() || strides_[0] != strides_[1] || + strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1) { LOG(WARNING) << "OpenCL conv2d kernel with " << "filter" << kernel_h << "x" << kernel_w << "," << " stride " << strides_[0] << "x" << strides_[1] << " is not implemented yet, using slow version"; - // TODO(heliangliang) The CPU/NEON kernel should map the buffer - Conv2dFunctor(strides_, paddings_, dilations_)( - input, filter, bias, output); - return; + MACE_NOT_IMPLEMENTED; } std::vector output_shape(4); @@ -66,16 +67,18 @@ void Conv2dFunctor::operator()(const Tensor *input, input->shape().data(), filter->shape().data(), dilations_, strides_, paddings_, output_shape.data(), paddings.data()); - if (input->is_image()) { - std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape); - output->ResizeImage(output_shape, output_image_shape); + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape); + output->ResizeImage(output_shape, output_image_shape); + + if (kernel_h == kernel_w && kernel_h <= 5 && + selector[kernel_h - 1][strides_[0] - 1] != nullptr) { + auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1]; + conv2d_func(input, filter, bias, false, paddings.data(), DataTypeToEnum::value, output); } else { - output->Resize(output_shape); + Conv2dOpencl(input, filter, bias, false, strides_[0], paddings.data(), DataTypeToEnum::value, output); } - auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1]; - conv2d_func(input, filter, bias, false, paddings.data(), DataTypeToEnum::value, output); } template diff --git a/mace/kernels/opencl/conv_2d_opencl_general.cc b/mace/kernels/opencl/conv_2d_opencl_general.cc new file mode 100644 index 0000000000000000000000000000000000000000..e46ecbcaca06e811de44b5a29e08abb1e3418906 --- /dev/null +++ b/mace/kernels/opencl/conv_2d_opencl_general.cc @@ -0,0 +1,73 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/common.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/conv_2d.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/utils.h" + +namespace mace { +namespace kernels { + +void Conv2dOpencl(const Tensor *input, const Tensor *filter, + const Tensor *bias, const bool fused_relu, + const uint32_t stride, const int *padding, + const DataType dt, Tensor *output) { + const index_t batch = output->dim(0); + const index_t height = output->dim(1); + const index_t width = output->dim(2); + const index_t channels = output->dim(3); + const index_t input_channels = input->dim(3); + + const index_t channel_blocks = RoundUpDiv4(channels); + const index_t input_channel_blocks = RoundUpDiv4(input_channels); + const index_t width_blocks = RoundUpDiv4(width); + + std::set built_options; + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + built_options.emplace(bias != nullptr ? "-DBIAS" : ""); + built_options.emplace("-DSTRIDE=" + ToString(stride)); + if (fused_relu) { + built_options.emplace("-DFUSED_RELU"); + } + + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + + auto conv_2d_kernel = runtime->BuildKernel("conv_2d", "conv_2d", built_options); + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(conv_2d_kernel); + + uint32_t idx = 0; + conv_2d_kernel.setArg(idx++, *(static_cast(input->buffer()))); + conv_2d_kernel.setArg(idx++, *(static_cast(filter->buffer()))); + if (bias != nullptr) { + conv_2d_kernel.setArg(idx++, *(static_cast(bias->buffer()))); + } + conv_2d_kernel.setArg(idx++, *(static_cast(output->buffer()))); + conv_2d_kernel.setArg(idx++, static_cast(input->dim(1))); + conv_2d_kernel.setArg(idx++, static_cast(input->dim(2))); + conv_2d_kernel.setArg(idx++, static_cast(input_channel_blocks)); + conv_2d_kernel.setArg(idx++, static_cast(height)); + conv_2d_kernel.setArg(idx++, static_cast(width)); + conv_2d_kernel.setArg(idx++, static_cast(filter->dim(0))); + conv_2d_kernel.setArg(idx++, static_cast(filter->dim(1))); + conv_2d_kernel.setArg(idx++, padding[0] / 2); + conv_2d_kernel.setArg(idx++, padding[1] / 2); + + auto command_queue = runtime->command_queue(); + cl_int error; + error = command_queue.enqueueNDRangeKernel( + conv_2d_kernel, cl::NullRange, + cl::NDRange(static_cast(channel_blocks), static_cast(width_blocks), + static_cast(height * batch)), + cl::NDRange(16, 16, 4), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); + MACE_CHECK(error == CL_SUCCESS, error); + +} + +} // namespace kernels +} // namespace mace diff --git a/mace/ops/buffer_to_image_test.cc b/mace/ops/buffer_to_image_test.cc index 3836a7ae90291dbbfb80da20cf78a1bb1c79d87e..43092084d3f75cacf48ecf9dc9dd3fd3861f557d 100644 --- a/mace/ops/buffer_to_image_test.cc +++ b/mace/ops/buffer_to_image_test.cc @@ -118,14 +118,13 @@ void TestDiffTypeBidirectionTransform(const int type, const std::vector .Input("B2IOutput") .Output("I2BOutput") .AddIntArg("buffer_type", type) - .AddIntArg("T", DataTypeToEnum::value) .Finalize(net.NewOperatorDef()); // Run net.RunOp(D); // Check - ExpectTensorNear(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-2); + ExpectTensorNear(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-3); } TEST(BufferToImageTest, ArgFloatToHalfSmall) { diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index b4fd374b578d3b1eef058f495d331c2182619246..711bf3891211451429fc3ad0e80e1f55611a4b70 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -558,18 +558,20 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) { } template -static void TestHalfComplexConvNxNS12(const std::vector &shape) { +static void TestHalfComplexConvNxNS12(const std::vector &input_shape, + const std::vector &filter_shape) { testing::internal::LogToStderr(); - auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, - Padding type) { - srand(time(NULL)); + srand(time(NULL)); + auto func = [&](int stride_h, int stride_w, Padding padding) { // generate random input index_t batch = 3 + (rand() % 10); - index_t height = shape[0]; - index_t width = shape[1]; - index_t input_channels = shape[2] + (rand() % 10); - index_t output_channels = shape[3] + (rand() % 10); + index_t height = input_shape[0]; + index_t width = input_shape[1]; + index_t kernel_h = filter_shape[0]; + index_t kernel_w = filter_shape[1]; + index_t input_channels = filter_shape[2] + (rand() % 10); + index_t output_channels = filter_shape[3] + (rand() % 10); // Construct graph OpsTestNet net; OpDefBuilder("Conv2D", "Conv2dTest") @@ -578,7 +580,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &shape) { .Input("Bias") .Output("Output") .AddIntsArg("strides", {stride_h, stride_w}) - .AddIntArg("padding", type) + .AddIntArg("padding", padding) .AddIntsArg("dilations", {1, 1}) .Finalize(net.NewOperatorDef()); @@ -611,7 +613,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &shape) { .Input("BiasImage") .Output("OutputImage") .AddIntsArg("strides", {stride_h, stride_w}) - .AddIntArg("padding", type) + .AddIntArg("padding", padding) .AddIntsArg("dilations", {1, 1}) .AddIntArg("T", static_cast(DataType::DT_HALF)) .Finalize(net.NewOperatorDef()); @@ -620,20 +622,46 @@ static void TestHalfComplexConvNxNS12(const std::vector &shape) { ImageToBuffer(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT); - ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.2); + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.5); }; - for (int kernel_size : {1, 3}) { - for (int stride : {1, 2}) { - func(kernel_size, kernel_size, stride, stride, VALID); - } + for (int stride : {1, 2}) { + func(stride, stride, VALID); + func(stride, stride, SAME); } } -TEST_F(Conv2dOpTest, OPENCLHalfAlignedConvNxNS12) { - TestHalfComplexConvNxNS12({32, 32, 32, 64}); +TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x1S12) { + TestHalfComplexConvNxNS12({32, 32}, + {1, 1, 32, 64}); +} + +TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv3x3S12) { + TestHalfComplexConvNxNS12({32, 32}, + {3, 3, 32, 64}); +} + +TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv15x1S12) { + TestHalfComplexConvNxNS12({32, 32}, + {15, 1, 256, 2}); +} + +TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x15S12) { + TestHalfComplexConvNxNS12({32, 32}, + {1, 15, 256, 2}); +} + +TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv7x75S12) { + TestHalfComplexConvNxNS12({32, 32}, + {7, 7, 3, 64}); +} + +TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConv1x1S12) { + TestHalfComplexConvNxNS12({107, 113}, + {1, 1, 5, 7}); } -TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConvNxNS12) { - TestHalfComplexConvNxNS12({107, 113, 5, 7}); +TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConv3x3S12) { + TestHalfComplexConvNxNS12({107, 113}, + {3, 3, 5, 7}); }