From 87dc6e00d2dbefc29be8f4f0eba609aa596ae50f Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Fri, 3 Nov 2017 20:24:00 +0800 Subject: [PATCH] Rollback private storage to fix conv_2d 1x1 --- mace/kernels/opencl/cl/common.h | 9 ++++++++ mace/kernels/opencl/cl/conv_2d_1x1.cl | 32 ++++++--------------------- mace/ops/conv_2d_benchmark.cc | 2 ++ 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/mace/kernels/opencl/cl/common.h b/mace/kernels/opencl/cl/common.h index a6be3c53..74c5b67a 100644 --- a/mace/kernels/opencl/cl/common.h +++ b/mace/kernels/opencl/cl/common.h @@ -1,2 +1,11 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_OPENCL_CL_COMMON_H_ +#define MACE_KERNELS_OPENCL_CL_COMMON_H_ + #pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable + +#endif // MACE_KERNELS_OPENCL_CL_COMMON_H_ diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl index dc5b1f81..ae004658 100644 --- a/mace/kernels/opencl/cl/conv_2d_1x1.cl +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -35,8 +35,6 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ int out_chan_blk = get_global_id(1); int out_pixel_blk = get_global_id(2); - __private float output_slice[4 * 4]; - const int out_chan_begin = out_chan_blk * 4; const int out_chan_end = min(out_chan_begin + 4, out_chan_num); const int out_pixel_begin = out_pixel_blk * 4; @@ -52,10 +50,10 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ int pixel_len = out_pixel_end - out_pixel_begin; for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { + float *output_ptr = output_base + out_chan * pixel_num; float bias_value = bias[out_chan]; - int out_chan_offset = out_chan - out_chan_begin; for (int p = 0; p < pixel_len; ++p) { - output_slice[out_chan_offset * 4 + p] = bias_value; + output_ptr[p] = bias_value; } } @@ -74,29 +72,28 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ #pragma unroll for (int oc = 0; oc < 4; ++oc) { float4 weights = vload4(0, filter_ptr + oc * in_chan_num); - float4 out = vload4(oc, output_slice); + float4 out = vload4(0, output_ptr + oc * pixel_num); out += in0 * weights.x; out += in1 * weights.y; out += in2 * weights.z; out += in3 * weights.w; - vstore4(out, oc, output_slice); + vstore4(out, 0, output_ptr + oc * pixel_num); } } for (; out_chan < out_chan_end; ++out_chan) { const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; float *output_ptr = output_base + out_chan * pixel_num; - int out_chan_offset = out_chan - out_chan_begin; float4 weights = vload4(0, filter_ptr); float4 in0 = vload4(0, input_ptr); float4 in1 = vload4(0, input_ptr + pixel_num); float4 in2 = vload4(0, input_ptr + 2 * pixel_num); float4 in3 = vload4(0, input_ptr + 3 * pixel_num); - float4 out = vload4(out_chan_offset, output_slice); + float4 out = vload4(0, output_ptr); out += in0 * weights.x; out += in1 * weights.y; out += in2 * weights.z; out += in3 * weights.w; - vstore4(out, out_chan_offset, output_slice); + vstore4(out, 0, output_ptr); } } } @@ -106,25 +103,10 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { float weights = filter[out_chan * in_chan_num + in_chan]; float *output_ptr = output_base + out_chan * pixel_num; - int out_chan_offset = out_chan - out_chan_begin; for (int p = 0; p < pixel_len; ++p) { float in = input_ptr[p]; - output_slice[out_chan_offset * 4 + p] += in * weights; - } - } - } - - for (int out_chan_offset = 0; out_chan_offset < out_chan_len; ++out_chan_offset) { - int out_chan = out_chan_begin + out_chan_offset; - float *output_ptr = output_base + out_chan * pixel_num; - if (pixel_len == 4) { - float4 out = vload4(out_chan_offset, output_slice); - vstore4(out, 0, output_ptr); - } else { - int offset = out_chan_offset << 2; - for (int p = 0; p < pixel_len; ++p) { - output_ptr[p] = output_slice[offset + p]; + output_ptr[p] += in * weights; } } } diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index fb859da8..0d201a2f 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -80,6 +80,8 @@ constexpr int kItersToSync = 10; BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON); \ BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL); +BM_CONV_2D(1, 3, 4032, 3016, 1, 1, 1, VALID, 3, float); // Test RGB <-> YUV +BM_CONV_2D(1, 3, 480, 480, 1, 1, 1, VALID, 3, float); // Test RGB <-> YUV BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad alignments BM_CONV_2D(1, 3, 512, 512, 1, 1, 1, VALID, 3, float); -- GitLab