提交 5262c860 编写于 作者: L liuqi

Refactor opencl conv kernel and op.

上级 b028e4de
...@@ -24,8 +24,8 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer, ...@@ -24,8 +24,8 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer,
} }
std::set<std::string> built_options; std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(image->dtype())); built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(image->dtype())); built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(DataTypeToEnum<T>::value));
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
string kernel_name; string kernel_name;
switch (type) { switch (type) {
......
...@@ -23,15 +23,15 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -23,15 +23,15 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#ifdef BIAS #ifdef BIAS
float4 out0 = convert_float4(READ_IMAGET(bias, sampler, (int2)(out_ch_blk, 0))); DATA_TYPE4 out0 = READ_IMAGET(bias, sampler, (int2)(out_ch_blk, 0));
float4 out1 = out0; DATA_TYPE4 out1 = out0;
float4 out2 = out0; DATA_TYPE4 out2 = out0;
float4 out3 = out0; DATA_TYPE4 out3 = out0;
#else #else
float4 out0 = 0; DATA_TYPE4 out0 = 0;
float4 out1 = 0; DATA_TYPE4 out1 = 0;
float4 out2 = 0; DATA_TYPE4 out2 = 0;
float4 out3 = 0; DATA_TYPE4 out3 = 0;
#endif #endif
int4 w; int4 w;
...@@ -62,16 +62,16 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -62,16 +62,16 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
int in_x_base = 0; int in_x_base = 0;
for (int in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { for (int in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) {
float4 in0 = convert_float4(READ_IMAGET(input, sampler, (int2)(in_x_base + w.x, out_hb_idx))); DATA_TYPE4 in0 = READ_IMAGET(input, sampler, (int2)(in_x_base + w.x, out_hb_idx));
float4 in1 = convert_float4(READ_IMAGET(input, sampler, (int2)(in_x_base + w.y, out_hb_idx))); DATA_TYPE4 in1 = READ_IMAGET(input, sampler, (int2)(in_x_base + w.y, out_hb_idx));
float4 in2 = convert_float4(READ_IMAGET(input, sampler, (int2)(in_x_base + w.z, out_hb_idx))); DATA_TYPE4 in2 = READ_IMAGET(input, sampler, (int2)(in_x_base + w.z, out_hb_idx));
float4 in3 = convert_float4(READ_IMAGET(input, sampler, (int2)(in_x_base + w.w, out_hb_idx))); DATA_TYPE4 in3 = READ_IMAGET(input, sampler, (int2)(in_x_base + w.w, out_hb_idx));
const int filter_x0 = in_ch_blk << 2; const int filter_x0 = in_ch_blk << 2;
float4 weights0 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_x0, out_ch_blk))); DATA_TYPE4 weights0 = READ_IMAGET(filter, sampler, (int2)(filter_x0, out_ch_blk));
float4 weights1 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_x0 + 1, out_ch_blk))); DATA_TYPE4 weights1 = READ_IMAGET(filter, sampler, (int2)(filter_x0 + 1, out_ch_blk));
float4 weights2 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_x0 + 2, out_ch_blk))); DATA_TYPE4 weights2 = READ_IMAGET(filter, sampler, (int2)(filter_x0 + 2, out_ch_blk));
float4 weights3 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_x0 + 3, out_ch_blk))); DATA_TYPE4 weights3 = READ_IMAGET(filter, sampler, (int2)(filter_x0 + 3, out_ch_blk));
// Will prefetch L2 improve performance? How to pretch image data? // Will prefetch L2 improve performance? How to pretch image data?
out0 += in0.x * weights0; out0 += in0.x * weights0;
...@@ -99,18 +99,18 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -99,18 +99,18 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
#ifdef FUSED_BATCH_NORM #ifdef FUSED_BATCH_NORM
// batch norm // batch norm
float4 bn_scale_value = DATA_TYPE4 bn_scale_value =
convert_float4(READ_IMAGET(bn_scale, sampler, (int2)(out_ch_blk, 0))); READ_IMAGET(bn_scale, sampler, (int2)(out_ch_blk, 0));
float4 scale0 = (float4)(bn_scale_value.x); DATA_TYPE4 scale0 = (DATA_TYPE4)(bn_scale_value.x);
float4 scale1 = (float4)(bn_scale_value.y); DATA_TYPE4 scale1 = (DATA_TYPE4)(bn_scale_value.y);
float4 scale2 = (float4)(bn_scale_value.z); DATA_TYPE4 scale2 = (DATA_TYPE4)(bn_scale_value.z);
float4 scale3 = (float4)(bn_scale_value.w); DATA_TYPE4 scale3 = (DATA_TYPE4)(bn_scale_value.w);
float4 bn_offset_value = DATA_TYPE4 bn_offset_value =
READ_IMAGET(bn_offset, sampler, (int2)(out_ch_blk, 0)); READ_IMAGET(bn_offset, sampler, (int2)(out_ch_blk, 0));
float4 offset0 = (float4)(bn_offset_value.x); DATA_TYPE4 offset0 = (DATA_TYPE4)(bn_offset_value.x);
float4 offset1 = (float4)(bn_offset_value.y); DATA_TYPE4 offset1 = (DATA_TYPE4)(bn_offset_value.y);
float4 offset2 = (float4)(bn_offset_value.z); DATA_TYPE4 offset2 = (DATA_TYPE4)(bn_offset_value.z);
float4 offset3 = (float4)(bn_offset_value.w); DATA_TYPE4 offset3 = (DATA_TYPE4)(bn_offset_value.w);
out0 = out0 * scale0 + offset0; out0 = out0 * scale0 + offset0;
out1 = out1 * scale1 + offset1; out1 = out1 * scale1 + offset1;
...@@ -126,7 +126,6 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -126,7 +126,6 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
out3 = fmax(out3, 0); out3 = fmax(out3, 0);
#endif #endif
#ifdef TYPE_FLOAT
const int out_x_base = out_ch_blk * width; const int out_x_base = out_ch_blk * width;
int out_x_idx = out_w_blk; int out_x_idx = out_w_blk;
WRITE_IMAGET(output, (int2)(out_x_base + out_x_idx, out_hb), out0); WRITE_IMAGET(output, (int2)(out_x_base + out_x_idx, out_hb), out0);
...@@ -142,21 +141,5 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -142,21 +141,5 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
out_x_idx += out_w_blks; out_x_idx += out_w_blks;
if (out_x_idx >= width) return; if (out_x_idx >= width) return;
WRITE_IMAGET(output, (int2)(out_x_base + out_x_idx, out_hb), out3); WRITE_IMAGET(output, (int2)(out_x_base + out_x_idx, out_hb), out3);
#else
const int out_x_base = out_ch_blk * width;
int out_x_idx = out_w_blk;
WRITE_IMAGET(output, (int2)(out_x_base + out_x_idx, out_hb), convert_half4(out0));
out_x_idx += out_w_blks;
if (out_x_idx >= width) return;
WRITE_IMAGET(output, (int2)(out_x_base + out_x_idx, out_hb), convert_half4(out1));
out_x_idx += out_w_blks;
if (out_x_idx >= width) return;
WRITE_IMAGET(output, (int2)(out_x_base + out_x_idx, out_hb), convert_half4(out2));
out_x_idx += out_w_blks;
if (out_x_idx >= width) return;
WRITE_IMAGET(output, (int2)(out_x_base + out_x_idx, out_hb), convert_half4(out3));
#endif
} }
...@@ -19,21 +19,20 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -19,21 +19,20 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
const int out_hb = get_global_id(2); const int out_hb = get_global_id(2);
const int rounded_in_ch = in_ch_blks * 4; const int rounded_in_ch = in_ch_blks * 4;
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#ifdef BIAS #ifdef BIAS
float4 out0 = DATA_TYPE4 out0 =
convert_float4(READ_IMAGET(bias, sampler, (int2)(out_ch_blk, 0))); READ_IMAGET(bias, sampler, (int2)(out_ch_blk, 0));
float4 out1 = out0; DATA_TYPE4 out1 = out0;
float4 out2 = out0; DATA_TYPE4 out2 = out0;
float4 out3 = out0; DATA_TYPE4 out3 = out0;
float4 out4 = out0; DATA_TYPE4 out4 = out0;
#else #else
float4 out0 = 0; DATA_TYPE4 out0 = 0;
float4 out1 = 0; DATA_TYPE4 out1 = 0;
float4 out2 = 0; DATA_TYPE4 out2 = 0;
float4 out3 = 0; DATA_TYPE4 out3 = 0;
float4 out4 = 0; DATA_TYPE4 out4 = 0;
#endif #endif
#if STRIDE == 1 #if STRIDE == 1
...@@ -54,8 +53,8 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -54,8 +53,8 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
const int batch_idx = (out_hb / out_height) * in_height; const int batch_idx = (out_hb / out_height) * in_height;
float4 in0, in1, in2, in3, in4; DATA_TYPE4 in0, in1, in2, in3, in4;
float4 weights0, weights1, weights2, weights3; DATA_TYPE4 weights0, weights1, weights2, weights3;
int in_idx, hb_idx, width_idx, in_width_idx; int in_idx, hb_idx, width_idx, in_width_idx;
// Unrolling this loop hurt perfmance // Unrolling this loop hurt perfmance
for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) {
...@@ -75,7 +74,7 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -75,7 +74,7 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
in_width_value = select(in_idx + in_width_value, \ in_width_value = select(in_idx + in_width_value, \
-1, \ -1, \
(in_width_value < 0 || in_width_value >= in_width)); \ (in_width_value < 0 || in_width_value >= in_width)); \
in##i = convert_float4(READ_IMAGET(input, sampler, (int2)(in_width_value, in_hb_value))); in##i = READ_IMAGET(input, sampler, (int2)(in_width_value, in_hb_value));
READ_INPUT(0); READ_INPUT(0);
READ_INPUT(1); READ_INPUT(1);
...@@ -86,10 +85,10 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -86,10 +85,10 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
#undef READ_INPUT #undef READ_INPUT
int filter_idx = (in_ch_blk << 2) + (hb_idx * 3 + width_idx) * rounded_in_ch; int filter_idx = (in_ch_blk << 2) + (hb_idx * 3 + width_idx) * rounded_in_ch;
weights0 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk))); weights0 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk));
weights1 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_idx + 1, out_ch_blk))); weights1 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 1, out_ch_blk));
weights2 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_idx + 2, out_ch_blk))); weights2 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 2, out_ch_blk));
weights3 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_idx + 3, out_ch_blk))); weights3 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 3, out_ch_blk));
// Will prefetch L2 improve performance? How to pretch image data? // Will prefetch L2 improve performance? How to pretch image data?
...@@ -122,7 +121,6 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -122,7 +121,6 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
} }
} }
#ifdef TYPE_FLOAT
const int out_x_base = out_ch_blk * out_width; const int out_x_base = out_ch_blk * out_width;
int w = out_w_blk; int w = out_w_blk;
WRITE_IMAGET(output, WRITE_IMAGET(output,
...@@ -152,36 +150,5 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -152,36 +150,5 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
WRITE_IMAGET(output, WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb), (int2)(out_x_base + w, out_hb),
out4); out4);
#else
const int out_x_base = out_ch_blk * out_width;
int w = out_w_blk;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
convert_half4(out0));
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
convert_half4(out1));
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
convert_half4(out2));
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
convert_half4(out3));
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
convert_half4(out4));
#endif
} }
...@@ -10,19 +10,19 @@ namespace kernels { ...@@ -10,19 +10,19 @@ namespace kernels {
extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, const int *padding, const Tensor *bias, const int *padding,
Tensor *output); const DataType dt, Tensor *output);
extern void Conv2dOpenclK1x1S2(const Tensor *input, const Tensor *filter, extern void Conv2dOpenclK1x1S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, const int *padding, const Tensor *bias, const int *padding,
Tensor *output); const DataType dt, Tensor *output);
extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter, extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, const int *padding, const Tensor *bias, const int *padding,
Tensor *output); const DataType dt, Tensor *output);
extern void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter, extern void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, const int *padding, const Tensor *bias, const int *padding,
Tensor *output); const DataType dt, Tensor *output);
template<typename T> template<typename T>
void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
...@@ -31,7 +31,7 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -31,7 +31,7 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
Tensor *output) { Tensor *output) {
typedef void (*Conv2dOpenclFunction)(const Tensor *input, const Tensor *filter, typedef void (*Conv2dOpenclFunction)(const Tensor *input, const Tensor *filter,
const Tensor *bias, const int *padding, const Tensor *bias, const int *padding,
Tensor *output); DataType dt, Tensor *output);
// Selection matrix: kernel_size x stride_size // Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5][2] = { static const Conv2dOpenclFunction selector[5][2] = {
{Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2}, {Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2},
...@@ -70,7 +70,7 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -70,7 +70,7 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
} }
auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1]; auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_func(input, filter, bias, paddings.data(), output); conv2d_func(input, filter, bias, paddings.data(), DataTypeToEnum<T>::value, output);
} }
template struct Conv2dFunctor<DeviceType::OPENCL, float>; template struct Conv2dFunctor<DeviceType::OPENCL, float>;
......
...@@ -15,6 +15,7 @@ void Conv1x1(const Tensor *input, ...@@ -15,6 +15,7 @@ void Conv1x1(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
const int stride, const int stride,
const DataType dt,
Tensor *output) { Tensor *output) {
const index_t batch = output->dim(0); const index_t batch = output->dim(0);
const index_t height = output->dim(1); const index_t height = output->dim(1);
...@@ -32,8 +33,8 @@ void Conv1x1(const Tensor *input, ...@@ -32,8 +33,8 @@ void Conv1x1(const Tensor *input,
MACE_CHECK(input_batch == batch); MACE_CHECK(input_batch == batch);
std::set<std::string> built_options; std::set<std::string> built_options;
built_options.emplace(input->dtype() == DT_FLOAT ? "-DTYPE_FLOAT" : ""); built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(input->dtype())); built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(dt));
built_options.emplace("-DSTRIDE=" + ToString(stride)); built_options.emplace("-DSTRIDE=" + ToString(stride));
if (bias != nullptr) { if (bias != nullptr) {
built_options.emplace("-DBIAS"); built_options.emplace("-DBIAS");
...@@ -74,16 +75,18 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input, ...@@ -74,16 +75,18 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
const int *padding, const int *padding,
const DataType dt,
Tensor *output) { Tensor *output) {
Conv1x1(input, filter, bias, 1, output); Conv1x1(input, filter, bias, 1, dt, output);
}; };
extern void Conv2dOpenclK1x1S2(const Tensor *input, extern void Conv2dOpenclK1x1S2(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
const int *padding, const int *padding,
const DataType dt,
Tensor *output) { Tensor *output) {
Conv1x1(input, filter, bias, 2, output); Conv1x1(input, filter, bias, 2, dt, output);
}; };
} // namespace kernels } // namespace kernels
......
...@@ -13,7 +13,8 @@ namespace kernels { ...@@ -13,7 +13,8 @@ namespace kernels {
static void Conv2d3x3S12(const Tensor *input, const Tensor *filter, static void Conv2d3x3S12(const Tensor *input, const Tensor *filter,
const Tensor *bias, const uint32_t stride, const Tensor *bias, const uint32_t stride,
const int *padding, Tensor *output) { const int *padding, const DataType dt,
Tensor *output) {
const index_t batch = output->dim(0); const index_t batch = output->dim(0);
const index_t height = output->dim(1); const index_t height = output->dim(1);
const index_t width = output->dim(2); const index_t width = output->dim(2);
...@@ -25,8 +26,8 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter, ...@@ -25,8 +26,8 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter,
const index_t width_blocks = RoundUpDiv<index_t, 5>(width); const index_t width_blocks = RoundUpDiv<index_t, 5>(width);
std::set<std::string> built_options; std::set<std::string> built_options;
built_options.emplace(input->dtype() == DT_FLOAT ? "-DTYPE_FLOAT" : ""); built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(input->dtype())); built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : ""); built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride)); built_options.emplace("-DSTRIDE=" + ToString(stride));
...@@ -63,13 +64,15 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter, ...@@ -63,13 +64,15 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter,
} }
void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter, void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, const int *padding, Tensor *output) { const Tensor *bias, const int *padding,
Conv2d3x3S12(input, filter, bias, 1, padding, output); const DataType dt, Tensor *output) {
Conv2d3x3S12(input, filter, bias, 1, padding, dt, output);
}; };
void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter, void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, const int *padding, Tensor *output) { const Tensor *bias, const int *padding,
Conv2d3x3S12(input, filter, bias, 2, padding, output); const DataType dt, Tensor *output) {
Conv2d3x3S12(input, filter, bias, 2, padding, dt, output);
}; };
} // namespace kernels } // namespace kernels
......
...@@ -558,14 +558,14 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -558,14 +558,14 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
} }
TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) { TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 32, 64, 128}); TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 32, 32, 64});
} }
TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) { TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({107, 113, 5, 7}); TestComplexConvNxNS12<DeviceType::OPENCL, float>({107, 113, 5, 7});
} }
template<DeviceType D, typename T> template<DeviceType D>
static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) { static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
...@@ -612,15 +612,15 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -612,15 +612,15 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
std::vector<half> input_data(float_input_data.begin(), float_input_data.end()); std::vector<half> input_data(float_input_data.begin(), float_input_data.end());
std::vector<half> filter_data(float_filter_data.begin(), float_filter_data.end()); std::vector<half> filter_data(float_filter_data.begin(), float_filter_data.end());
std::vector<half> bias_data(float_bias_data.begin(), float_bias_data.end()); std::vector<half> bias_data(float_bias_data.begin(), float_bias_data.end());
net.AddInputFromArray<D, T>("InputHalf", {batch, height, width, input_channels}, input_data); net.AddInputFromArray<D, half>("InputHalf", {batch, height, width, input_channels}, input_data);
net.AddInputFromArray<D, T>( net.AddInputFromArray<D, half>(
"FilterHalf", {kernel_h, kernel_w, input_channels, output_channels}, filter_data); "FilterHalf", {kernel_h, kernel_w, input_channels, output_channels}, filter_data);
net.AddInputFromArray<D, T>("BiasHalf", {output_channels}, bias_data); net.AddInputFromArray<D, half>("BiasHalf", {output_channels}, bias_data);
// run on gpu // run on gpu
BufferToImage<D, T>(net, "InputHalf", "InputImage", kernels::BufferType::IN_OUT); BufferToImage<D, half>(net, "InputHalf", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D, T>(net, "FilterHalf", "FilterImage", kernels::BufferType::FILTER); BufferToImage<D, half>(net, "FilterHalf", "FilterImage", kernels::BufferType::FILTER);
BufferToImage<D, T>(net, "BiasHalf", "BiasImage", kernels::BufferType::ARGUMENT); BufferToImage<D, half>(net, "BiasHalf", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage") .Input("InputImage")
...@@ -630,24 +630,26 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -630,24 +630,26 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
.AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type) .AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on device // Run on device
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D, T>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT); ImageToBuffer<D, half>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT);
ExpectTensorNear<float, T>(expected, *net.GetOutput("OPENCLOutput"), 1.0); ExpectTensorNear<float, half>(expected, *net.GetOutput("OPENCLOutput"), 0.2);
}; };
for (int kernel_size : {3}) { for (int kernel_size : {1, 3}) {
for (int stride : {1}) { for (int stride : {1, 2}) {
func(kernel_size, kernel_size, stride, stride, VALID); func(kernel_size, kernel_size, stride, stride, VALID);
} }
} }
} }
// TODO: support half input & float computation TEST_F(Conv2dOpTest, OPENCLHalfAlignedConvNxNS12) {
//TEST_F(Conv2dOpTest, OPENCLHalfAlignedConvNxNS12) { TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32, 32, 64});
// TestHalfComplexConvNxNS12<DeviceType::OPENCL, half>({32, 32, 64, 128}); }
//}
TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConvNxNS12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({107, 113, 5, 7});
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册