提交 25a874f7 编写于 作者: L liuqi

Filter format from [kh*kw*Ic, (Oc+3)/4] to [Ic, kh*kw*(Oc+3)/4]

上级 81414574
......@@ -86,7 +86,12 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(
static_cast<uint32_t>(buffer->buffer_offset() /
GetEnumTypeSize(buffer->dtype())));
}
if (type == ARGUMENT) {
if (type == CONV2D_FILTER) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else if (type == ARGUMENT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
} else if (type == WEIGHT_HEIGHT || type == WEIGHT_WIDTH) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
......
......@@ -2,22 +2,25 @@
__kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, oc, ic */
__private const int input_offset,
__private const int filter_h,
__private const int filter_w,
__private const int out_channel,
__private const int in_channel,
__write_only image2d_t output) {
int w = get_global_id(0);
int h = get_global_id(1);
const int out_channel_idx = h * 4;
const int rounded_in_channel = ((in_channel + 3) / 4) * 4;
const int hw_idx = w / rounded_in_channel;
const int in_channel_idx = w % rounded_in_channel;
const int in_channel_idx = w;
const int hw_size = filter_w * filter_h;
const int out_channel_idx = h / hw_size * 4;
const int hw_idx = h % hw_size;
const int h_idx = hw_idx / filter_w;
const int w_idx = hw_idx % filter_w;
const int offset = input_offset + ((h_idx * filter_w + w_idx) * out_channel + out_channel_idx) * in_channel
const int offset = input_offset
+ ((h_idx * filter_w + w_idx) * out_channel
+ out_channel_idx) * in_channel
+ in_channel_idx;
VEC_DATA_TYPE(DATA_TYPE, 4) values = 0;
DATA_TYPE4 values = 0;
if (out_channel_idx < out_channel) {
const int size = out_channel - out_channel_idx;
if (size < 4) {
......@@ -38,28 +41,30 @@ __kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, o
}
int2 coord = (int2)(w, h);
CMD_TYPE(write_image, CMD_DATA_TYPE)(output, coord, values);
WRITE_IMAGET(output, coord, values);
}
__kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, oc, ic */
__private const int filter_h,
__private const int filter_w,
__private const int out_channel,
__private const int in_channel,
__read_only image2d_t input) {
int w = get_global_id(0);
int h = get_global_id(1);
const int out_channel_idx = h * 4;
const int rounded_in_channel = ((in_channel + 3) / 4) * 4;
const int hw_idx = w / rounded_in_channel;
const int in_channel_idx = w % rounded_in_channel;
const int in_channel_idx = w;
const int hw_size = filter_w * filter_h;
const int out_channel_idx = h / hw_size * 4;
const int hw_idx = h % hw_size;
const int h_idx = hw_idx / filter_w;
const int w_idx = hw_idx % filter_w;
const int offset = ((h_idx * filter_w + w_idx) * out_channel + out_channel_idx) * in_channel
const int offset = ((h_idx * filter_w + w_idx) * out_channel
+ out_channel_idx) * in_channel
+ in_channel_idx;
if (out_channel_idx < out_channel) {
int2 coord = (int2)(w, h);
VEC_DATA_TYPE(DATA_TYPE, 4) values = CMD_TYPE(read_image, CMD_DATA_TYPE)(input, SAMPLER, coord);
DATA_TYPE4 values = READ_IMAGET(input, SAMPLER, coord);
const int size = (out_channel - out_channel_idx);
if (size < 4) {
switch (size) {
......@@ -145,7 +150,7 @@ __kernel void in_out_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */
+ channel_idx;
const int size = channels - channel_idx;
VEC_DATA_TYPE(DATA_TYPE, 4) values = 0;
DATA_TYPE4 values = 0;
if (size < 4) {
switch(size) {
case 3:
......@@ -159,7 +164,7 @@ __kernel void in_out_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */
values = vload4(0, input + offset);
}
int2 coord = (int2)(w, h);
CMD_TYPE(write_image, CMD_DATA_TYPE)(output, coord, values);
WRITE_IMAGET(output, coord, values);
}
__kernel void in_out_image_to_buffer(__global DATA_TYPE *output, /* nhwc */
......@@ -177,7 +182,7 @@ __kernel void in_out_image_to_buffer(__global DATA_TYPE *output, /* nhwc */
+ channel_idx;
int2 coord = (int2)(w, h);
VEC_DATA_TYPE(DATA_TYPE, 4) values = CMD_TYPE(read_image, CMD_DATA_TYPE)(input, SAMPLER, coord);
DATA_TYPE4 values = READ_IMAGET(input, SAMPLER, coord);
const int size = channels - channel_idx;
if (size < 4) {
switch (size) {
......@@ -204,7 +209,7 @@ __kernel void arg_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */
const int size = count - w * 4;
VEC_DATA_TYPE(DATA_TYPE, 4) values = 0;
DATA_TYPE4 values = 0;
if (size < 4) {
switch(size) {
case 3:
......@@ -218,7 +223,7 @@ __kernel void arg_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */
values = vload4(0, input + offset);
}
int2 coord = (int2)(w, h);
CMD_TYPE(write_image, CMD_DATA_TYPE)(output, coord, values);
WRITE_IMAGET(output, coord, values);
}
__kernel void arg_image_to_buffer(__global DATA_TYPE *output, /* nhwc */
......@@ -229,7 +234,7 @@ __kernel void arg_image_to_buffer(__global DATA_TYPE *output, /* nhwc */
const int offset = w * 4;
int2 coord = (int2)(w, h);
VEC_DATA_TYPE(DATA_TYPE, 4) values = CMD_TYPE(read_image, CMD_DATA_TYPE)(input, SAMPLER, coord);
DATA_TYPE4 values = READ_IMAGET(input, SAMPLER, coord);
const int size = count - offset;
if (size < 4) {
switch (size) {
......
#include <common.h>
__kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * cin * kh * kw, cout/4 */
__read_only image2d_t filter, /* cout%4 * cin, kh * kw * cout/4 */
#ifdef BIAS
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
......@@ -23,7 +23,6 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
const int out_w_blk = get_global_id(1);
const int out_w_blks = get_global_size(1);
const int out_hb = get_global_id(2);
const int rounded_in_ch = in_ch_blks << 2;
#ifdef BIAS
DATA_TYPE4 out0 =
......@@ -46,21 +45,21 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
const int height_idx = mad24((out_hb % out_height), stride, -padding_top);
const int batch_idx = mul24((out_hb / out_height), in_height);
const int rounded_in_ch_x_filter_width = mul24(rounded_in_ch, filter_width);
const int filter_hw = mul24(filter_width, filter_height);
DATA_TYPE4 in0, in1, in2, in3;
DATA_TYPE4 weights0, weights1, weights2, weights3;
for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) {
const int in_idx = mul24(in_ch_blk, in_width);
int filter_x_part0 = in_ch_blk << 2;
int filter_x_idx = in_ch_blk << 2;
int filter_y_idx = mul24(out_ch_blk, filter_hw);
for (short hb_idx = 0; hb_idx < filter_height; ++hb_idx) {
// TODO(heliangliang) optimize out these muls
int in_hb_value = height_idx + mul24(hb_idx, dilation_h);
in_hb_value = select(in_hb_value + batch_idx,
-1,
(in_hb_value < 0 || in_hb_value >= in_height));
int filter_x_part1 = 0;
#pragma unroll
for (short width_idx = 0; width_idx < filter_width; ++width_idx) {
int in_width_value;
#define READ_INPUT(i) \
......@@ -78,11 +77,10 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
#undef READ_INPUT
// int filter_idx = (hb_idx * filter_width + width_idx) * rounded_in_ch + (in_ch_blk << 2);
int filter_idx = filter_x_part0 + filter_x_part1;
weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 0, out_ch_blk));
weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 1, out_ch_blk));
weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 2, out_ch_blk));
weights3 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 3, out_ch_blk));
weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 0, filter_y_idx));
weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 1, filter_y_idx));
weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 2, filter_y_idx));
weights3 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 3, filter_y_idx));
out0 = mad(in0.x, weights0, out0);
out0 = mad(in0.y, weights1, out0);
......@@ -105,9 +103,8 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
out3 = mad(in3.z, weights2, out3);
out3 = mad(in3.w, weights3, out3);
filter_x_part1 += rounded_in_ch;
filter_y_idx += 1;
}
filter_x_part0 += rounded_in_ch_x_filter_width;
}
}
......
#include <common.h>
__kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * cin * kh * kw, cout/4 */
__read_only image2d_t filter, /* cout%4 * cin , kh * kw * cout/4 */
#ifdef BIAS
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
......@@ -21,7 +21,6 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
const int out_w_blk = get_global_id(1);
const int out_w_blks = get_global_size(1);
const int out_hb = get_global_id(2);
const int rounded_in_ch = in_ch_blks << 2;
#ifdef BIAS
DATA_TYPE4 out0 =
......@@ -47,19 +46,18 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
const int height_idx = mad24((out_hb % out_height), stride, -padding_top);
const int batch_idx = mul24((out_hb / out_height), in_height);
const int rounded_in_ch_x_3 = (rounded_in_ch << 1) + rounded_in_ch;
DATA_TYPE4 in0, in1, in2, in3, in4;
DATA_TYPE4 weights0, weights1, weights2, weights3;
for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) {
const int in_idx = mul24(in_ch_blk, in_width);
int filter_x_part0 = in_ch_blk << 2;
int filter_x_idx = in_ch_blk << 2;
int filter_y_idx = mul24(out_ch_blk, 9);
int in_hb_idx = height_idx;
for (short hb_idx = 0; hb_idx < 3; ++hb_idx) {
int in_hb_value = select(in_hb_idx + batch_idx,
-1,
(in_hb_idx < 0 || in_hb_idx >= in_height));
int filter_x_part1 = 0;
int in_width_idx = 0;
for (short width_idx = 0; width_idx < 3; ++width_idx) {
int in_width_value;
......@@ -79,11 +77,10 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
#undef READ_INPUT
// int filter_idx = (hb_idx * 3 + width_idx) * rounded_in_ch + (in_ch_blk << 2);
int filter_idx = filter_x_part0 + filter_x_part1;
weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 0, out_ch_blk));
weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 1, out_ch_blk));
weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 2, out_ch_blk));
weights3 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 3, out_ch_blk));
weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 0, filter_y_idx));
weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 1, filter_y_idx));
weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 2, filter_y_idx));
weights3 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 3, filter_y_idx));
out0 = mad(in0.x, weights0, out0);
out0 = mad(in0.y, weights1, out0);
......@@ -111,10 +108,9 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
out4 = mad(in4.z, weights2, out4);
out4 = mad(in4.w, weights3, out4);
filter_x_part1 += rounded_in_ch;
in_width_idx += dilation_w;
filter_y_idx += 1;
}
filter_x_part0 += rounded_in_ch_x_3;
in_hb_idx += dilation_h;
}
}
......
......@@ -23,13 +23,13 @@ void CalInOutputImageShape(const std::vector<index_t> &shape, /* NHWC */
(*image_shape)[1] = shape[0] * shape[1];
}
// [RoundUp<4>(Ic) * H * W, (Oc + 3) / 4]
// [RoundUp<4>(Ic), H * W * (Oc + 3) / 4]
void CalConv2dFilterImageShape(const std::vector<index_t> &shape, /* HWOI */
std::vector<size_t> *image_shape) {
MACE_CHECK(shape.size() == 4);
image_shape->resize(2);
(*image_shape)[0] = shape[0] * shape[1] * RoundUp<index_t>(shape[3], 4);
(*image_shape)[1] = RoundUpDiv4(shape[2]);
(*image_shape)[0] = RoundUp<index_t>(shape[3], 4);
(*image_shape)[1] = shape[0] * shape[1] * RoundUpDiv4(shape[2]);
}
// [H * W * M, (Ic + 3) / 4]
......
......@@ -114,6 +114,7 @@ static void Conv2d(int iters,
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, OPENCL); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, OPENCL);
BM_CONV_2D(1, 256, 64, 64, 3, 3, 1, 1, VALID, 256);
BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, 1, VALID, 1024);
......@@ -135,6 +136,8 @@ BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, 1, SAME, 128);
BM_CONV_2D(1, 1024, 16, 16, 15, 1, 1, 1, SAME, 2);
// Dilation
BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 2, VALID, 32);
BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 4, VALID, 32);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册