From 3d1bf3ebebd81b8bf9cc41a9d1fa1ce142c114ce Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 22 Jan 2018 19:23:05 +0800 Subject: [PATCH] Finish winograd convolution of 3x3 with valid padding. --- mace/core/operator.cc | 8 + mace/kernels/gemm.h | 6 +- mace/kernels/opencl/addn.cc | 3 +- mace/kernels/opencl/buffer_to_image.cc | 32 +- mace/kernels/opencl/cl/buffer_to_image.cl | 209 +++++++++++ mace/kernels/opencl/cl/gemm.cl | 75 ++-- mace/kernels/opencl/cl/winograd_transform.cl | 185 ++++++++++ mace/kernels/opencl/concat.cc | 2 +- mace/kernels/opencl/conv_2d_opencl.cc | 2 +- mace/kernels/opencl/gemm.cc | 13 +- mace/kernels/opencl/helper.cc | 58 ++- mace/kernels/opencl/helper.h | 12 +- mace/kernels/opencl/pooling_opencl.cc | 2 +- mace/kernels/opencl/resize_bilinear_opencl.cc | 2 +- mace/kernels/opencl/space_to_batch_opencl.cc | 2 +- mace/kernels/opencl/winograd_transform.cc | 141 ++++++++ mace/kernels/winograd_transform.h | 88 +++++ mace/ops/addn_benchmark.cc | 2 +- mace/ops/addn_test.cc | 4 +- mace/ops/batch_norm_benchmark.cc | 2 +- mace/ops/batch_norm_test.cc | 20 +- mace/ops/batch_to_space_benchmark.cc | 2 +- mace/ops/bias_add_benchmark.cc | 2 +- mace/ops/bias_add_test.cc | 12 +- mace/ops/buffer_to_image_test.cc | 12 +- mace/ops/concat_benchmark.cc | 4 +- mace/ops/concat_test.cc | 4 +- mace/ops/conv_2d_benchmark.cc | 27 +- mace/ops/conv_2d_test.cc | 32 +- mace/ops/folded_batch_norm_test.cc | 20 +- mace/ops/fused_conv_2d_test.cc | 36 +- mace/ops/gemm.h | 5 +- mace/ops/gemm_benchmark.cc | 17 +- mace/ops/gemm_test.cc | 58 +-- mace/ops/pooling_test.cc | 16 +- mace/ops/resize_bilinear_benchmark.cc | 2 +- mace/ops/resize_bilinear_test.cc | 4 +- mace/ops/softmax_benchmark.cc | 2 +- mace/ops/softmax_test.cc | 8 +- mace/ops/space_to_batch_benchmark.cc | 2 +- mace/ops/space_to_batch_test.cc | 8 +- mace/ops/winograd_inverse_transform.cc | 22 ++ mace/ops/winograd_inverse_transform.h | 42 +++ mace/ops/winograd_transform.cc | 22 ++ mace/ops/winograd_transform.h | 41 +++ mace/ops/winograd_transform_benchmark.cc | 111 ++++++ mace/ops/winograd_transform_test.cc | 334 ++++++++++++++++++ 47 files changed, 1501 insertions(+), 212 deletions(-) create mode 100644 mace/kernels/opencl/cl/winograd_transform.cl create mode 100644 mace/kernels/opencl/winograd_transform.cc create mode 100644 mace/kernels/winograd_transform.h create mode 100644 mace/ops/winograd_inverse_transform.cc create mode 100644 mace/ops/winograd_inverse_transform.h create mode 100644 mace/ops/winograd_transform.cc create mode 100644 mace/ops/winograd_transform.h create mode 100644 mace/ops/winograd_transform_benchmark.cc create mode 100644 mace/ops/winograd_transform_test.cc diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 7554fbba..94e4f22f 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -77,6 +77,10 @@ extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_Softmax(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); +extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); +extern void Register_GEMM(OperatorRegistry *op_registry); +extern void Register_WinogradTransform(OperatorRegistry *op_registry); +extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); OperatorRegistry::OperatorRegistry() { Register_Activation(this); @@ -97,6 +101,10 @@ OperatorRegistry::OperatorRegistry() { Register_ResizeBilinear(this); Register_Softmax(this); Register_SpaceToBatchND(this); + Register_FoldedBatchNorm(this); + Register_GEMM(this); + Register_WinogradTransform(this); + Register_WinogradInverseTransform(this); } } // namespace mace diff --git a/mace/kernels/gemm.h b/mace/kernels/gemm.h index 94cd2bdc..146c07e1 100644 --- a/mace/kernels/gemm.h +++ b/mace/kernels/gemm.h @@ -19,12 +19,12 @@ struct GEMMFunctor { Tensor *C, StatsFuture *future) { - std::vector c_shape = {A->dim(0), A->dim(1), 1, B->dim(3)}; + std::vector c_shape = {A->dim(0), A->dim(1), B->dim(2), 1}; C->Resize(c_shape); const index_t N = C->dim(0); const index_t height = C->dim(1); - const index_t width = C->dim(3); - const index_t K = A->dim(3); + const index_t width = C->dim(2); + const index_t K = A->dim(2); Tensor::MappingGuard guarda(A); Tensor::MappingGuard guardb(B); Tensor::MappingGuard guardc(C); diff --git a/mace/kernels/opencl/addn.cc b/mace/kernels/opencl/addn.cc index 261efde0..65a2cf20 100644 --- a/mace/kernels/opencl/addn.cc +++ b/mace/kernels/opencl/addn.cc @@ -17,7 +17,6 @@ static void AddN(const std::vector &input_tensors, if (input_tensors.size() > 4) { MACE_NOT_IMPLEMENTED; } - output->ResizeLike(input_tensors[0]); const index_t batch = output->dim(0); const index_t height = output->dim(1); @@ -82,7 +81,7 @@ void AddNFunctor::operator()( std::vector output_shape = input_tensors[0]->shape(); std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape); + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); output_tensor->ResizeImage(output_shape, output_image_shape); AddN(input_tensors, output_tensor, future); diff --git a/mace/kernels/opencl/buffer_to_image.cc b/mace/kernels/opencl/buffer_to_image.cc index ae81d32f..98035498 100644 --- a/mace/kernels/opencl/buffer_to_image.cc +++ b/mace/kernels/opencl/buffer_to_image.cc @@ -18,13 +18,21 @@ void BufferToImageFunctor::operator()(Tensor *buffer, std::vector image_shape; if (!i2b_) { CalImage2DShape(buffer->shape(), type, image_shape); - image->ResizeImage(buffer->shape(), image_shape); + if(type == WINOGRAD_FILTER) { + std::vector new_shape = + CalWinogradShape(buffer->shape(), type); + image->ResizeImage(new_shape, image_shape); + } else { + image->ResizeImage(buffer->shape(), image_shape); + } buffer->MarkUnused(); } else { image_shape = image->image_shape(); buffer->Resize(image->shape()); } + size_t gws[2] = {image_shape[0], + image_shape[1]}; string kernel_name; switch (type) { case CONV2D_FILTER: @@ -33,12 +41,23 @@ void BufferToImageFunctor::operator()(Tensor *buffer, case DW_CONV2D_FILTER: kernel_name = i2b_ ? "dw_filter_image_to_buffer" : "dw_filter_buffer_to_image"; break; - case IN_OUT: + case IN_OUT_CHANNEL: kernel_name = i2b_ ? "in_out_image_to_buffer" : "in_out_buffer_to_image"; break; case ARGUMENT: kernel_name = i2b_ ? "arg_image_to_buffer" : "arg_buffer_to_image"; break; + case IN_OUT_HEIGHT: + kernel_name = i2b_ ? "in_out_height_image_to_buffer" : "in_out_height_buffer_to_image"; + break; + case IN_OUT_WIDTH: + MACE_CHECK(!i2b_) << "IN_OUT_WIDTH only support buffer to image now"; + kernel_name = "in_out_width_buffer_to_image"; + break; + case WINOGRAD_FILTER: + gws[1] /= 16; + kernel_name = i2b_ ? "winograd_filter_image_to_buffer" : "winograd_filter_buffer_to_image"; + break; } string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name); std::set built_options; @@ -68,16 +87,13 @@ void BufferToImageFunctor::operator()(Tensor *buffer, } b2f_kernel.setArg(idx++, *(static_cast(image->buffer()))); - const size_t gws[3] = {image_shape[0], - image_shape[1], - 1}; const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(b2f_kernel); - const std::vector lws = {16, 64, 1}; + const std::vector lws = {16, 64}; cl::Event event; cl_int error = runtime->command_queue().enqueueNDRangeKernel( b2f_kernel, cl::NullRange, - cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(lws[0], lws[1], lws[2]), + cl::NDRange(gws[0], gws[1]), + cl::NDRange(lws[0], lws[1]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; diff --git a/mace/kernels/opencl/cl/buffer_to_image.cl b/mace/kernels/opencl/cl/buffer_to_image.cl index 2ac05209..f95029c0 100644 --- a/mace/kernels/opencl/cl/buffer_to_image.cl +++ b/mace/kernels/opencl/cl/buffer_to_image.cl @@ -233,3 +233,212 @@ __kernel void arg_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ vstore4(values, 0, output + offset); } } + + +__kernel void in_out_height_buffer_to_image(__global const DATA_TYPE *input, //nhwc + __private const int height, + __private const int width, + __private const int channels, + __write_only image2d_t output) { + int w = get_global_id(0); + int h = get_global_id(1); + const int wc = width * channels; + const int height_blks = (height + 3) / 4; + const int batch_idx = h / height_blks; + const int height_idx = (h % height_blks) << 2; + const int width_idx = w % width; + const int channel_idx = w / width; + int offset = ((batch_idx * height + height_idx) * width + width_idx) * channels + + channel_idx; + + int size = height - height_idx; + size = size >= 4 ? 0 : size; + DATA_TYPE4 values = 0; + switch(size) { + case 0: + values.w = *(input + offset + wc * 3); + case 3: + values.z = *(input + offset + wc * 2); + case 2: + values.y = *(input + offset + wc); + case 1: + values.x = *(input + offset); + } + int2 coord = (int2)(w, h); + WRITE_IMAGET(output, coord, values); +} + +__kernel void in_out_height_image_to_buffer(__global DATA_TYPE *output, //nhwc + __private const int height, + __private const int width, + __private const int channels, + __read_only image2d_t input) { + int w = get_global_id(0); + int h = get_global_id(1); + const int height_blks = (height + 3) / 4; + const int batch_idx = h / height_blks; + const int height_idx = (h % height_blks) << 2; + const int width_idx = w % width; + const int channel_idx = w / width; + int offset = ((batch_idx * height + height_idx) * width + width_idx) * channels + + channel_idx; + + int2 coord = (int2)(w, h); + DATA_TYPE4 values = READ_IMAGET(input, SAMPLER, coord); + output[offset] = values.x; + if (height_idx + 1 >= height) return; + offset += width * channels; + output[offset] = values.y; + if (height_idx + 2 >= height) return; + offset += width * channels; + output[offset] = values.z; + if (height_idx + 3 >= height) return; + offset += width * channels; + output[offset] = values.w; +} + + +__kernel void in_out_width_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ + __private const int height, + __private const int width, + __private const int channels, + __write_only image2d_t output) { + int w = get_global_id(0); + int h = get_global_id(1); + const int batch_idx = h / height; + const int height_idx = h % height; + const int width_idx = (w % width) << 2; + const int channel_idx = w / width; + const int offset = ((batch_idx * height + height_idx) * width + width_idx) * channels + + channel_idx; + + int size = width - width_idx; + size = size >= 4 ? 0 : size; + DATA_TYPE4 values = 0; + switch(size) { + case 0: + values.w = *(input + offset + channels * 3); + case 3: + values.z = *(input + offset + channels * 2); + case 2: + values.y = *(input + offset + channels); + case 1: + values.x = *(input + offset); + } + int2 coord = (int2)(w, h); + WRITE_IMAGET(output, coord, values); +} + +// only support 3x3 now +__kernel void winograd_filter_buffer_to_image(__global const DATA_TYPE *input, //Oc, Ic, H, W + __private const int in_channels, + __private const int height, + __private const int width, + __write_only image2d_t output) { + int w = get_global_id(0); + int h = get_global_id(1); + const int out_channels = get_global_size(1); + const int out_channel_idx = h; + const int in_channel_idx = w << 2; + const int offset = (out_channel_idx * in_channels + in_channel_idx) * height * width; + const int length = min((in_channels - in_channel_idx) * 9, 36); + DATA_TYPE in[36] = {0}; + DATA_TYPE4 tt; + DATA_TYPE4 tu0[4], tu1[4], tu2[4], tu3[4]; + +#pragma unroll + for (short i = 0; i < length; ++i) { + in[i] = *(input + offset + i); + } + tt = ((DATA_TYPE4)(in[0], in[9], in[18], in[27]) + + (DATA_TYPE4)(in[6], in[15], in[24], in[33])) / 2; + tu1[0] = tt + ((DATA_TYPE4)(in[3], in[12], in[21], in[30]) / 2); + tu2[0] = tt - ((DATA_TYPE4)(in[3], in[12], in[21], in[30]) / 2); + tt = ((DATA_TYPE4)(in[1], in[10], in[19], in[28]) + + (DATA_TYPE4)(in[7], in[16], in[25], in[34])) / 2; + tu1[1] = tt + ((DATA_TYPE4)(in[4], in[13], in[22], in[31]) / 2); + tu2[1] = tt - ((DATA_TYPE4)(in[4], in[13], in[22], in[31]) / 2); + tt = ((DATA_TYPE4)(in[2], in[11], in[20], in[29]) + + (DATA_TYPE4)(in[8], in[17], in[26], in[35])) / 2; + tu1[2] = tt + ((DATA_TYPE4)(in[5], in[14], in[23], in[32]) / 2); + tu2[2] = tt - ((DATA_TYPE4)(in[5], in[14], in[23], in[32]) / 2); + tu0[0] = (DATA_TYPE4)(in[0], in[9], in[18], in[27]); + tu0[1] = (DATA_TYPE4)(in[1], in[10], in[19], in[28]); + tu0[2] = (DATA_TYPE4)(in[2], in[11], in[20], in[29]); + tu3[0] = (DATA_TYPE4)(in[6], in[15], in[24], in[33]); + tu3[1] = (DATA_TYPE4)(in[7], in[16], in[25], in[34]); + tu3[2] = (DATA_TYPE4)(in[8], in[17], in[26], in[35]); + + tt = (tu0[0] + tu0[2]) / 2; + tu0[3] = tu0[2]; + tu0[2] = tt - tu0[1] / 2; + tu0[1] = tt + tu0[1] / 2; + tt = (tu1[0] + tu1[2]) / 2; + tu1[3] = tu1[2]; + tu1[2] = tt - tu1[1] / 2; + tu1[1] = tt + tu1[1] / 2; + tt = (tu2[0] + tu2[2]) / 2; + tu2[3] = tu2[2]; + tu2[2] = tt - tu2[1] / 2; + tu2[1] = tt + tu2[1] / 2; + tt = (tu3[0] + tu3[2]) / 2; + tu3[3] = tu3[2]; + tu3[2] = tt - tu3[1] / 2; + tu3[1] = tt + tu3[1] / 2; + + int2 coord = (int2)(w, h); +#pragma unroll + for (short i = 0; i < 4; ++i) { + WRITE_IMAGET(output, coord, tu0[i]); + coord.y += out_channels; + } +#pragma unroll + for (short i = 0; i < 4; ++i) { + WRITE_IMAGET(output, coord, tu1[i]); + coord.y += out_channels; + } +#pragma unroll + for (short i = 0; i < 4; ++i) { + WRITE_IMAGET(output, coord, tu2[i]); + coord.y += out_channels; + } +#pragma unroll + for (short i = 0; i < 4; ++i) { + WRITE_IMAGET(output, coord, tu3[i]); + coord.y += out_channels; + } +} + +// only support 3x3 now +__kernel void winograd_filter_image_to_buffer(__global DATA_TYPE *output, //Oc, Ic, H, W + __private const int height, + __private const int width, + __private const int channel, + __read_only image2d_t input) { + const int w = get_global_id(0); + const int h = get_global_id(1); + const int width_idx = w << 2; + const int size = width - width_idx; + int offset = h * width + width_idx; + + int2 coord = (int2)(w, h); + DATA_TYPE4 values; + for (short i = 0; i < 16; ++i) { + values = READ_IMAGET(input, SAMPLER, coord); + if (size < 4) { + switch (size) { + case 3: + output[offset+2] = values.z; + case 2: + output[offset+1] = values.y; + case 1: + output[offset] = values.x; + } + } else { + vstore4(values, 0, output + offset); + } + + coord.y += height; + offset += height * width; + } +} diff --git a/mace/kernels/opencl/cl/gemm.cl b/mace/kernels/opencl/cl/gemm.cl index 994a190a..2f07579a 100644 --- a/mace/kernels/opencl/cl/gemm.cl +++ b/mace/kernels/opencl/cl/gemm.cl @@ -5,56 +5,47 @@ __kernel void gemm(__read_only image2d_t A, __read_only image2d_t B, __write_only image2d_t C, __private const int M, + __private const int N, + __private const int K, __private const int height_blocks, - __private const int K) { - const int gx = get_global_id(0); + __private const int k_blocks) { + const int gx = get_global_id(0) << 2; const int hb = get_global_id(1); const int batch = hb / height_blocks; - const int gy = (hb % height_blocks) << 2; - const int bm = mul24(batch, M); - const int bk = mul24(batch, K); + const int ty = (hb % height_blocks); + const int gy = mad24(batch, height_blocks, ty); + const int bm = mad24(batch, M, ty << 2); + const int bk = mul24(batch, k_blocks); + float4 a0, a1, a2, a3; float4 b0, b1, b2, b3; - float4 c0, c1, c2, c3; + float4 c0 = 0, c1 = 0, c2 = 0, c3 = 0; - for (short pos = 0; pos < K; pos += 4) { - a0 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy))); - a1 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy + 1))); - a2 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy + 2))); - a3 = READ_IMAGET(A, SAMPLER, (int2)(pos >> 2, (bm + gy + 3))); + for (short pos = 0; pos < k_blocks; pos += 1) { + a0 = READ_IMAGET(A, SAMPLER, (int2)(pos, (bm))); + a1 = READ_IMAGET(A, SAMPLER, (int2)(pos, (bm + 1))); + a2 = READ_IMAGET(A, SAMPLER, (int2)(pos, (bm + 2))); + a3 = READ_IMAGET(A, SAMPLER, (int2)(pos, (bm + 3))); b0 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos))); - b1 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos + 1))); - b2 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos + 2))); - b3 = READ_IMAGET(B, SAMPLER, (int2)(gx, (bk + pos + 3))); - - c0 = mad(a0.x, b0, c0); - c0 = mad(a0.y, b1, c0); - c0 = mad(a0.z, b2, c0); - c0 = mad(a0.w, b3, c0); - - c1 = mad(a1.x, b0, c1); - c1 = mad(a1.y, b1, c1); - c1 = mad(a1.z, b2, c1); - c1 = mad(a1.w, b3, c1); - - c2 = mad(a2.x, b0, c2); - c2 = mad(a2.y, b1, c2); - c2 = mad(a2.z, b2, c2); - c2 = mad(a2.w, b3, c2); - - c3 = mad(a3.x, b0, c3); - c3 = mad(a3.y, b1, c3); - c3 = mad(a3.z, b2, c3); - c3 = mad(a3.w, b3, c3); + b1 = READ_IMAGET(B, SAMPLER, (int2)(gx + 1, (bk + pos))); + b2 = READ_IMAGET(B, SAMPLER, (int2)(gx + 2, (bk + pos))); + b3 = READ_IMAGET(B, SAMPLER, (int2)(gx + 3, (bk + pos))); + + c0 += (DATA_TYPE4)(dot(a0, b0), dot(a1, b0), dot(a2, b0), dot(a3, b0)); + + c1 += (DATA_TYPE4)(dot(a0, b1), dot(a1, b1), dot(a2, b1), dot(a3, b1)); + + c2 += (DATA_TYPE4)(dot(a0, b2), dot(a1, b2), dot(a2, b2), dot(a3, b2)); + + c3 += (DATA_TYPE4)(dot(a0, b3), dot(a1, b3), dot(a2, b3), dot(a3, b3)); } - if (gy >= M) return; - WRITE_IMAGET(C, (int2)(gx, (bm + gy)), c0); - if ((gy + 1) >= M) return; - WRITE_IMAGET(C, (int2)(gx, (bm + gy + 1)), c1); - if ((gy + 2) >= M) return; - WRITE_IMAGET(C, (int2)(gx, (bm + gy + 2)), c2); - if ((gy + 3) >= M) return; - WRITE_IMAGET(C, (int2)(gx, (bm + gy + 3)), c3); + WRITE_IMAGET(C, (int2)(gx, gy), c0); + if ((gx + 1) >= N) return; + WRITE_IMAGET(C, (int2)(gx + 1, gy), c1); + if ((gx + 2) >= N) return; + WRITE_IMAGET(C, (int2)(gx + 2, gy), c2); + if ((gx + 3) >= N) return; + WRITE_IMAGET(C, (int2)(gx + 3, gy), c3); } diff --git a/mace/kernels/opencl/cl/winograd_transform.cl b/mace/kernels/opencl/cl/winograd_transform.cl new file mode 100644 index 00000000..6ab698dd --- /dev/null +++ b/mace/kernels/opencl/cl/winograd_transform.cl @@ -0,0 +1,185 @@ +#include + +__kernel void winograd_transform_2x2(__read_only image2d_t input, + __write_only image2d_t output, + __private const int in_height, + __private const int in_width, + __private const int in_channel, + __private const int round_hw, + __private const int round_w, + __private const int padding_top, + __private const int padding_left) { + int out_width_idx = get_global_id(0); + int chan_blk_idx = get_global_id(1); + const int chan_blk_size = get_global_size(1); + + const int batch_idx = out_width_idx / round_hw; + const int t_idx = out_width_idx % round_hw; + const int height_idx = ((t_idx / round_w) << 1) - padding_top; + const int width_idx = ((t_idx % round_w) << 1) - padding_left; + + const int nh_idx = mad24(batch_idx, in_height, height_idx); + const int wc_idx = mad24(chan_blk_idx, in_width, width_idx); + + DATA_TYPE4 input0[4]; + DATA_TYPE4 input1[4]; + DATA_TYPE4 input2[4]; + DATA_TYPE4 input3[4]; + + DATA_TYPE4 tv0[4]; + DATA_TYPE4 tv1[4]; + DATA_TYPE4 tv2[4]; + DATA_TYPE4 tv3[4]; + + int y = nh_idx; +#pragma unroll + for (short i = 0; i < 4; ++i) { + int x = width_idx + i; + x = select(wc_idx + i, -1, x >= in_width); + input0[i] = READ_IMAGET(input, SAMPLER, (int2)(x, y)); + } + y = select(nh_idx + 1, -1, height_idx + 1 >= in_height); +#pragma unroll + for (short i = 0; i < 4; ++i) { + int x = width_idx + i; + x = select(wc_idx + i, -1, x >= in_width); + input1[i] = READ_IMAGET(input, SAMPLER, (int2)(x, y)); + } + y = select(nh_idx + 2, -1, height_idx + 2 >= in_height); +#pragma unroll + for (short i = 0; i < 4; ++i) { + int x = width_idx + i; + x = select(wc_idx + i, -1, x >= in_width); + input2[i] = READ_IMAGET(input, SAMPLER, (int2)(x, y)); + } + y = select(nh_idx + 3, -1, height_idx + 3 >= in_height); +#pragma unroll + for (short i = 0; i < 4; ++i) { + int x = width_idx + i; + x = select(wc_idx + i, -1, x >= in_width); + input3[i] = READ_IMAGET(input, SAMPLER, (int2)(x, y)); + } + +#pragma unroll + for (short i = 0; i < 4; ++i) { + tv0[i] = input0[i] - input2[i]; + tv1[i] = input1[i] + input2[i]; + tv2[i] = input2[i] - input1[i]; + tv3[i] = input1[i] - input3[i]; + } + input0[0] = tv0[0] - tv0[2]; + input0[1] = tv0[1] + tv0[2]; + input0[2] = tv0[2] - tv0[1]; + input0[3] = tv0[1] - tv0[3]; + input1[0] = tv1[0] - tv1[2]; + input1[1] = tv1[1] + tv1[2]; + input1[2] = tv1[2] - tv1[1]; + input1[3] = tv1[1] - tv1[3]; + input2[0] = tv2[0] - tv2[2]; + input2[1] = tv2[1] + tv2[2]; + input2[2] = tv2[2] - tv2[1]; + input2[3] = tv2[1] - tv2[3]; + input3[0] = tv3[0] - tv3[2]; + input3[1] = tv3[1] + tv3[2]; + input3[2] = tv3[2] - tv3[1]; + input3[3] = tv3[1] - tv3[3]; + +#pragma unroll + for (short i = 0; i < 4; ++i) { + WRITE_IMAGET(output, (int2)(out_width_idx, chan_blk_idx), input0[i]); + chan_blk_idx += chan_blk_size; + } +#pragma unroll + for (short i = 0; i < 4; ++i) { + WRITE_IMAGET(output, (int2)(out_width_idx, chan_blk_idx), input1[i]); + chan_blk_idx += chan_blk_size; + } +#pragma unroll + for (short i = 0; i < 4; ++i) { + WRITE_IMAGET(output, (int2)(out_width_idx, chan_blk_idx), input2[i]); + chan_blk_idx += chan_blk_size; + } +#pragma unroll + for (short i = 0; i < 4; ++i) { + WRITE_IMAGET(output, (int2)(out_width_idx, chan_blk_idx), input3[i]); + chan_blk_idx += chan_blk_size; + } +} + +__kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, + __write_only image2d_t output, + __private const int out_height, + __private const int out_width, + __private const int round_hw, + __private const int round_w) { + const int width_idx = get_global_id(0); + const int height_idx = get_global_id(1); + const int out_channel = get_global_size(1); + int width = width_idx; + int height = height_idx; + + DATA_TYPE4 in0[4], in1[4], in2[4], in3[4]; + +#pragma unroll + for (short i = 0; i < 4; ++i) { + in0[i] = READ_IMAGET(input, SAMPLER, (int2)(width, height)); + height += out_channel; + } +#pragma unroll + for (short i = 0; i < 4; ++i) { + in1[i] = READ_IMAGET(input, SAMPLER, (int2)(width_idx, height)); + height += out_channel; + } +#pragma unroll + for (short i = 0; i < 4; ++i) { + in2[i] = READ_IMAGET(input, SAMPLER, (int2)(width_idx, height)); + height += out_channel; + } +#pragma unroll + for (short i = 0; i < 4; ++i) { + in3[i] = READ_IMAGET(input, SAMPLER, (int2)(width_idx, height)); + height += out_channel; + } + + in0[0] = in0[0] + in1[0] + in2[0]; + in0[1] = in0[1] + in1[1] + in2[1]; + in0[2] = in0[2] + in1[2] + in2[2]; + in0[3] = in0[3] + in1[3] + in2[3]; + + in0[0] = in0[0] + in0[1] + in0[2]; + in0[1] = in0[1] - in0[2] - in0[3]; + + in1[0] = in1[0] - in2[0] - in3[0]; + in1[1] = in1[1] - in2[1] - in3[1]; + in1[2] = in1[2] - in2[2] - in3[2]; + in1[3] = in1[3] - in2[3] - in3[3]; + + in1[0] = in1[0] + in1[1] + in1[2]; + in1[1] = in1[1] - in1[2] - in1[3]; + + const int batch = width_idx / round_hw; + int t = width_idx % round_hw; + const int out_height_idx = (t / round_w) << 1; + const int out_width_idx = (t % round_w) << 1; + const int out_chan_idx = height_idx; + const int coord_x = mad24(out_chan_idx, out_width, out_width_idx); + const int coord_y = mad24(batch, out_height, out_height_idx); + + WRITE_IMAGET(output, (int2)(coord_x, coord_y), in0[0]); + + t = 0; + if (out_width_idx + 1 < out_width) { + WRITE_IMAGET(output, (int2)(coord_x + 1, coord_y), in0[1]); + t += 1; + } + if (out_height_idx + 1 < out_height) { + WRITE_IMAGET(output, (int2)(coord_x, coord_y + 1), in1[0]); + t += 1; + } + if (t == 2) { + WRITE_IMAGET(output, (int2)(coord_x + 1, coord_y + 1), in1[1]); + } + + + +} diff --git a/mace/kernels/opencl/concat.cc b/mace/kernels/opencl/concat.cc index b47a096e..b1c3850d 100644 --- a/mace/kernels/opencl/concat.cc +++ b/mace/kernels/opencl/concat.cc @@ -85,7 +85,7 @@ void ConcatFunctor::operator()(const std::vectordim(axis_); } std::vector image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT, image_shape); + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, image_shape); output->ResizeImage(output_shape, image_shape); switch (inputs_count) { diff --git a/mace/kernels/opencl/conv_2d_opencl.cc b/mace/kernels/opencl/conv_2d_opencl.cc index 072a0abf..03883d3a 100644 --- a/mace/kernels/opencl/conv_2d_opencl.cc +++ b/mace/kernels/opencl/conv_2d_opencl.cc @@ -109,7 +109,7 @@ void Conv2dFunctor::operator()(const Tensor *input, paddings_, output_shape.data(), paddings.data()); std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape); + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); output->ResizeImage(output_shape, output_image_shape); if (kernel_h == kernel_w && kernel_h <= 5 && diff --git a/mace/kernels/opencl/gemm.cc b/mace/kernels/opencl/gemm.cc index 934dea79..60b632c5 100644 --- a/mace/kernels/opencl/gemm.cc +++ b/mace/kernels/opencl/gemm.cc @@ -17,17 +17,17 @@ void GEMMFunctor::operator()( Tensor *C, StatsFuture *future) { - std::vector c_shape = {A->dim(0), A->dim(1), 1, B->dim(3)}; + std::vector c_shape = {A->dim(0), A->dim(1), B->dim(2), 1}; std::vector c_image_shape; - CalImage2DShape(c_shape, BufferType::IN_OUT, c_image_shape); + CalImage2DShape(c_shape, BufferType::IN_OUT_HEIGHT, c_image_shape); C->ResizeImage(c_shape, c_image_shape); const index_t batch = C->dim(0); const index_t height = C->dim(1); - const index_t width = C->dim(3); + const index_t width = C->dim(2); - const index_t width_blocks = RoundUpDiv4(width); const index_t height_blocks = RoundUpDiv4(height); + const index_t width_blocks = RoundUpDiv4(width); auto runtime = OpenCLRuntime::Global(); std::set built_options; @@ -45,8 +45,10 @@ void GEMMFunctor::operator()( *(static_cast(B->buffer()))); gemm_kernel.setArg(idx++, *(static_cast(C->buffer()))); gemm_kernel.setArg(idx++, static_cast(height)); + gemm_kernel.setArg(idx++, static_cast(width)); + gemm_kernel.setArg(idx++, static_cast(A->dim(2))); gemm_kernel.setArg(idx++, static_cast(height_blocks)); - gemm_kernel.setArg(idx++, static_cast(A->dim(3))); + gemm_kernel.setArg(idx++, static_cast(RoundUpDiv4(A->dim(2)))); const uint32_t gws[3] = { static_cast(width_blocks), @@ -61,6 +63,7 @@ void GEMMFunctor::operator()( return {{local_ws[0], local_ws[1]}, {local_ws[1], local_ws[0]}, {kwg_size / 4, 4}, + {kwg_size / 8, 8}, {kwg_size / 16, 16}, {kwg_size / 32, 32}, {kwg_size / 64, 64}, diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index 56c157a6..ca55968d 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -45,6 +45,34 @@ void CalArgImageShape(const std::vector &shape, image_shape[1] = 1; } +// Only support 3x3 now +// [ (Ic + 3) / 4, 16 * Oc] +void CalWinogradFilterImageShape(const std::vector &shape, /* Oc, Ic, H, W*/ + std::vector &image_shape) { + MACE_CHECK(shape.size() == 4); + image_shape.resize(2); + image_shape[0] = RoundUpDiv4(shape[1]); + image_shape[1] = (shape[0] << 4); +} + +// [W * C, N * RoundUp<4>(H)] +void CalInOutHeightImageShape(const std::vector &shape, /* NHWC */ + std::vector &image_shape) { + MACE_CHECK(shape.size() == 4); + image_shape.resize(2); + image_shape[0] = shape[2] * shape[3]; + image_shape[1] = shape[0] * RoundUpDiv4(shape[1]); +} + +// [RoundUp<4>(W) * C, N * H] +void CalInOutWidthImageShape(const std::vector &shape, /* NHWC */ + std::vector &image_shape) { + MACE_CHECK(shape.size() == 4); + image_shape.resize(2); + image_shape[0] = RoundUpDiv4(shape[2]) * shape[3]; + image_shape[1] = shape[0] * shape[1]; +} + void CalImage2DShape(const std::vector &shape, /* NHWC */ const BufferType type, std::vector &image_shape) { @@ -55,13 +83,39 @@ void CalImage2DShape(const std::vector &shape, /* NHWC */ case DW_CONV2D_FILTER: CalDepthwiseConv2dFilterImageShape(shape, image_shape); break; - case IN_OUT: + case IN_OUT_CHANNEL: CalInOutputImageShape(shape, image_shape); break; case ARGUMENT: CalArgImageShape(shape, image_shape); break; - default:LOG(FATAL) << "Mace not supported yet."; + case IN_OUT_HEIGHT: + CalInOutHeightImageShape(shape, image_shape); + break; + case IN_OUT_WIDTH: + CalInOutWidthImageShape(shape, image_shape); + break; + case WINOGRAD_FILTER: + CalWinogradFilterImageShape(shape, image_shape); + break; + default: + LOG(FATAL) << "Mace not supported yet."; + } +} + + +std::vector CalWinogradShape(const std::vector &shape, + const BufferType type) { + if (type == WINOGRAD_FILTER) { + return {16, shape[0], shape[1], 1}; + }else if (type == IN_OUT_HEIGHT) { + index_t out_width = shape[0] * + ((shape[1] - 1) / 2) * + ((shape[2] - 1) / 2); + return {16, shape[3], out_width, 1}; + } else { + LOG(FATAL) << "Mace not supported yet."; + return std::vector(); } } diff --git a/mace/kernels/opencl/helper.h b/mace/kernels/opencl/helper.h index dc40514f..01e29289 100644 --- a/mace/kernels/opencl/helper.h +++ b/mace/kernels/opencl/helper.h @@ -18,15 +18,21 @@ const float kMaxKernelExeTime = 1000.0; // microseconds enum BufferType { CONV2D_FILTER = 0, - DW_CONV2D_FILTER = 1, - IN_OUT = 2, - ARGUMENT = 3 + IN_OUT_CHANNEL = 1, + ARGUMENT = 2, + IN_OUT_HEIGHT = 3, + IN_OUT_WIDTH = 4, + WINOGRAD_FILTER = 5, + DW_CONV2D_FILTER = 6, }; void CalImage2DShape(const std::vector &shape, /* NHWC */ const BufferType type, std::vector &image_shape); +std::vector CalWinogradShape(const std::vector &shape, + const BufferType type); + std::string DtToCLCMDDt(const DataType dt); std::string DtToUpstreamCLCMDDt(const DataType dt); diff --git a/mace/kernels/opencl/pooling_opencl.cc b/mace/kernels/opencl/pooling_opencl.cc index b147c15a..248bf6a7 100644 --- a/mace/kernels/opencl/pooling_opencl.cc +++ b/mace/kernels/opencl/pooling_opencl.cc @@ -92,7 +92,7 @@ void PoolingFunctor::operator()(const Tensor *input, output_shape.data(), paddings.data()); std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape); + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); output->ResizeImage(output_shape, output_image_shape); Pooling(input, strides_, paddings.data(), kernels_[0], pooling_type_, diff --git a/mace/kernels/opencl/resize_bilinear_opencl.cc b/mace/kernels/opencl/resize_bilinear_opencl.cc index f8d3aed2..5f0c4e33 100644 --- a/mace/kernels/opencl/resize_bilinear_opencl.cc +++ b/mace/kernels/opencl/resize_bilinear_opencl.cc @@ -28,7 +28,7 @@ void ResizeBilinearFunctor::operator()( std::vector output_shape {batch, out_height, out_width, channels}; if (input->is_image()) { std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape); + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); output->ResizeImage(output_shape, output_image_shape); } else { output->Resize(output_shape); diff --git a/mace/kernels/opencl/space_to_batch_opencl.cc b/mace/kernels/opencl/space_to_batch_opencl.cc index 8ef3f7c4..09897048 100644 --- a/mace/kernels/opencl/space_to_batch_opencl.cc +++ b/mace/kernels/opencl/space_to_batch_opencl.cc @@ -21,7 +21,7 @@ void SpaceToBatchFunctor::operator()(Tensor *space_tensor Tensor *batch_tensor, StatsFuture *future) { std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape); + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); const char *kernel_name = nullptr; if (b2s_) { space_tensor->ResizeImage(output_shape, output_image_shape); diff --git a/mace/kernels/opencl/winograd_transform.cc b/mace/kernels/opencl/winograd_transform.cc new file mode 100644 index 00000000..3508329a --- /dev/null +++ b/mace/kernels/opencl/winograd_transform.cc @@ -0,0 +1,141 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/winograd_transform.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" + +namespace mace { +namespace kernels { + +template +void WinogradTransformFunctor::operator()(const Tensor *input_tensor, + Tensor *output_tensor, + StatsFuture *future) { + std::vector output_shape(4); + std::vector filter_shape = {3, 3, input_tensor->dim(3), 1}; + std::vector paddings(2); + kernels::CalcNHWCPaddingAndOutputSize( + input_tensor->shape().data(), filter_shape.data(), dilations_.data(), + strides_.data(), paddings_, output_shape.data(), paddings.data()); + + const index_t round_h = (output_shape[1] + 1) / 2; + const index_t round_w = (output_shape[2] + 1) / 2; + const index_t out_width = input_tensor->dim(0) * round_h * round_w; + output_shape = {16, input_tensor->dim(3), out_width, 1}; + std::vector image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, image_shape); + output_tensor->ResizeImage(output_shape, image_shape); + + string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL("winograd_transform_2x2"); + std::set built_options; + built_options.emplace("-Dwinograd_transform_2x2=" + obfuscated_kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum::value)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum::value)); + auto runtime = OpenCLRuntime::Global(); + auto b2f_kernel = runtime->BuildKernel("winograd_transform", + obfuscated_kernel_name, + built_options); + + uint32_t idx = 0; + b2f_kernel.setArg(idx++, *(static_cast(input_tensor->buffer()))); + b2f_kernel.setArg(idx++, *(static_cast(output_tensor->buffer()))); + b2f_kernel.setArg(idx++, static_cast(input_tensor->dim(1))); + b2f_kernel.setArg(idx++, static_cast(input_tensor->dim(2))); + b2f_kernel.setArg(idx++, static_cast(input_tensor->dim(3))); + b2f_kernel.setArg(idx++, static_cast(round_h * round_w)); + b2f_kernel.setArg(idx++, static_cast(round_w)); + b2f_kernel.setArg(idx++, static_cast(paddings[0] / 2)); + b2f_kernel.setArg(idx++, static_cast(paddings[1] / 2)); + + const size_t gws[2] = {static_cast(out_width), + static_cast(RoundUpDiv4(input_tensor->dim(3)))}; + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(b2f_kernel); + const std::vector lws = {128, 8}; + cl::Event event; + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + b2f_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1]), + cl::NDRange(lws[0], lws[1]), + nullptr, &event); + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + + if (future != nullptr) { + future->wait_fn = [runtime, event](CallStats *stats) { + event.wait(); + if (stats != nullptr) { + runtime->GetCallStats(event, stats); + } + }; + } +} + +template +void WinogradInverseTransformFunctor::operator()(const Tensor *input_tensor, + Tensor *output_tensor, + StatsFuture *future) { + std::vector output_shape = {batch_, height_, width_, input_tensor->dim(1)}; + std::vector image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, image_shape); + output_tensor->ResizeImage(output_shape, image_shape); + + string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL("winograd_inverse_transform_2x2"); + std::set built_options; + built_options.emplace("-Dwinograd_inverse_transform_2x2=" + obfuscated_kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum::value)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum::value)); + if ((input_tensor->dim(1) % 4 == 0 || input_tensor->dim(0) == 1) && + input_tensor->dim(2) % 4 == 0) { + built_options.emplace("-DDIVISIBLE_FOUR"); + } + auto runtime = OpenCLRuntime::Global(); + auto b2f_kernel = runtime->BuildKernel("winograd_transform", + obfuscated_kernel_name, + built_options); + + const uint32_t round_h = (height_ + 1) / 2; + const uint32_t round_w = (width_ + 1) / 2; + uint32_t idx = 0; + b2f_kernel.setArg(idx++, *(static_cast(input_tensor->buffer()))); + b2f_kernel.setArg(idx++, *(static_cast(output_tensor->buffer()))); + b2f_kernel.setArg(idx++, static_cast(output_shape[1])); + b2f_kernel.setArg(idx++, static_cast(output_shape[2])); + b2f_kernel.setArg(idx++, static_cast(round_h * round_w)); + b2f_kernel.setArg(idx++, static_cast(round_w)); + + const size_t gws[2] = {static_cast(input_tensor->dim(2)), + static_cast(RoundUpDiv4(input_tensor->dim(1)))}; + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(b2f_kernel); + const std::vector lws = {128, 8}; + cl::Event event; + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + b2f_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1]), + cl::NDRange(lws[0], lws[1]), + nullptr, &event); + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + + if (future != nullptr) { + future->wait_fn = [runtime, event](CallStats *stats) { + event.wait(); + if (stats != nullptr) { + runtime->GetCallStats(event, stats); + } + }; + } +} + +template +struct WinogradTransformFunctor; +template +struct WinogradTransformFunctor; + +template +struct WinogradInverseTransformFunctor; +template +struct WinogradInverseTransformFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/winograd_transform.h b/mace/kernels/winograd_transform.h new file mode 100644 index 00000000..62284a07 --- /dev/null +++ b/mace/kernels/winograd_transform.h @@ -0,0 +1,88 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_WINOGRAD_TRANSFORM_H_ +#define MACE_KERNELS_WINOGRAD_TRANSFORM_H_ + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/kernels/conv_pool_2d_util.h" + +namespace mace { +namespace kernels { + +struct WinogradTransformFunctorBase { + WinogradTransformFunctorBase(const Padding &paddings) + : strides_({1, 1}), dilations_({1, 1}), paddings_(paddings) {} + + const std::vector strides_; // [stride_h, stride_w] + const std::vector dilations_; // [dilation_h, dilation_w] + Padding paddings_; +}; + +template +struct WinogradTransformFunctor : WinogradTransformFunctorBase { + WinogradTransformFunctor(const Padding &paddings) + : WinogradTransformFunctorBase(paddings) {} + + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + MACE_NOT_IMPLEMENTED; + } + +}; + +template +struct WinogradTransformFunctor : WinogradTransformFunctorBase { + WinogradTransformFunctor(const Padding &paddings) + : WinogradTransformFunctorBase(paddings) {} + + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future); +}; + +struct WinogradInverseTransformFunctorBase { + WinogradInverseTransformFunctorBase(const int batch, + const int height, + const int width) + : batch_(batch), height_(height), width_(width) {} + + const int batch_; + const int height_; + const int width_; +}; + +template +struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase { + WinogradInverseTransformFunctor(const int batch, + const int height, + const int width) + : WinogradInverseTransformFunctorBase(batch, height, width) {} + + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + MACE_NOT_IMPLEMENTED; + } + +}; + +template +struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase { + WinogradInverseTransformFunctor(const int batch, + const int height, + const int width) + : WinogradInverseTransformFunctorBase(batch, height, width) {} + + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future); +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_WINOGRAD_TRANSFORM_H_ diff --git a/mace/ops/addn_benchmark.cc b/mace/ops/addn_benchmark.cc index 41fb6e9e..7e9d9856 100644 --- a/mace/ops/addn_benchmark.cc +++ b/mace/ops/addn_benchmark.cc @@ -23,7 +23,7 @@ static void AddNBenchmark(int iters, int inputs, int n, int h, int w, int c) { for (int i = 0; i < inputs; ++i) { BufferToImage(net, internal::MakeString("Input", i).c_str(), internal::MakeString("InputImage", i).c_str(), - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } OpDefBuilder op_def_builder("AddN", "AddNBM"); for (int i = 0; i < inputs; ++i) { diff --git a/mace/ops/addn_test.cc b/mace/ops/addn_test.cc index 5f9bd2bf..691b1571 100644 --- a/mace/ops/addn_test.cc +++ b/mace/ops/addn_test.cc @@ -104,7 +104,7 @@ void RandomTest() { for (int i = 0; i < input_num; ++i) { BufferToImage(net, "Input" + ToString(i), "InputImage" + ToString(i), - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } auto op_def_cl = OpDefBuilder("AddN", "AddNTest"); @@ -119,7 +119,7 @@ void RandomTest() { net.RunOp(D); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.1); } diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index 976bc241..abfe85a6 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -24,7 +24,7 @@ static void BatchNorm( if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index 595635e7..a312df78 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -23,7 +23,7 @@ void Simple() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", @@ -47,7 +47,7 @@ void Simple() { // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("BatchNorm", "BatchNormTest") .Input("Input") @@ -204,7 +204,7 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { // Run on opencl BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", @@ -234,7 +234,7 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { net.Sync(); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-2); } @@ -276,7 +276,7 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) { // Run on opencl BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", @@ -307,7 +307,7 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) { net.Sync(); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.5); } @@ -349,7 +349,7 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { // Run on opencl BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", @@ -379,7 +379,7 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { net.Sync(); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-2); } @@ -421,7 +421,7 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) { // Run on opencl BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", @@ -452,7 +452,7 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) { net.Sync(); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.5); } } diff --git a/mace/ops/batch_to_space_benchmark.cc b/mace/ops/batch_to_space_benchmark.cc index 93df21f9..02da45ca 100644 --- a/mace/ops/batch_to_space_benchmark.cc +++ b/mace/ops/batch_to_space_benchmark.cc @@ -15,7 +15,7 @@ static void BMBatchToSpace( OpsTestNet net; net.AddRandomInput("Input", {batch, height, width, channels}); BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") .Input("InputImage") .Output("OutputImage") diff --git a/mace/ops/bias_add_benchmark.cc b/mace/ops/bias_add_benchmark.cc index 917c28a1..09f96267 100644 --- a/mace/ops/bias_add_benchmark.cc +++ b/mace/ops/bias_add_benchmark.cc @@ -20,7 +20,7 @@ static void BiasAdd(int iters, int batch, int channels, int height, int width) { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); OpDefBuilder("BiasAdd", "BiasAddBM") diff --git a/mace/ops/bias_add_test.cc b/mace/ops/bias_add_test.cc index ce83ebd7..91bc96e4 100644 --- a/mace/ops/bias_add_test.cc +++ b/mace/ops/bias_add_test.cc @@ -20,7 +20,7 @@ void BiasAddSimple() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); @@ -34,7 +34,7 @@ void BiasAddSimple() { // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("BiasAdd", "BiasAddTest") .Input("Input") @@ -90,7 +90,7 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) { // Run on opencl BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); @@ -105,7 +105,7 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) { net.Sync(); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-2); } @@ -140,7 +140,7 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { // Run on opencl BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); @@ -155,7 +155,7 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { net.Sync(); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-2); } } diff --git a/mace/ops/buffer_to_image_test.cc b/mace/ops/buffer_to_image_test.cc index f77bbde0..760103d0 100644 --- a/mace/ops/buffer_to_image_test.cc +++ b/mace/ops/buffer_to_image_test.cc @@ -55,23 +55,23 @@ TEST(BufferToImageTest, ArgLarge) { } TEST(BufferToImageTest, InputSmallSingleChannel) { - TestBidirectionTransform(kernels::IN_OUT, {1, 2, 3, 1}); + TestBidirectionTransform(kernels::IN_OUT_CHANNEL, {1, 2, 3, 1}); } TEST(BufferToImageTest, InputSmallMultipleChannel) { - TestBidirectionTransform(kernels::IN_OUT, {1, 2, 3, 3}); + TestBidirectionTransform(kernels::IN_OUT_CHANNEL, {1, 2, 3, 3}); } TEST(BufferToImageTest, InputSmallMultipleBatchAndChannel) { - TestBidirectionTransform(kernels::IN_OUT, {3, 2, 3, 3}); + TestBidirectionTransform(kernels::IN_OUT_CHANNEL, {3, 2, 3, 3}); } TEST(BufferToImageTest, InputMedia) { - TestBidirectionTransform(kernels::IN_OUT, {3, 13, 17, 128}); + TestBidirectionTransform(kernels::IN_OUT_CHANNEL, {3, 13, 17, 128}); } TEST(BufferToImageTest, InputLarge) { - TestBidirectionTransform(kernels::IN_OUT, {3, 64, 64, 256}); + TestBidirectionTransform(kernels::IN_OUT_CHANNEL, {3, 64, 64, 256}); } TEST(BufferToImageTest, Filter1x1Small) { @@ -124,7 +124,7 @@ void TestDiffTypeBidirectionTransform(const int type, const std::vector net.RunOp(D); // Check - ExpectTensorNear(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-3); + ExpectTensorNear(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-2); } TEST(BufferToImageTest, ArgFloatToHalfSmall) { diff --git a/mace/ops/concat_benchmark.cc b/mace/ops/concat_benchmark.cc index 6a3dda02..11d7de4b 100644 --- a/mace/ops/concat_benchmark.cc +++ b/mace/ops/concat_benchmark.cc @@ -61,9 +61,9 @@ static void OpenclConcatHelper(int iters, net.AddRandomInput("Input1", shape1); BufferToImage(net, "Input0", "InputImage0", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Input1", "InputImage1", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("Concat", "ConcatBM") .Input("InputImage0") .Input("InputImage1") diff --git a/mace/ops/concat_test.cc b/mace/ops/concat_test.cc index 49d55d2a..dff64dbf 100644 --- a/mace/ops/concat_test.cc +++ b/mace/ops/concat_test.cc @@ -153,7 +153,7 @@ void OpenclRandomTest(const std::vector> &shapes, concat_axis_size += shapes[i][axis]; net.AddRandomInput(input_name, shapes[i]); BufferToImage(net, input_name, image_name, - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } auto builder = OpDefBuilder("Concat", "ConcatTest"); @@ -170,7 +170,7 @@ void OpenclRandomTest(const std::vector> &shapes, net.RunOp(DeviceType::OPENCL); ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); // Check auto output = net.GetOutput("Output"); diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 42d187b3..713e08eb 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -34,7 +34,7 @@ static void Conv2d(int iters, if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -96,17 +96,20 @@ static void Conv2d(int iters, BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL); // ICNet -BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, VALID, 1024, half); -// SNPE GPU ExecutionDuration = 448us, % ALU Utilization = 105 -BM_CONV_2D(1, 64, 60, 60, 1, 1, 1, VALID, 128, half); -// SNPE GPU ExecutionDuration = 258us, % ALU Utilization = 108 -BM_CONV_2D(1, 32, 60, 60, 1, 1, 1, VALID, 128, half); - -BM_CONV_2D(1, 128, 60, 60, 3, 3, 1, VALID, 128, half); -// SNPE GPU ExecutionDuration = 506us, % ALU Utilization = 106.8 -BM_CONV_2D(1, 32, 60, 60, 3, 3, 1, SAME, 32, half); -BM_CONV_2D(1, 3, 512, 512, 7, 7, 2, SAME, 64, half); -BM_CONV_2D(1, 512, 64, 64, 1, 1, 1, SAME, 256, half); +//BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, VALID, 1024, half); +//// SNPE GPU ExecutionDuration = 448us, % ALU Utilization = 105 +//BM_CONV_2D(1, 64, 60, 60, 1, 1, 1, VALID, 128, half); +//// SNPE GPU ExecutionDuration = 258us, % ALU Utilization = 108 +//BM_CONV_2D(1, 32, 60, 60, 1, 1, 1, VALID, 128, half); +// +//BM_CONV_2D(1, 128, 60, 60, 3, 3, 1, VALID, 128, half); +//// SNPE GPU ExecutionDuration = 506us, % ALU Utilization = 106.8 +//BM_CONV_2D(1, 32, 60, 60, 3, 3, 1, SAME, 32, half); +//BM_CONV_2D(1, 3, 512, 512, 7, 7, 2, SAME, 64, half); +//BM_CONV_2D(1, 512, 64, 64, 1, 1, 1, SAME, 256, half); +BM_CONV_2D(1, 128, 16, 16, 3, 3, 1, VALID, 32, half); +BM_CONV_2D(1, 128, 64, 64, 3, 3, 1, VALID, 32, half); +BM_CONV_2D(1, 128, 128, 128, 3, 3, 1, VALID, 32, half); // Test RGB <-> YUV // BM_CONV_2D(1, 3, 2160, 1080, 1, 1, 1, VALID, 3, float); diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index a12842e2..877da76d 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -100,7 +100,7 @@ void TestNHWCSimple3x3VALID() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -120,7 +120,7 @@ void TestNHWCSimple3x3VALID() { // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("Conv2D", "Conv2dTest") @@ -157,7 +157,7 @@ void TestNHWCSimple3x3SAME() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -177,7 +177,7 @@ void TestNHWCSimple3x3SAME() { // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("Conv2D", "Conv2dTest") @@ -262,7 +262,7 @@ void TestNHWCSimple3x3WithoutBias() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); @@ -279,7 +279,7 @@ void TestNHWCSimple3x3WithoutBias() { net.RunOp(D); // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("Conv2D", "Conv2dTest") .Input("Input") @@ -369,7 +369,7 @@ static void TestNHWCCombined3x3() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -389,7 +389,7 @@ static void TestNHWCCombined3x3() { net.RunOp(D); ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("Conv2D", "Conv2DTest") .Input("Input") @@ -442,7 +442,7 @@ void TestConv1x1() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -461,7 +461,7 @@ void TestConv1x1() { net.RunOp(D); ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("Conv2D", "Conv2DTest") .Input("Input") @@ -533,7 +533,7 @@ static void TestComplexConvNxNS12(const std::vector &shape) { // run on gpu BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -553,7 +553,7 @@ static void TestComplexConvNxNS12(const std::vector &shape) { net.RunOp(D); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.001); }; @@ -626,7 +626,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &input_shape, // run on gpu BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -646,7 +646,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &input_shape, net.RunOp(D); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.5); }; @@ -758,7 +758,7 @@ static void TestDilationConvNxN(const std::vector &shape, const int dil expected.Copy(*net.GetOutput("Output")); // run on gpu - BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); @@ -775,7 +775,7 @@ static void TestDilationConvNxN(const std::vector &shape, const int dil // Run on device net.RunOp(D); - ImageToBuffer(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT); + ImageToBuffer(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.001); }; diff --git a/mace/ops/folded_batch_norm_test.cc b/mace/ops/folded_batch_norm_test.cc index 5ee0a947..45bd6736 100644 --- a/mace/ops/folded_batch_norm_test.cc +++ b/mace/ops/folded_batch_norm_test.cc @@ -38,7 +38,7 @@ void Simple() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", @@ -55,7 +55,7 @@ void Simple() { // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") .Input("Input") @@ -204,7 +204,7 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) { // Run on opencl BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", @@ -222,7 +222,7 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) { net.Sync(); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-2); } @@ -259,7 +259,7 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) { // Run on opencl BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", @@ -278,7 +278,7 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) { net.Sync(); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.5); } @@ -315,7 +315,7 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) { // Run on opencl BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", @@ -332,7 +332,7 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) { net.RunOp(DeviceType::OPENCL); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-2); } @@ -369,7 +369,7 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) { // Run on opencl BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Scale", "ScaleImage", kernels::BufferType::ARGUMENT); BufferToImage(net, "Offset", "OffsetImage", @@ -387,7 +387,7 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) { net.RunOp(DeviceType::OPENCL); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.5); } } diff --git a/mace/ops/fused_conv_2d_test.cc b/mace/ops/fused_conv_2d_test.cc index bdc4c3cf..87d99b9e 100644 --- a/mace/ops/fused_conv_2d_test.cc +++ b/mace/ops/fused_conv_2d_test.cc @@ -24,7 +24,7 @@ void TestNHWCSimple3x3VALID() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -44,7 +44,7 @@ void TestNHWCSimple3x3VALID() { // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("FusedConv2D", "FusedConv2dTest") @@ -81,7 +81,7 @@ void TestNHWCSimple3x3SAME() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -101,7 +101,7 @@ void TestNHWCSimple3x3SAME() { // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("FusedConv2D", "FusedConv2dTest") @@ -149,7 +149,7 @@ void TestNHWCSimple3x3WithoutBias() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); @@ -166,7 +166,7 @@ void TestNHWCSimple3x3WithoutBias() { net.RunOp(D); // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("FusedConv2D", "FusedConv2dTest") .Input("Input") @@ -218,7 +218,7 @@ void TestConv1x1() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -237,7 +237,7 @@ void TestConv1x1() { net.RunOp(D); ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("FusedConv2D", "FusedConv2dTest") .Input("Input") @@ -309,7 +309,7 @@ static void TestComplexConvNxNS12(const std::vector &shape) { // run on gpu BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -329,7 +329,7 @@ static void TestComplexConvNxNS12(const std::vector &shape) { net.RunOp(D); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.001); }; @@ -395,7 +395,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &shape) { // run on gpu BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -415,7 +415,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &shape) { net.RunOp(D); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.2); }; @@ -473,7 +473,7 @@ static void TestGeneralConvNxNS12(const std::vector &image_shape, // run on gpu BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", @@ -493,7 +493,7 @@ static void TestGeneralConvNxNS12(const std::vector &image_shape, net.RunOp(D); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.001); }; @@ -550,7 +550,7 @@ static void TestAtrousConvNxN(const std::vector &shape, const int dilat expected.Copy(*net.GetOutput("Output")); // run on gpu - BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); @@ -567,7 +567,7 @@ static void TestAtrousConvNxN(const std::vector &shape, const int dilat // Run on device net.RunOp(D); - ImageToBuffer(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT); + ImageToBuffer(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.001); }; @@ -632,7 +632,7 @@ static void TestGeneralHalfAtrousConv(const std::vector &image_shape, expected.Copy(*net.GetOutput("Output")); // run on gpu - BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); @@ -649,7 +649,7 @@ static void TestGeneralHalfAtrousConv(const std::vector &image_shape, // Run on device net.RunOp(D); - ImageToBuffer(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT); + ImageToBuffer(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.7); }; diff --git a/mace/ops/gemm.h b/mace/ops/gemm.h index 78d7e3b1..b4a96626 100644 --- a/mace/ops/gemm.h +++ b/mace/ops/gemm.h @@ -23,8 +23,9 @@ class GEMMOp : public Operator { MACE_CHECK(A->dim_size() == 4 && 4 == B->dim_size()) << "The dimension of A and B should be 4"; MACE_CHECK(A->dim(0) == B->dim(0)) << "A and B must have same batch size"; - MACE_CHECK(A->dim(3) == B->dim(1)) - << "the number of A's column must be equal to B's row"; + MACE_CHECK(A->dim(2) == B->dim(1)) + << "the number of A's column " << A->dim(2) + << " must be equal to B's row " << B->dim(1); functor_(A, B, C, future); return true; diff --git a/mace/ops/gemm_benchmark.cc b/mace/ops/gemm_benchmark.cc index 76dcc02a..eaffc006 100644 --- a/mace/ops/gemm_benchmark.cc +++ b/mace/ops/gemm_benchmark.cc @@ -16,14 +16,14 @@ static void GEMMBenchmark( OpsTestNet net; // Add input data - net.AddRandomInput("A", {batch, height, 1, channels}); - net.AddRandomInput("B", {batch, channels, 1, out_width}); + net.AddRandomInput("A", {batch, height, channels, 1}); + net.AddRandomInput("B", {batch, channels, out_width, 1}); if (D == DeviceType::OPENCL) { BufferToImage(net, "A", "AImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_WIDTH); BufferToImage(net, "B", "BImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_HEIGHT); OpDefBuilder("GEMM", "GEMMBM") .Input("AImage") @@ -53,17 +53,18 @@ static void GEMMBenchmark( } #define BM_GEMM_MACRO(N, H, C, W, TYPE, DEVICE) \ - static void BM_GEMM_##N##H##C##W##_##TYPE##_##DEVICE(int iters) { \ + static void BM_GEMM_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE(int iters) { \ const int64_t tot = static_cast(iters) * N * C * H * W; \ mace::testing::ItemsProcessed(tot); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ GEMMBenchmark(iters, N, H, C, W); \ } \ - BENCHMARK(BM_GEMM_##N##H##C##W##_##TYPE##_##DEVICE) + BENCHMARK(BM_GEMM_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE) #define BM_GEMM(N, H, C, W, TYPE) \ BM_GEMM_MACRO(N, H, C, W, TYPE, OPENCL); -BM_GEMM(16, 32, 128, 1024, half); -BM_GEMM(36, 32, 128, 256, half); +BM_GEMM(16, 32, 128, 49, half); +BM_GEMM(16, 32, 128, 961, half); +BM_GEMM(16, 32, 128, 3969, half); } // namespace mace diff --git a/mace/ops/gemm_test.cc b/mace/ops/gemm_test.cc index c1d94889..9d01124c 100644 --- a/mace/ops/gemm_test.cc +++ b/mace/ops/gemm_test.cc @@ -25,9 +25,9 @@ void Simple(const std::vector &A_shape, if (D == DeviceType::OPENCL) { BufferToImage(net, "A", "AImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_WIDTH); BufferToImage(net, "B", "BImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_HEIGHT); OpDefBuilder("GEMM", "GEMMTest") .Input("AImage") @@ -39,7 +39,7 @@ void Simple(const std::vector &A_shape, // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_HEIGHT); } else { OpDefBuilder("GEMM", "GEMMTest") .Input("A") @@ -58,37 +58,50 @@ void Simple(const std::vector &A_shape, } TEST_F(GEMMOpTest, SimpleCPU) { - Simple({1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, - {1, 3, 1, 2}, {1, 2, 3, 4, 5, 6}, - {1, 2, 1, 2}, {22, 28, 49, 64}); - Simple({1, 5, 1, 5}, + Simple({1, 2, 3, 1}, {1, 2, 3, 4, 5, 6}, + {1, 3, 2, 1}, {1, 2, 3, 4, 5, 6}, + {1, 2, 2, 1}, {22, 28, 49, 64}); + Simple({1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, - {1, 5, 1, 5}, + {1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, - {1, 5, 1, 5}, + {1, 5, 5, 1}, {215, 230, 245, 260, 275, 490, 530, 570, 610, 650, 765, 830, 895, 960, 1025, 1040, 1130, 1220, 1310, 1400, 1315, 1430, 1545, 1660, 1775}); } + +TEST_F(GEMMOpTest, SimpleCPUWithBatch) { + Simple({2, 2, 3, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, + {2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, + {2, 2, 2, 1}, {22, 28, 49, 64, 22, 28, 49, 64}); +} + TEST_F(GEMMOpTest, SimpleOPENCL) { - Simple({1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, - {1, 3, 1, 2}, {1, 2, 3, 4, 5, 6}, - {1, 2, 1, 2}, {22, 28, 49, 64}); - Simple({1, 5, 1, 5}, + Simple({1, 2, 3, 1}, {1, 2, 3, 4, 5, 6}, + {1, 3, 2, 1}, {1, 2, 3, 4, 5, 6}, + {1, 2, 2, 1}, {22, 28, 49, 64}); + Simple({1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, - {1, 5, 1, 5}, + {1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, - {1, 5, 1, 5}, + {1, 5, 5, 1}, {215, 230, 245, 260, 275, 490, 530, 570, 610, 650, 765, 830, 895, 960, 1025, 1040, 1130, 1220, 1310, 1400, 1315, 1430, 1545, 1660, 1775}); } +TEST_F(GEMMOpTest, SimpleGPUWithBatch) { + Simple({2, 2, 3, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, + {2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, + {2, 2, 2, 1}, {22, 28, 49, 64, 22, 28, 49, 64}); +} + template void Complex(const index_t batch, const index_t height, @@ -106,9 +119,9 @@ void Complex(const index_t batch, // Add input data net.AddRandomInput( - "A", {batch, height, 1, channels}); + "A", {batch, height, channels, 1}); net.AddRandomInput( - "B", {batch, channels, 1, out_width}); + "B", {batch, channels, out_width, 1}); // run cpu net.RunOp(); @@ -119,9 +132,9 @@ void Complex(const index_t batch, // Run on opencl BufferToImage(net, "A", "AImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_WIDTH); BufferToImage(net, "B", "BImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_HEIGHT); OpDefBuilder("GEMM", "GEMMTest") .Input("AImage") @@ -132,10 +145,9 @@ void Complex(const index_t batch, // Run on opencl net.RunOp(DeviceType::OPENCL); - net.Sync(); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_HEIGHT); if (DataTypeToEnum::value == DataType::DT_HALF) { ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-1); } else { @@ -152,8 +164,8 @@ TEST_F(GEMMOpTest, OPENCLUnAlignedWithoutBatch) { Complex(1, 113, 31, 73); } TEST_F(GEMMOpTest, OPENCLUnAlignedWithBatch) { - Complex(2, 31, 113, 61); - Complex(16, 32, 64, 64); + Complex(2, 3, 3, 3); + Complex(16, 31, 61, 67); Complex(31, 31, 61, 67); } TEST_F(GEMMOpTest, OPENCLHalfAlignedWithoutBatch) { diff --git a/mace/ops/pooling_test.cc b/mace/ops/pooling_test.cc index c802c126..bf4cff8b 100644 --- a/mace/ops/pooling_test.cc +++ b/mace/ops/pooling_test.cc @@ -134,7 +134,7 @@ static void SimpleMaxPooling3S2() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("Pooling", "PoolingTest") .Input("InputImage") .Output("OutputImage") @@ -146,7 +146,7 @@ static void SimpleMaxPooling3S2() { .Finalize(net.NewOperatorDef()); net.RunOp(D); ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { // Run OpDefBuilder("Pooling", "PoolingTest") @@ -198,7 +198,7 @@ static void MaxPooling3S2(const std::vector &input_shape, Tensor expected; expected.Copy(*net.GetOutput("Output")); - BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("Pooling", "PoolingTest") .Input("InputImage") .Output("OutputImage") @@ -211,7 +211,7 @@ static void MaxPooling3S2(const std::vector &input_shape, .Finalize(net.NewOperatorDef()); net.RunOp(D); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.001); } @@ -283,7 +283,7 @@ static void SimpleAvgPoolingTest() { {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("Pooling", "PoolingTest") .Input("InputImage") .Output("OutputImage") @@ -296,7 +296,7 @@ static void SimpleAvgPoolingTest() { // Run net.RunOp(D); ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); // Check auto expected = CreateTensor({1, 1, 4, 1}, {4.5, 6.5, 8.5, 10.5}); @@ -333,7 +333,7 @@ static void AvgPoolingTest(const std::vector &shape, Tensor expected; expected.Copy(*net.GetOutput("Output")); - BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("Pooling", "PoolingTest") .Input("InputImage") .Output("OutputImage") @@ -346,7 +346,7 @@ static void AvgPoolingTest(const std::vector &shape, .Finalize(net.NewOperatorDef()); net.RunOp(D); ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.01); } diff --git a/mace/ops/resize_bilinear_benchmark.cc b/mace/ops/resize_bilinear_benchmark.cc index 46b96123..01ffda0e 100644 --- a/mace/ops/resize_bilinear_benchmark.cc +++ b/mace/ops/resize_bilinear_benchmark.cc @@ -27,7 +27,7 @@ static void ResizeBilinearBenchmark(int iters, {output_height, output_width}); if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark") .Input("InputImage") .Input("OutSize") diff --git a/mace/ops/resize_bilinear_test.cc b/mace/ops/resize_bilinear_test.cc index 06b715a0..129a627a 100644 --- a/mace/ops/resize_bilinear_test.cc +++ b/mace/ops/resize_bilinear_test.cc @@ -92,7 +92,7 @@ void TestRandomResizeBilinear() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") .Input("InputImage") @@ -104,7 +104,7 @@ void TestRandomResizeBilinear() { net.RunOp(D); ImageToBuffer(net, "OutputImage", "DeviceOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { // TODO support NEON } diff --git a/mace/ops/softmax_benchmark.cc b/mace/ops/softmax_benchmark.cc index 030af807..267074a7 100644 --- a/mace/ops/softmax_benchmark.cc +++ b/mace/ops/softmax_benchmark.cc @@ -20,7 +20,7 @@ static void SoftmaxBenchmark( if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("Softmax", "SoftmaxBM") .Input("InputImage") diff --git a/mace/ops/softmax_test.cc b/mace/ops/softmax_test.cc index b4f321a6..af8e3afc 100644 --- a/mace/ops/softmax_test.cc +++ b/mace/ops/softmax_test.cc @@ -18,7 +18,7 @@ void Simple() { if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("Softmax", "SoftmaxTest") .Input("InputImage") @@ -30,7 +30,7 @@ void Simple() { // Transfer output ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("Softmax", "SoftmaxTest") .Input("Input") @@ -72,7 +72,7 @@ void Complex(const std::vector &logits_shape) { expected.Copy(*net.GetOutput("Output")); BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("Softmax", "SoftmaxTest") .Input("InputImage") @@ -84,7 +84,7 @@ void Complex(const std::vector &logits_shape) { // Transfer output ImageToBuffer(net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-5); } diff --git a/mace/ops/space_to_batch_benchmark.cc b/mace/ops/space_to_batch_benchmark.cc index a2fea8dc..9b3e4d1c 100644 --- a/mace/ops/space_to_batch_benchmark.cc +++ b/mace/ops/space_to_batch_benchmark.cc @@ -16,7 +16,7 @@ static void BMSpaceToBatch( net.AddRandomInput("Input", {batch, height, width, channels}); BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") .Input("InputImage") .Output("OutputImage") diff --git a/mace/ops/space_to_batch_test.cc b/mace/ops/space_to_batch_test.cc index bebbafef..56d37611 100644 --- a/mace/ops/space_to_batch_test.cc +++ b/mace/ops/space_to_batch_test.cc @@ -18,7 +18,7 @@ void RunSpaceToBatch(const std::vector &input_shape, net.AddInputFromArray("Input", input_shape, input_data); BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") .Input("InputImage") .Output("OutputImage") @@ -30,7 +30,7 @@ void RunSpaceToBatch(const std::vector &input_shape, net.RunOp(D); ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); // Check ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-8); } @@ -46,7 +46,7 @@ void RunBatchToSpace(const std::vector &input_shape, net.AddInputFromArray("Input", input_shape, input_data); BufferToImage(net, "Input", "InputImage", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") .Input("InputImage") .Output("OutputImage") @@ -58,7 +58,7 @@ void RunBatchToSpace(const std::vector &input_shape, net.RunOp(D); ImageToBuffer(net, "OutputImage", "Output", - kernels::BufferType::IN_OUT); + kernels::BufferType::IN_OUT_CHANNEL); // Check ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-8); } diff --git a/mace/ops/winograd_inverse_transform.cc b/mace/ops/winograd_inverse_transform.cc new file mode 100644 index 00000000..4f81a1b0 --- /dev/null +++ b/mace/ops/winograd_inverse_transform.cc @@ -0,0 +1,22 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/winograd_inverse_transform.h" + +namespace mace { + +void Register_WinogradInverseTransform(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("WinogradInverseTransform") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + WinogradInverseTransformOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("WinogradInverseTransform") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + WinogradInverseTransformOp); +} + +} // namespace mace diff --git a/mace/ops/winograd_inverse_transform.h b/mace/ops/winograd_inverse_transform.h new file mode 100644 index 00000000..c620246c --- /dev/null +++ b/mace/ops/winograd_inverse_transform.h @@ -0,0 +1,42 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_WINOGRAD_INVERSE_TRANSFORM_H_ +#define MACE_OPS_WINOGRAD_INVERSE_TRANSFORM_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/winograd_transform.h" + +namespace mace { + +template +class WinogradInverseTransformOp : public Operator { + public: + WinogradInverseTransformOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + functor_(OperatorBase::GetSingleArgument("batch", 1), + OperatorBase::GetSingleArgument("height", 0), + OperatorBase::GetSingleArgument("width", 0)) {} + + bool Run(StatsFuture *future) override { + const Tensor *input_tensor = this->Input(INPUT); + Tensor *output_tensor = this->Output(OUTPUT); + + functor_(input_tensor, output_tensor, future); + return true; + } + + private: + kernels::WinogradInverseTransformFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_WINOGRAD_INVERSE_TRANSFORM_H_ diff --git a/mace/ops/winograd_transform.cc b/mace/ops/winograd_transform.cc new file mode 100644 index 00000000..369a6218 --- /dev/null +++ b/mace/ops/winograd_transform.cc @@ -0,0 +1,22 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/winograd_transform.h" + +namespace mace { + +void Register_WinogradTransform(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("WinogradTransform") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + WinogradTransformOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("WinogradTransform") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + WinogradTransformOp); +} + +} // namespace mace diff --git a/mace/ops/winograd_transform.h b/mace/ops/winograd_transform.h new file mode 100644 index 00000000..f2cc5f10 --- /dev/null +++ b/mace/ops/winograd_transform.h @@ -0,0 +1,41 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_WINOGRAD_TRANSFORM_H_ +#define MACE_OPS_WINOGRAD_TRANSFORM_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/winograd_transform.h" + +namespace mace { + +template +class WinogradTransformOp : public Operator { + public: + WinogradTransformOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + functor_(static_cast(OperatorBase::GetSingleArgument( + "padding", static_cast(VALID)))) {} + + bool Run(StatsFuture *future) override { + const Tensor *input_tensor = this->Input(INPUT); + Tensor *output_tensor = this->Output(OUTPUT); + + functor_(input_tensor, output_tensor, future); + return true; + } + + private: + kernels::WinogradTransformFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_WINOGRAD_TRANSFORM_H_ diff --git a/mace/ops/winograd_transform_benchmark.cc b/mace/ops/winograd_transform_benchmark.cc new file mode 100644 index 00000000..ca1e9bea --- /dev/null +++ b/mace/ops/winograd_transform_benchmark.cc @@ -0,0 +1,111 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +template +static void BMWinogradTransform( + int iters, int batch, int height, int width, int channels) { + mace::testing::StopTiming(); + + OpsTestNet net; + net.AddRandomInput("Input", {batch, height, width, channels}); + + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("WinogradTransform", "WinogradTransformTest") + .Input("InputImage") + .Output("OutputImage") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_WINOGRAD_TRANSFORM_MACRO(N, H, W, C, TYPE, DEVICE) \ + static void \ + BM_WINOGRAD_TRANSFORM_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + BMWinogradTransform(iters, N, H, W, C); \ + } \ + BENCHMARK( \ + BM_WINOGRAD_TRANSFORM_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE) + +#define BM_WINOGRAD_TRANSFORM(N, H, W, C, TYPE) \ + BM_WINOGRAD_TRANSFORM_MACRO(N, H, W, C, TYPE, OPENCL); + +BM_WINOGRAD_TRANSFORM(1, 16, 16, 128, half); +BM_WINOGRAD_TRANSFORM(1, 64, 64, 128, half); +BM_WINOGRAD_TRANSFORM(1, 128, 128, 128, half); +BM_WINOGRAD_TRANSFORM(1, 256, 256, 32, half); + +template +static void BMWinogradInverseTransform( + int iters, int batch, int height, int width, int channels) { + mace::testing::StopTiming(); + + index_t p = batch * ((height + 1) / 2) * ((width + 1) / 2); + OpsTestNet net; + net.AddRandomInput("Input", {16, channels, p, 1}); + + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT_HEIGHT); + OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest") + .Input("InputImage") + .AddIntArg("batch", batch) + .AddIntArg("height", height) + .AddIntArg("width", width) + .Output("OutputImage") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_WINOGRAD_INVERSE_TRANSFORM_MACRO(N, H, W, C, TYPE, DEVICE) \ + static void \ + BM_WINOGRAD_INVERSE_TRANSFORM_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + BMWinogradInverseTransform(iters, N, H, W, C); \ + } \ + BENCHMARK( \ + BM_WINOGRAD_INVERSE_TRANSFORM_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE) + +#define BM_WINOGRAD_INVERSE_TRANSFORM(N, H, W, C, TYPE) \ + BM_WINOGRAD_INVERSE_TRANSFORM_MACRO(N, H, W, C, TYPE, OPENCL); + +BM_WINOGRAD_INVERSE_TRANSFORM(1, 14, 14, 32, half); +BM_WINOGRAD_INVERSE_TRANSFORM(1, 62, 62, 32, half); +BM_WINOGRAD_INVERSE_TRANSFORM(1, 126, 126, 32, half); + +} // namespace mace \ No newline at end of file diff --git a/mace/ops/winograd_transform_test.cc b/mace/ops/winograd_transform_test.cc new file mode 100644 index 00000000..f2a04656 --- /dev/null +++ b/mace/ops/winograd_transform_test.cc @@ -0,0 +1,334 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" +#include "mace/kernels/conv_pool_2d_util.h" + +namespace mace { + +class WinogradTransformOpTest : public OpsTestBase {}; + +//TEST_F(WinogradTransformOpTest, WinogradInputTransform) { +// srand(time(NULL)); +// +// // generate random input +// index_t batch = 7; +// index_t height = 61; +// index_t width = 71; +// index_t channels = 31; +// +// index_t p = batch * ((height - 1) / 2) * ((width - 1) / 2); +// +// const std::string A_file = "/data/local/tmp/test/A"; +// const std::string C_file = "/data/local/tmp/test/C"; +// const std::vector A_shape = {batch, height, width, channels}; +// const int A_size = std::accumulate(A_shape.begin(), A_shape.end(), 1, std::multiplies()); +// const std::vector C_shape = {16, channels, p, 1}; +// const int C_size = std::accumulate(C_shape.begin(), C_shape.end(), 1, std::multiplies()); +// +// std::vector A_data(A_size, 0.0); +// std::ifstream in_file(A_file, std::ios::in | std::ios::binary); +// if (in_file.is_open()) { +// in_file.read(reinterpret_cast(A_data.data()), +// A_size * sizeof(float)); +// in_file.close(); +// } else { +// VLOG(0) << "open A file failed"; +// } +// auto C_tensor = unique_ptr(new Tensor(GetDeviceAllocator(DeviceType::OPENCL), +// DataTypeToEnum::v())); +// C_tensor->Resize(C_shape); +// std::vector C_data(C_size, 0.0); +// std::ifstream C_in_file(C_file, std::ios::in | std::ios::binary); +// if (C_in_file.is_open()) { +// C_in_file.read(reinterpret_cast(C_data.data()), +// C_size * sizeof(float)); +// C_in_file.close(); +// Tensor::MappingGuard C_mapper(C_tensor.get()); +// float *batch_ptr = C_tensor->mutable_data(); +// MACE_CHECK(static_cast(C_tensor->size()) == +// C_data.size()); +// memcpy(batch_ptr, C_data.data(), C_data.size() * sizeof(float)); +// } else { +// VLOG(0) << "open C file failed"; +// } +// // Construct graph +// OpsTestNet net; +// // Add input data +// net.AddInputFromArray( +// "A", A_shape, A_data); +// +// // Run on opencl +// BufferToImage(net, "A", "AImage", +// kernels::BufferType::IN_OUT_CHANNEL); +// +// OpDefBuilder("WinogradTransform", "WinogradTransformTest") +// .Input("AImage") +// .Output("OutputImage") +// .Finalize(net.NewOperatorDef()); +// +// // Run on opencl +// net.RunOp(DeviceType::OPENCL); +// net.Sync(); +// +// ImageToBuffer(net, "OutputImage", "OPENCLOutput", +// kernels::BufferType::IN_OUT_HEIGHT); +// ExpectTensorNear(*(C_tensor.get()), *net.GetOutput("OPENCLOutput"), 1e-4); +//} +// +//TEST_F(WinogradTransformOpTest, FilterTransform) { +// srand(time(NULL)); +// +// // generate random input +// index_t out_chan = 31; +// index_t in_chan = 31; +// index_t height = 3; +// index_t width = 3; +// +// index_t p = (in_chan + 3) / 4; +// +// const std::string A_file = "/data/local/tmp/test/filter_in"; +// const std::string C_file = "/data/local/tmp/test/filter_out"; +// const std::vector A_shape = {out_chan, in_chan, height, width}; +// const int A_size = std::accumulate(A_shape.begin(), A_shape.end(), 1, std::multiplies()); +// const std::vector C_shape = {16, out_chan, in_chan, 1}; +// const int C_size = std::accumulate(C_shape.begin(), C_shape.end(), 1, std::multiplies()); +// +// std::vector A_data(A_size, 0.0); +// std::ifstream in_file(A_file, std::ios::in | std::ios::binary); +// if (in_file.is_open()) { +// in_file.read(reinterpret_cast(A_data.data()), +// A_size * sizeof(float)); +// in_file.close(); +// } else { +// VLOG(0) << "open A file failed"; +// } +// auto C_tensor = unique_ptr(new Tensor(GetDeviceAllocator(DeviceType::OPENCL), +// DataTypeToEnum::v())); +// C_tensor->Resize(C_shape); +// std::vector C_data(C_size, 0.0); +// std::ifstream C_in_file(C_file, std::ios::in | std::ios::binary); +// if (C_in_file.is_open()) { +// C_in_file.read(reinterpret_cast(C_data.data()), +// C_size * sizeof(float)); +// C_in_file.close(); +// Tensor::MappingGuard C_mapper(C_tensor.get()); +// float *batch_ptr = C_tensor->mutable_data(); +// MACE_CHECK(static_cast(C_tensor->size()) == +// C_data.size()); +// memcpy(batch_ptr, C_data.data(), C_data.size() * sizeof(float)); +// } else { +// VLOG(0) << "open C file failed"; +// } +// // Construct graph +// OpsTestNet net; +// // Add input data +// net.AddInputFromArray( +// "A", A_shape, A_data); +// +// // Run on opencl +// +// OpDefBuilder("BufferToImage", "WinogradFilterTransformTest") +// .Input("A") +// .AddIntArg("buffer_type", kernels::WINOGRAD_FILTER) +// .Output("OutputImage") +// .Finalize(net.NewOperatorDef()); +// +// // Run on opencl +// net.RunOp(DeviceType::OPENCL); +// +// ImageToBuffer(net, "OutputImage", "OPENCLOutput", +// kernels::BufferType::WINOGRAD_FILTER); +// ExpectTensorNear(*(C_tensor.get()), *net.GetOutput("OPENCLOutput"), 1e-4); +//} +// +// +//TEST_F(WinogradTransformOpTest, WinogradInverseTransform) { +// srand(time(NULL)); +// +// // generate random input +// index_t n = 7; +// index_t out_height = 59; +// index_t out_width = 69; +// index_t out_chan = 31; +// +// index_t p = n * ((out_height + 1) / 2) * ((out_width + 1) / 2); +// +// const std::string A_file = "/data/local/tmp/test/gemm"; +// const std::string C_file = "/data/local/tmp/test/res"; +// const std::vector A_shape = {16, out_chan, p, 1}; +// const int A_size = std::accumulate(A_shape.begin(), A_shape.end(), 1, std::multiplies()); +// const std::vector C_shape = {n, out_height, out_width, out_chan}; +// const int C_size = std::accumulate(C_shape.begin(), C_shape.end(), 1, std::multiplies()); +// +// std::vector A_data(A_size, 0.0); +// std::ifstream in_file(A_file, std::ios::in | std::ios::binary); +// if (in_file.is_open()) { +// in_file.read(reinterpret_cast(A_data.data()), +// A_size * sizeof(float)); +// in_file.close(); +// } else { +// VLOG(0) << "open A file failed"; +// } +// auto C_tensor = unique_ptr(new Tensor(GetDeviceAllocator(DeviceType::OPENCL), +// DataTypeToEnum::v())); +// C_tensor->Resize(C_shape); +// std::vector C_data(C_size, 0.0); +// std::ifstream C_in_file(C_file, std::ios::in | std::ios::binary); +// if (C_in_file.is_open()) { +// C_in_file.read(reinterpret_cast(C_data.data()), +// C_size * sizeof(float)); +// C_in_file.close(); +// Tensor::MappingGuard C_mapper(C_tensor.get()); +// float *batch_ptr = C_tensor->mutable_data(); +// MACE_CHECK(static_cast(C_tensor->size()) == +// C_data.size()); +// memcpy(batch_ptr, C_data.data(), C_data.size() * sizeof(float)); +// } else { +// VLOG(0) << "open C file failed"; +// } +// // Construct graph +// OpsTestNet net; +// // Add input data +// net.AddInputFromArray( +// "A", A_shape, A_data); +// +// // Run on opencl +// BufferToImage(net, "A", "AImage", +// kernels::BufferType::IN_OUT_HEIGHT); +// +// OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest") +// .Input("AImage") +// .AddIntArg("batch", n) +// .AddIntArg("height", out_height) +// .AddIntArg("width", out_width) +// .Output("OutputImage") +// .Finalize(net.NewOperatorDef()); +// +// // Run on opencl +// net.RunOp(DeviceType::OPENCL); +// net.Sync(); +// +// ImageToBuffer(net, "OutputImage", "OPENCLOutput", +// kernels::BufferType::IN_OUT_CHANNEL); +// ExpectTensorNear(*(C_tensor.get()), *net.GetOutput("OPENCLOutput"), 1e-4); +//} + +void TransposeFilter(const std::vector &input, + const std::vector &input_shape, + std::vector &output) { + output.resize(input.size()); + + const float *input_ptr = input.data(); + for (index_t h = 0; h < input_shape[0]; ++h) { + for (index_t w = 0; w < input_shape[1]; ++w) { + for (index_t ic = 0; ic < input_shape[2]; ++ic) { + for (index_t oc = 0; oc < input_shape[3]; ++oc) { + int offset = ((oc * input_shape[2] + ic) * input_shape[0] + h) * input_shape[1] + w; + output[offset] = *input_ptr; + ++input_ptr; + } + } + } + } +} + +template +void WinogradConvolution(const index_t batch, + const index_t height, + const index_t width, + const index_t in_channels, + const index_t out_channels, + const Padding padding) { + srand(time(NULL)); + + // Construct graph + OpsTestNet net; + // Add input data + std::vector filter_data; + std::vector filter_shape = {3, 3, in_channels, out_channels}; + GenerateRandomRealTypeData(filter_shape, filter_data); + net.AddRandomInput("Input", {batch, height, width, in_channels}); + net.AddInputFromArray("Filter", filter_shape, filter_data); + + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Filter", "FilterImage", + kernels::BufferType::FILTER); + OpDefBuilder("Conv2D", "Conv2dTest") + .Input("InputImage") + .Input("FilterImage") + .Output("OutputImage") + .AddIntsArg("strides", {1, 1}) + .AddIntArg("padding", padding) + .AddIntsArg("dilations", {1, 1}) + .Finalize(net.NewOperatorDef()); + + net.RunOp(D); + + // Transfer output + ImageToBuffer(net, "OutputImage", "ConvOutput", + kernels::BufferType::IN_OUT_CHANNEL); + Tensor expected; + expected.Copy(*net.GetOutput("ConvOutput")); + auto output_shape = expected.shape(); + + // Winograd convolution + // transform filter + std::vector wino_filter_data; + TransposeFilter(filter_data, filter_shape, wino_filter_data); + net.AddInputFromArray("WinoFilterData", {out_channels, in_channels, 3, 3}, wino_filter_data); + BufferToImage(net, "WinoFilterData", "WinoFilter", kernels::BufferType::WINOGRAD_FILTER); + + // transform input + OpDefBuilder("WinogradTransform", "WinogradTransformTest") + .Input("InputImage") + .Output("WinoInput") + .AddIntArg("padding", padding) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(D); + + // GEMM + OpDefBuilder("GEMM", "GEMMTest") + .Input("WinoFilter") + .Input("WinoInput") + .Output("WinoGemm") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + // Run on opencl + net.RunOp(D); + + // Inverse transform + OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest") + .Input("WinoGemm") + .AddIntArg("batch", batch) + .AddIntArg("height", output_shape[1]) + .AddIntArg("width", output_shape[2]) + .Output("WinoOutputImage") + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(D); + net.Sync(); + + ImageToBuffer(net, "WinoOutputImage", "WinoOutput", + kernels::BufferType::IN_OUT_CHANNEL); + if (DataTypeToEnum::value == DataType::DT_HALF) { + ExpectTensorNear(expected, *net.GetOutput("WinoOutput"), 1e-1); + } else { + ExpectTensorNear(expected, *net.GetOutput("WinoOutput"), 1e-4); + } +} + + +TEST_F(WinogradTransformOpTest, Convolution) { + WinogradConvolution(1, 64, 64, 32, 32, Padding::VALID); +} + +} -- GitLab