diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl index d371698862460087fa67812282cc72b181e431df..020e6bdc8374c3025f979d1aedba2b841c398e91 100644 --- a/mace/kernels/opencl/cl/conv_2d_1x1.cl +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -24,33 +24,89 @@ __kernel void conv_2d_1x1_naive(__global const float *input, /* n, c, h, w */ } } +#define vec_conv_2d_1x1_s1(out_size) \ +do { \ + float4 in0 = vload4(0, input_ptr); \ + float4 in1 = vload4(0, input_ptr + in_pixel); \ + float4 in2 = vload4(0, input_ptr + 2 * in_pixel); \ + float4 in3 = vload4(0, input_ptr + 3 * in_pixel); \ + for (int oc = 0; oc < out_size; ++oc) { \ + float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \ + float4 out = vload4(0, output_ptr + oc * out_pixel); \ + out += in0 * weights.x; \ + out += in1 * weights.y; \ + out += in2 * weights.z; \ + out += in3 * weights.w; \ + vstore4(out, 0, output_ptr + oc * out_pixel); \ + } \ +} while(0) + +#define vec_conv_2d_1x1_s2(out_size) \ +do { \ + float4 in00 = vload4(0, input_ptr); \ + float3 in01 = vload3(0, input_ptr + 4); \ + float4 in10 = vload4(0, input_ptr + in_pixel); \ + float3 in11 = vload3(0, input_ptr + in_pixel + 4); \ + float4 in20 = vload4(0, input_ptr + 2 * in_pixel); \ + float3 in21 = vload3(0, input_ptr + 2 * in_pixel + 4);\ + float4 in30 = vload4(0, input_ptr + 3 * in_pixel); \ + float3 in31 = vload3(0, input_ptr + 3 * in_pixel + 4); \ + float4 in0 = (float4)(in00.s02, in01.s02); \ + float4 in1 = (float4)(in10.s02, in11.s02); \ + float4 in2 = (float4)(in20.s02, in21.s02); \ + float4 in3 = (float4)(in30.s02, in31.s02); \ + for (int oc = 0; oc < out_size; ++oc) { \ + float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \ + float4 out = vload4(0, output_ptr + oc * out_pixel); \ + out += in0 * weights.x; \ + out += in1 * weights.y; \ + out += in2 * weights.z; \ + out += in3 * weights.w; \ + vstore4(out, 0, output_ptr + oc * out_pixel); \ + } \ +} while(0) + + + __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ __global const float *filter, /* o, i, kh, kw */ __global const float *bias, /* o */ __global float *output, /* n, c, h, w */ __private const int in_chan_num, __private const int out_chan_num, - __private const int pixel_num) { + __private const int in_height, + __private const int in_width, + __private const int out_height, + __private const int out_width, + __private const int stride) { int batch = get_global_id(0); int out_chan_blk = get_global_id(1); int out_pixel_blk = get_global_id(2); + const int in_pixel = in_height * in_width; + const int out_pixel = out_height * out_width; + + const int round_out_width = (out_width + 3) / 4; + const int out_pixel_height = out_pixel_blk / round_out_width; + const int out_pixel_width = out_pixel_blk % round_out_width; + const int out_chan_begin = out_chan_blk * 4; const int out_chan_end = min(out_chan_begin + 4, out_chan_num); - const int out_pixel_begin = out_pixel_blk * 4; - const int out_pixel_end = min(out_pixel_begin + 4, pixel_num); + const int out_pixel_begin = out_pixel_height * out_width + out_pixel_width * 4; + const int out_pixel_end = min(out_pixel_begin + 4, (out_pixel_height + 1) * out_width); + const int in_pixel_begin = out_pixel_height * stride * in_width + out_pixel_width * stride * 4; - const int in_offset = batch * in_chan_num * pixel_num; - const int out_offset = batch * out_chan_num * pixel_num; + const int in_offset = batch * in_chan_num * in_pixel; + const int out_offset = batch * out_chan_num * out_pixel; - const float *input_base = input + in_offset + out_pixel_begin; + const float *input_base = input + in_offset + in_pixel_begin; float *output_base = output + out_offset + out_pixel_begin; int out_chan_len = out_chan_end - out_chan_begin; int pixel_len = out_pixel_end - out_pixel_begin; for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { - float *output_ptr = output_base + out_chan * pixel_num; + float *output_ptr = output_base + out_chan * out_pixel; float bias_value = bias == NULL ? 0 : bias[out_chan]; for (int p = 0; p < pixel_len; ++p) { output_ptr[p] = bias_value; @@ -60,52 +116,37 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ int in_chan = 0; if (pixel_len == 4) { for (; in_chan + 3 < in_chan_num; in_chan += 4) { - const float *input_ptr = input_base + in_chan * pixel_num; + const float *input_ptr = input_base + in_chan * in_pixel; int out_chan = out_chan_begin; for (; out_chan + 3 < out_chan_end; out_chan += 4) { const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; - float *output_ptr = output_base + out_chan * pixel_num; - float4 in0 = vload4(0, input_ptr); - float4 in1 = vload4(0, input_ptr + pixel_num); - float4 in2 = vload4(0, input_ptr + 2 * pixel_num); - float4 in3 = vload4(0, input_ptr + 3 * pixel_num); - #pragma unroll - for (int oc = 0; oc < 4; ++oc) { - float4 weights = vload4(0, filter_ptr + oc * in_chan_num); - float4 out = vload4(0, output_ptr + oc * pixel_num); - out += in0 * weights.x; - out += in1 * weights.y; - out += in2 * weights.z; - out += in3 * weights.w; - vstore4(out, 0, output_ptr + oc * pixel_num); + float *output_ptr = output_base + out_chan * out_pixel; + if (stride == 1) { + vec_conv_2d_1x1_s1(4); + } else if (stride == 2) { + vec_conv_2d_1x1_s2(4); } } for (; out_chan < out_chan_end; ++out_chan) { const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; - float *output_ptr = output_base + out_chan * pixel_num; - float4 weights = vload4(0, filter_ptr); - float4 in0 = vload4(0, input_ptr); - float4 in1 = vload4(0, input_ptr + pixel_num); - float4 in2 = vload4(0, input_ptr + 2 * pixel_num); - float4 in3 = vload4(0, input_ptr + 3 * pixel_num); - float4 out = vload4(0, output_ptr); - out += in0 * weights.x; - out += in1 * weights.y; - out += in2 * weights.z; - out += in3 * weights.w; - vstore4(out, 0, output_ptr); + float *output_ptr = output_base + out_chan * out_pixel; + if (stride == 1) { + vec_conv_2d_1x1_s1(1); + } else if (stride == 2) { + vec_conv_2d_1x1_s2(1); + } } } } for (; in_chan < in_chan_num; ++in_chan) { - const float *input_ptr = input_base + in_chan * pixel_num; + const float *input_ptr = input_base + in_chan * in_pixel; for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { float weights = filter[out_chan * in_chan_num + in_chan]; - float *output_ptr = output_base + out_chan * pixel_num; + float *output_ptr = output_base + out_chan * out_pixel; for (int p = 0; p < pixel_len; ++p) { - float in = input_ptr[p]; + float in = input_ptr[p*stride]; output_ptr[p] += in * weights; } } diff --git a/mace/kernels/opencl/conv_2d_opencl.cc b/mace/kernels/opencl/conv_2d_opencl.cc index 2ff4a9c50da2e533d92d7e6dece0db285f91406f..ffb0314549f46ca64fee2ab4c88bc630459d0592 100644 --- a/mace/kernels/opencl/conv_2d_opencl.cc +++ b/mace/kernels/opencl/conv_2d_opencl.cc @@ -10,6 +10,9 @@ namespace kernels { extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output); +extern void Conv2dOpenclK1x1S2(const Tensor *input, const Tensor *filter, + const Tensor *bias, Tensor *output); + extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output); @@ -24,7 +27,7 @@ void Conv2dFunctor::operator()(const Tensor *input, const Tensor *bias, Tensor *output); // Selection matrix: kernel_size x stride_size static const Conv2dOpenclFunction selector[5][2] = { - {Conv2dOpenclK1x1S1, nullptr}, + {Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2}, {nullptr, nullptr}, {Conv2dOpenclK3x3S1, Conv2dOpenclK3x3S2}, {nullptr, nullptr}, diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index ba784d0552bd3f5a67558ab1392905db35ae2c4a..0c043b8c8758da3079e041f03d875a43fb2fd200 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -45,6 +45,7 @@ void Conv1x1Naive(const Tensor *input, void Conv1x1V2(const Tensor *input, const Tensor *filter, const Tensor *bias, + const int stride, Tensor *output) { const index_t batch = output->dim(0); const index_t channels = output->dim(1); @@ -54,9 +55,8 @@ void Conv1x1V2(const Tensor *input, auto runtime = OpenCLRuntime::Get(); auto program = runtime->program(); - const index_t pixels = height * width; const index_t channel_blocks = (channels + 3) / 4; - const index_t pixel_blocks = (pixels + 3) / 4; + const index_t pixel_blocks = (width + 3) / 4 * height; // TODO KernelFunctor has an extra clReleaseCommandQueue due to a copy // TODO check wired clReleaseCommandQueue latency @@ -77,7 +77,11 @@ void Conv1x1V2(const Tensor *input, conv_2d_kernel.setArg(idx++, *(static_cast(output->buffer()))); conv_2d_kernel.setArg(idx++, static_cast(input_channels)); conv_2d_kernel.setArg(idx++, static_cast(channels)); - conv_2d_kernel.setArg(idx++, static_cast(pixels)); + conv_2d_kernel.setArg(idx++, static_cast(input->dim(2))); + conv_2d_kernel.setArg(idx++, static_cast(input->dim(3))); + conv_2d_kernel.setArg(idx++, static_cast(height)); + conv_2d_kernel.setArg(idx++, static_cast(width)); + conv_2d_kernel.setArg(idx++, stride); auto command_queue = runtime->command_queue(); cl_int error = command_queue.enqueueNDRangeKernel( @@ -189,7 +193,16 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input, MACE_CHECK(input_batch == batch && input_height == height && input_width == width); - Conv1x1V2(input, filter, bias, output); + Conv1x1V2(input, filter, bias, 1, output); +}; + +extern void Conv2dOpenclK1x1S2(const Tensor *input, + const Tensor *filter, + const Tensor *bias, + Tensor *output) { + MACE_CHECK(input->dim(0) == output->dim(0)); + + Conv1x1V2(input, filter, bias, 2, output); }; } // namespace kernels