提交 5bbd271e 编写于 作者: L liuqi

Finish conv 1x1 with stride 2 opencl kernel.

上级 446780e5
......@@ -24,33 +24,89 @@ __kernel void conv_2d_1x1_naive(__global const float *input, /* n, c, h, w */
}
}
#define vec_conv_2d_1x1_s1(out_size) \
do { \
float4 in0 = vload4(0, input_ptr); \
float4 in1 = vload4(0, input_ptr + in_pixel); \
float4 in2 = vload4(0, input_ptr + 2 * in_pixel); \
float4 in3 = vload4(0, input_ptr + 3 * in_pixel); \
for (int oc = 0; oc < out_size; ++oc) { \
float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \
float4 out = vload4(0, output_ptr + oc * out_pixel); \
out += in0 * weights.x; \
out += in1 * weights.y; \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr + oc * out_pixel); \
} \
} while(0)
#define vec_conv_2d_1x1_s2(out_size) \
do { \
float4 in00 = vload4(0, input_ptr); \
float3 in01 = vload3(0, input_ptr + 4); \
float4 in10 = vload4(0, input_ptr + in_pixel); \
float3 in11 = vload3(0, input_ptr + in_pixel + 4); \
float4 in20 = vload4(0, input_ptr + 2 * in_pixel); \
float3 in21 = vload3(0, input_ptr + 2 * in_pixel + 4);\
float4 in30 = vload4(0, input_ptr + 3 * in_pixel); \
float3 in31 = vload3(0, input_ptr + 3 * in_pixel + 4); \
float4 in0 = (float4)(in00.s02, in01.s02); \
float4 in1 = (float4)(in10.s02, in11.s02); \
float4 in2 = (float4)(in20.s02, in21.s02); \
float4 in3 = (float4)(in30.s02, in31.s02); \
for (int oc = 0; oc < out_size; ++oc) { \
float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \
float4 out = vload4(0, output_ptr + oc * out_pixel); \
out += in0 * weights.x; \
out += in1 * weights.y; \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr + oc * out_pixel); \
} \
} while(0)
__kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
__global const float *filter, /* o, i, kh, kw */
__global const float *bias, /* o */
__global float *output, /* n, c, h, w */
__private const int in_chan_num,
__private const int out_chan_num,
__private const int pixel_num) {
__private const int in_height,
__private const int in_width,
__private const int out_height,
__private const int out_width,
__private const int stride) {
int batch = get_global_id(0);
int out_chan_blk = get_global_id(1);
int out_pixel_blk = get_global_id(2);
const int in_pixel = in_height * in_width;
const int out_pixel = out_height * out_width;
const int round_out_width = (out_width + 3) / 4;
const int out_pixel_height = out_pixel_blk / round_out_width;
const int out_pixel_width = out_pixel_blk % round_out_width;
const int out_chan_begin = out_chan_blk * 4;
const int out_chan_end = min(out_chan_begin + 4, out_chan_num);
const int out_pixel_begin = out_pixel_blk * 4;
const int out_pixel_end = min(out_pixel_begin + 4, pixel_num);
const int out_pixel_begin = out_pixel_height * out_width + out_pixel_width * 4;
const int out_pixel_end = min(out_pixel_begin + 4, (out_pixel_height + 1) * out_width);
const int in_pixel_begin = out_pixel_height * stride * in_width + out_pixel_width * stride * 4;
const int in_offset = batch * in_chan_num * pixel_num;
const int out_offset = batch * out_chan_num * pixel_num;
const int in_offset = batch * in_chan_num * in_pixel;
const int out_offset = batch * out_chan_num * out_pixel;
const float *input_base = input + in_offset + out_pixel_begin;
const float *input_base = input + in_offset + in_pixel_begin;
float *output_base = output + out_offset + out_pixel_begin;
int out_chan_len = out_chan_end - out_chan_begin;
int pixel_len = out_pixel_end - out_pixel_begin;
for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) {
float *output_ptr = output_base + out_chan * pixel_num;
float *output_ptr = output_base + out_chan * out_pixel;
float bias_value = bias == NULL ? 0 : bias[out_chan];
for (int p = 0; p < pixel_len; ++p) {
output_ptr[p] = bias_value;
......@@ -60,52 +116,37 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
int in_chan = 0;
if (pixel_len == 4) {
for (; in_chan + 3 < in_chan_num; in_chan += 4) {
const float *input_ptr = input_base + in_chan * pixel_num;
const float *input_ptr = input_base + in_chan * in_pixel;
int out_chan = out_chan_begin;
for (; out_chan + 3 < out_chan_end; out_chan += 4) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * pixel_num;
float4 in0 = vload4(0, input_ptr);
float4 in1 = vload4(0, input_ptr + pixel_num);
float4 in2 = vload4(0, input_ptr + 2 * pixel_num);
float4 in3 = vload4(0, input_ptr + 3 * pixel_num);
#pragma unroll
for (int oc = 0; oc < 4; ++oc) {
float4 weights = vload4(0, filter_ptr + oc * in_chan_num);
float4 out = vload4(0, output_ptr + oc * pixel_num);
out += in0 * weights.x;
out += in1 * weights.y;
out += in2 * weights.z;
out += in3 * weights.w;
vstore4(out, 0, output_ptr + oc * pixel_num);
float *output_ptr = output_base + out_chan * out_pixel;
if (stride == 1) {
vec_conv_2d_1x1_s1(4);
} else if (stride == 2) {
vec_conv_2d_1x1_s2(4);
}
}
for (; out_chan < out_chan_end; ++out_chan) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * pixel_num;
float4 weights = vload4(0, filter_ptr);
float4 in0 = vload4(0, input_ptr);
float4 in1 = vload4(0, input_ptr + pixel_num);
float4 in2 = vload4(0, input_ptr + 2 * pixel_num);
float4 in3 = vload4(0, input_ptr + 3 * pixel_num);
float4 out = vload4(0, output_ptr);
out += in0 * weights.x;
out += in1 * weights.y;
out += in2 * weights.z;
out += in3 * weights.w;
vstore4(out, 0, output_ptr);
float *output_ptr = output_base + out_chan * out_pixel;
if (stride == 1) {
vec_conv_2d_1x1_s1(1);
} else if (stride == 2) {
vec_conv_2d_1x1_s2(1);
}
}
}
}
for (; in_chan < in_chan_num; ++in_chan) {
const float *input_ptr = input_base + in_chan * pixel_num;
const float *input_ptr = input_base + in_chan * in_pixel;
for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) {
float weights = filter[out_chan * in_chan_num + in_chan];
float *output_ptr = output_base + out_chan * pixel_num;
float *output_ptr = output_base + out_chan * out_pixel;
for (int p = 0; p < pixel_len; ++p) {
float in = input_ptr[p];
float in = input_ptr[p*stride];
output_ptr[p] += in * weights;
}
}
......
......@@ -10,6 +10,9 @@ namespace kernels {
extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output);
extern void Conv2dOpenclK1x1S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output);
extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output);
......@@ -24,7 +27,7 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
const Tensor *bias, Tensor *output);
// Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5][2] = {
{Conv2dOpenclK1x1S1, nullptr},
{Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2},
{nullptr, nullptr},
{Conv2dOpenclK3x3S1, Conv2dOpenclK3x3S2},
{nullptr, nullptr},
......
......@@ -45,6 +45,7 @@ void Conv1x1Naive(const Tensor *input,
void Conv1x1V2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int stride,
Tensor *output) {
const index_t batch = output->dim(0);
const index_t channels = output->dim(1);
......@@ -54,9 +55,8 @@ void Conv1x1V2(const Tensor *input,
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
const index_t pixels = height * width;
const index_t channel_blocks = (channels + 3) / 4;
const index_t pixel_blocks = (pixels + 3) / 4;
const index_t pixel_blocks = (width + 3) / 4 * height;
// TODO KernelFunctor has an extra clReleaseCommandQueue due to a copy
// TODO check wired clReleaseCommandQueue latency
......@@ -77,7 +77,11 @@ void Conv1x1V2(const Tensor *input,
conv_2d_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channels));
conv_2d_kernel.setArg(idx++, static_cast<int>(channels));
conv_2d_kernel.setArg(idx++, static_cast<int>(pixels));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(2)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(3)));
conv_2d_kernel.setArg(idx++, static_cast<int>(height));
conv_2d_kernel.setArg(idx++, static_cast<int>(width));
conv_2d_kernel.setArg(idx++, stride);
auto command_queue = runtime->command_queue();
cl_int error = command_queue.enqueueNDRangeKernel(
......@@ -189,7 +193,16 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input,
MACE_CHECK(input_batch == batch && input_height == height &&
input_width == width);
Conv1x1V2(input, filter, bias, output);
Conv1x1V2(input, filter, bias, 1, output);
};
extern void Conv2dOpenclK1x1S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output) {
MACE_CHECK(input->dim(0) == output->dim(0));
Conv1x1V2(input, filter, bias, 2, output);
};
} // namespace kernels
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册