diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl index cd9f22ae3f981823d4d45a876ee0cf18e4a0f456..7d41efc05df9f5c6b61bb14eee19708a41c145d9 100644 --- a/mace/kernels/opencl/cl/conv_2d_1x1.cl +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -1,51 +1,77 @@ /* * Split work item along output channels and pixels */ -void kernel conv_2d_1x1_naive(global const float *input, /* n, c, h, w */ - global const float *filter, /* o, i, kh, kw */ - global float *output, /* n, c, h, w */ - private const int in_offset, - private const int out_offset, - private const int pixel_num, - private const int in_chan_num, - private const int out_chan_num) { +void kernel conv_2d_1x1_nchw(global const float *input, /* n, c, h, w */ + global const float *filter, /* o, i, kh, kw */ + global float *output, /* n, c, h, w */ + private const int in_offset, + private const int out_offset, + private const int pixel_num, + private const int in_chan_num, + private const int out_chan_num) { int out_chan_blk = get_global_id(0); int out_pixel_blk = get_global_id(1); - const int out_chan_begin = out_chan_blk << 2; + const int out_chan_begin = out_chan_blk * 4; const int out_chan_end = min(out_chan_begin + 4, out_chan_num); - const int out_pixel_begin = out_pixel_blk << 3; - const int out_pixel_end = min(out_pixel_begin + 8, pixel_num); + const int out_pixel_begin = out_pixel_blk * 4; + const int out_pixel_end = min(out_pixel_begin + 4, pixel_num); const float *input_base = input + in_offset + out_pixel_begin; float *output_base = output + out_offset + out_pixel_begin; - int pixels = out_pixel_end - out_pixel_begin; - for (int in_chan = 0; in_chan < in_chan_num; ++in_chan) { - const float *input_ptr = input_base + in_chan * pixel_num; - if (pixels == 8) { - /* TODO fix '#pragma unroll' build error */ - for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { - float weights = filter[out_chan * in_chan_num + in_chan]; + int pixels = out_pixel_end - out_pixel_begin; + int in_chan = 0; + if (pixels == 4) { + for (; in_chan + 3 < in_chan_num; in_chan += 4) { + const float *input_ptr = input_base + in_chan * pixel_num; + int out_chan = out_chan_begin; + for (; out_chan + 3 < out_chan_end; out_chan += 4) { + const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; float *output_ptr = output_base + out_chan * pixel_num; - for (int p = 0; p < 2; ++p) { - float4 in = vload4(p, input_ptr); - float4 out = vload4(p, output_ptr); - out += in * weights; - vstore4(out, p, output_ptr); + float4 in0 = vload4(0, input_ptr); + float4 in1 = vload4(0, input_ptr + pixel_num); + float4 in2 = vload4(0, input_ptr + 2 * pixel_num); + float4 in3 = vload4(0, input_ptr + 3 * pixel_num); + for (int oc = 0; oc < 4; ++oc) { + float4 weights = vload4(0, filter_ptr + oc * in_chan_num); + float4 out = vload4(0, output_ptr + oc * pixel_num); + out += in0 * weights.x; + out += in1 * weights.y; + out += in2 * weights.z; + out += in3 * weights.w; + vstore4(out, 0, output_ptr + oc * pixel_num); } } - } else { - for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { - float weights = filter[out_chan * in_chan_num + in_chan]; + for (; out_chan < out_chan_end; ++out_chan) { + const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; float *output_ptr = output_base + out_chan * pixel_num; + float4 weights = vload4(0, filter_ptr); + float4 in0 = vload4(0, input_ptr); + float4 in1 = vload4(0, input_ptr + pixel_num); + float4 in2 = vload4(0, input_ptr + 2 * pixel_num); + float4 in3 = vload4(0, input_ptr + 3 * pixel_num); + float4 out = vload4(0, output_ptr); + out += in0 * weights.x; + out += in1 * weights.y; + out += in2 * weights.z; + out += in3 * weights.w; + vstore4(out, 0, output_ptr); + } + } + } - for (int p = 0; p < pixels; ++p) { - float in = input_ptr[p]; - float out = output_ptr[p]; - out += in * weights; - output_ptr[p] = out; - } + for (; in_chan < in_chan_num; ++in_chan) { + const float *input_ptr = input_base + in_chan * pixel_num; + for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { + float weights = filter[out_chan * in_chan_num + in_chan]; + float *output_ptr = output_base + out_chan * pixel_num; + + for (int p = 0; p < pixels; ++p) { + float in = input_ptr[p]; + float out = output_ptr[p]; + out += in * weights; + output_ptr[p] = out; } } } diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index d37fb2931bb1b782dc20d88c8410b5a3c638fb06..7f6e1d19fafa9e17c47abd9a11ef1d8bb006da7e 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -50,44 +50,53 @@ void AssignBias(Tensor *output, const Tensor *bias) { } } -extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, - const Tensor *bias, Tensor *output) { +void Conv1x1NCHW(const Tensor *input, + const Tensor *filter, + Tensor *output) { const index_t batch = output->shape()[0]; const index_t channels = output->shape()[1]; const index_t height = output->shape()[2]; const index_t width = output->shape()[3]; - - const index_t input_batch = input->shape()[0]; const index_t input_channels = input->shape()[1]; - const index_t input_height = input->shape()[2]; - const index_t input_width = input->shape()[3]; - - MACE_CHECK(input_batch == batch && input_height == height && - input_width == width); - - AssignBias(output, bias); auto runtime = OpenCLRuntime::Get(); auto program = runtime->program(); auto conv_2d = cl::KernelFunctor(program, "conv_2d_1x1_naive"); + int, int, int, int, int>(program, "conv_2d_1x1_nchw"); const index_t total_pixels = height * width; for (int b = 0; b < batch; ++b) { int input_offset = b * input_channels * total_pixels; int output_offset = b * channels * total_pixels; int chan_blk_num = (channels + 3) >> 2; // each 4 output channels - int pixel_blk_num = (total_pixels + 7) >> 3; // each 8 pixels + int pixel_blk_num = (total_pixels + 3) >> 2; // each 4 pixels cl_int error; conv_2d(cl::EnqueueArgs(runtime->command_queue(), cl::NDRange(chan_blk_num, pixel_blk_num), - cl::NDRange(1, 64)), + cl::NDRange(1, 256)), *(static_cast(input->buffer())), *(static_cast(filter->buffer())), *(static_cast(output->buffer())), input_offset, output_offset, total_pixels, input_channels, channels, error); MACE_CHECK(error == CL_SUCCESS); } +} + +extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, + const Tensor *bias, Tensor *output) { + const index_t batch = output->shape()[0]; + const index_t height = output->shape()[2]; + const index_t width = output->shape()[3]; + + const index_t input_batch = input->shape()[0]; + const index_t input_height = input->shape()[2]; + const index_t input_width = input->shape()[3]; + + MACE_CHECK(input_batch == batch && input_height == height && + input_width == width); + + AssignBias(output, bias); + Conv1x1NCHW(input, filter, output); }; } // namespace kernels