diff --git a/mace/kernels/conv_pool_2d_util.cc b/mace/kernels/conv_pool_2d_util.cc index d979ee44efc47eb1766a799563f7b2cb688b9920..44fc4f707a53c49404af4e457dc4587c11623e43 100644 --- a/mace/kernels/conv_pool_2d_util.cc +++ b/mace/kernels/conv_pool_2d_util.cc @@ -132,6 +132,8 @@ void ConstructInputWithPadding(const float *input, const int padded_left = paddings[1] / 2; output_tensor->Resize(output_shape); + + Tensor::MappingGuard padded_input_mapper(output_tensor); float *output_ptr = output_tensor->mutable_data(); memset(output_ptr, 0, output_tensor->size() * sizeof(float)); diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index dab8cebbdbfce86153896805382ad0fdb046fd10..840ce727570f108f365bfcf0b6402e030e18d7d2 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -20,27 +20,27 @@ struct DepthwiseConv2dFunctor { const int *dilations) : strides_(strides), paddings_(paddings), dilations_(dilations) {} - void operator()(const T *input, // NCHW - const index_t *input_shape, - const T *filter, // c_out, c_in, kernel_h, kernel_w - const index_t *filter_shape, - const T *bias, // c_out - T *output, // NCHW - const index_t *output_shape) { + void operator()(const Tensor *input, // NCHW + const Tensor *filter, // c_out, c_in, kernel_h, kernel_w + const Tensor *bias, // c_out + Tensor *output) { + MACE_CHECK_NOTNULL(input); + MACE_CHECK_NOTNULL(filter); + MACE_CHECK_NOTNULL(bias); MACE_CHECK_NOTNULL(output); - index_t batch = output_shape[0]; - index_t channels = output_shape[1]; - index_t height = output_shape[2]; - index_t width = output_shape[3]; + index_t batch = output->dim(0); + index_t channels = output->dim(1); + index_t height = output->dim(2); + index_t width = output->dim(3); - index_t input_batch = input_shape[0]; - index_t input_channels = input_shape[1]; - index_t input_height = input_shape[2]; - index_t input_width = input_shape[3]; + index_t input_batch = input->dim(0); + index_t input_channels = input->dim(1); + index_t input_height = input->dim(2); + index_t input_width = input->dim(3); - index_t kernel_h = filter_shape[2]; - index_t kernel_w = filter_shape[3]; + index_t kernel_h = filter->dim(2); + index_t kernel_w = filter->dim(3); int stride_h = strides_[0]; int stride_w = strides_[1]; @@ -56,20 +56,29 @@ struct DepthwiseConv2dFunctor { index_t padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2; index_t padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2; - index_t kernel_size = filter_shape[1] * kernel_h * kernel_w; - index_t multiplier = channels / input_channels; + index_t kernel_size = kernel_h * kernel_w; + index_t multiplier = filter->dim(0); + + Tensor::MappingGuard input_mapper(input); + Tensor::MappingGuard filter_mapper(filter); + Tensor::MappingGuard bias_mapper(bias); + Tensor::MappingGuard output_mapper(output); + const T *input_ptr = input->data(); + const T *filter_ptr = filter->data(); + const T *bias_ptr = bias->data(); + T *output_ptr = output->mutable_data(); #pragma omp parallel for collapse(2) for (int n = 0; n < batch; ++n) { for (int c = 0; c < channels; ++c) { - T bias_channel = bias ? bias[c] : 0; + T bias_channel = bias_ptr ? bias_ptr[c] : 0; for (int h = 0; h < height; ++h) { for (int w = 0; w < width; ++w) { index_t offset = n * channels * height * width + c * height * width + h * width + w; - output[offset] = bias_channel; + output_ptr[offset] = bias_channel; T sum = 0; - const T *filter_ptr = filter + c * kernel_size; + const T *filter_base = filter_ptr + c * kernel_size; for (int kh = 0; kh < kernel_h; ++kh) { for (int kw = 0; kw < kernel_w; ++kw) { int inh = padded_h_start + h * stride_h + dilation_h * kh; @@ -79,19 +88,17 @@ struct DepthwiseConv2dFunctor { MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop && inw >= padded_w_start && inw < padded_w_stop, "Out of range read from input: ", inh, ", ", inw); - // else padding with 0: - // sum += 0; } else { index_t input_offset = n * input_channels * input_height * input_width + (c / multiplier) * input_height * input_width + inh * input_width + inw; - sum += input[input_offset] * *filter_ptr; + sum += input_ptr[input_offset] * *filter_base; } - ++filter_ptr; + ++filter_base; } } - output[offset] += sum; + output_ptr[offset] += sum; } } } @@ -105,13 +112,18 @@ struct DepthwiseConv2dFunctor { template <> void DepthwiseConv2dFunctor::operator()( - const float *input, - const index_t *input_shape, - const float *filter, - const index_t *filter_shape, - const float *bias, - float *output, - const index_t *output_shape); + const Tensor *input, + const Tensor *filter, + const Tensor *bias, + Tensor *output); + +template <> +void DepthwiseConv2dFunctor::operator()( + const Tensor *input, + const Tensor *filter, + const Tensor *bias, + Tensor *output); + } // namespace kernels } // namespace mace diff --git a/mace/kernels/neon/conv_2d_neon_3x3.cc b/mace/kernels/neon/conv_2d_neon_3x3.cc index 9a88aa18222a77a1f35a5987712ac71b4516861c..4d36651fddd129a21b6c12b743f945cefa337d30 100644 --- a/mace/kernels/neon/conv_2d_neon_3x3.cc +++ b/mace/kernels/neon/conv_2d_neon_3x3.cc @@ -29,9 +29,9 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW int input_height = input_shape[2]; int input_width = input_shape[3]; int multiplier = - filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels); + filter_shape == nullptr ? 0 : filter_shape[0]; int filter_in_channels = - filter_shape == nullptr ? input_channels : filter_shape[1]; + filter_shape == nullptr ? input_channels : 1; #pragma omp parallel for collapse(2) for (int b = 0; b < output_batch; ++b) { for (int oc = 0; oc < output_channels; ++oc) { @@ -232,9 +232,9 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW int input_height = input_shape[2]; int input_width = input_shape[3]; int multiplier = - filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels); + filter_shape == nullptr ? 0 : filter_shape[0]; int filter_in_channels = - filter_shape == nullptr ? input_channels : filter_shape[1]; + filter_shape == nullptr ? input_channels : 1; #pragma omp parallel for collapse(2) for (int b = 0; b < output_batch; ++b) { diff --git a/mace/kernels/neon/depthwise_conv_neon.cc b/mace/kernels/neon/depthwise_conv_neon.cc index 75f01707a2fcbec70d393bcb1c605152cd6b207f..cbae961b8d180736be591c311aae8d507777710b 100644 --- a/mace/kernels/neon/depthwise_conv_neon.cc +++ b/mace/kernels/neon/depthwise_conv_neon.cc @@ -26,13 +26,10 @@ extern void Conv2dNeonK3x3S2(const float *input, template <> void DepthwiseConv2dFunctor::operator()( - const float *input, // NCHW - const index_t *input_shape, - const float *filter, // c_out, c_in, kernel_h, kernel_w - const index_t *filter_shape, - const float *bias, // c_out - float *output, // NCHW - const index_t *output_shape) { + const Tensor *input, + const Tensor *filter, + const Tensor *bias, + Tensor *output) { typedef void (*Conv2dNeonFunction)( const float *input, const index_t *input_shape, const float *filter, const index_t *filter_shape, const float *bias, float *output, @@ -45,8 +42,8 @@ void DepthwiseConv2dFunctor::operator()( {nullptr, nullptr}, {nullptr, nullptr}}; // not implement yet - index_t kernel_h = filter_shape[2]; - index_t kernel_w = filter_shape[3]; + index_t kernel_h = filter->dim(2); + index_t kernel_w = filter->dim(3); if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || selector[kernel_h - 1][strides_[0] - 1] == nullptr) { @@ -56,20 +53,27 @@ void DepthwiseConv2dFunctor::operator()( << " is not implemented yet, using slow version"; DepthwiseConv2dFunctor(strides_, paddings_, dilations_)( - input, input_shape, filter, filter_shape, bias, output, output_shape); + input, filter, bias, output); return; } + const float *input_ptr = input->data(); + const index_t *input_shape = input->shape().data(); + const float *filter_ptr = filter->data(); + const index_t *filter_shape = filter->shape().data(); + const float *bias_ptr = bias->data(); + float *output_ptr = output->mutable_data(); + const index_t *output_shape = output->shape().data(); // Keep this alive during kernel execution Tensor padded_input; if (paddings_[0] > 0 || paddings_[1] > 0) { - ConstructInputWithPadding(input, input_shape, paddings_.data(), + ConstructInputWithPadding(input_ptr, input_shape, paddings_.data(), &padded_input); - input = padded_input.data(); + input_ptr = padded_input.data(); input_shape = padded_input.shape().data(); } auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1]; - conv2d_neon_func(input, input_shape, filter, filter_shape, bias, output, + conv2d_neon_func(input_ptr, input_shape, filter_ptr, filter_shape, bias_ptr, output_ptr, output_shape); } diff --git a/mace/kernels/opencl/cl/depthwise_conv_3x3.cl b/mace/kernels/opencl/cl/depthwise_conv_3x3.cl new file mode 100644 index 0000000000000000000000000000000000000000..084156e13788e53769458ca901cbbc19cbc84f10 --- /dev/null +++ b/mace/kernels/opencl/cl/depthwise_conv_3x3.cl @@ -0,0 +1,89 @@ +inline float4 conv1x3(const float *input_ptr, + const float *filter_ptr) { + float8 input = vload8(0, input_ptr); + float4 row0 = convert_float4(input.s0123); + float4 row1 = convert_float4(input.s1234); + float4 row2 = convert_float4(input.s2345); + return (float4)filter_ptr[0] * row0 + (float4)filter_ptr[1] * row1 + + (float4)filter_ptr[2] * row2; +} + +inline float4 conv3x3x4(const float *input_ptr, + const float *filter_ptr, + const int row_width) { + float4 res; + res = conv1x3(input_ptr + 0 * row_width, filter_ptr + 0 * 3); + res += conv1x3(input_ptr + 1 * row_width, filter_ptr + 1 * 3); + res += conv1x3(input_ptr + 2 * row_width, filter_ptr + 2 * 3); + + return res; +} + +inline float conv3x3(const float *input_ptr, + const float *filter_ptr, + const int row_width) { + float res = input_ptr[0] * filter_ptr[0] + input_ptr[1] * filter_ptr[1] + input_ptr[2] * filter_ptr[2]; + input_ptr += row_width; + filter_ptr += 3; + res += input_ptr[0] * filter_ptr[0] + input_ptr[1] * filter_ptr[1] + input_ptr[2] * filter_ptr[2]; + input_ptr += row_width; + filter_ptr += 3; + res += input_ptr[0] * filter_ptr[0] + input_ptr[1] * filter_ptr[1] + input_ptr[2] * filter_ptr[2]; + + return res; +} + +void kernel depthwise_conv_3x3_s1(global const float *input, /* n, c, h, w */ + global const float *filter, /* m, i, kh, kw */ + global const float *bias, /* o */ + global float *output, /* n, c, h, w */ + private const int in_chan_num, + private const int out_chan_num, + private const int in_height, + private const int in_width, + private const int out_height, + private const int out_width) { + int batch = get_global_id(0); + int out_chan_blk = get_global_id(1); + int out_pixel_blk = get_global_id(2); + + const int in_pixel = in_height * in_width; + const int out_pixel = out_height * out_width; + const int multiplier = out_chan_num / in_chan_num; + + const int round_out_width = (out_width + 3) / 4; + const int out_pixel_height = out_pixel_blk / round_out_width; + const int out_pixel_width = out_pixel_blk % round_out_width; + + const int out_chan_begin = out_chan_blk * 4; + const int out_chan_end = min(out_chan_begin + 4, out_chan_num); + const int out_pixel_begin = out_pixel_height * out_width + out_pixel_width * 4; + const int out_pixel_end = min(out_pixel_begin + 4, (out_pixel_height + 1) * out_width); + const int in_pixel_begin = out_pixel_height * in_width + out_pixel_width * 4; + + const int in_offset = batch * in_chan_num * in_pixel; + const int out_offset = batch * out_chan_num * out_pixel; + const float *input_base = input + in_offset + in_pixel_begin; + float *output_base = output + out_offset + out_pixel_begin; + + int pixels = out_pixel_end - out_pixel_begin; + + for (int i = out_chan_begin; i < out_chan_end; ++i) { + float bias_value = bias[i]; + const float *input_ptr = input_base + (i / multiplier) * in_pixel; + const float *filter_ptr = filter + i * 9; + float *output_ptr = output_base + i * out_pixel; + if (pixels < 4) { + for (int out_idx = 0; out_idx < pixels; ++out_idx) { + output_ptr[out_idx] = bias_value; + output_ptr[out_idx] += conv3x3(input_ptr, filter_ptr, in_width); + input_ptr += 1; + } + } else { + float4 res = conv3x3x4(input_ptr, filter_ptr, in_width); + res += (float4)bias_value; + vstore4(res, 0, output_ptr); + } + } + +} diff --git a/mace/kernels/opencl/depthwise_conv_opencl.cc b/mace/kernels/opencl/depthwise_conv_opencl.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca6a5b44682b2d0cd0c37ecd3f24a4daaf487dc3 --- /dev/null +++ b/mace/kernels/opencl/depthwise_conv_opencl.cc @@ -0,0 +1,57 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/depthwise_conv2d.h" + +namespace mace { +namespace kernels { + +extern void DepthwiseConvOpenclK3x3S1(const Tensor *input, const Tensor *filter, + const Tensor *bias, Tensor *output); + +template <> +void DepthwiseConv2dFunctor::operator()(const Tensor *input, + const Tensor *filter, + const Tensor *bias, + Tensor *output) { + typedef void (*Conv2dOpenclFunction)(const Tensor *input, const Tensor *filter, + const Tensor *bias, Tensor *output); + // Selection matrix: kernel_size x stride_size + static const Conv2dOpenclFunction selector[5][2] = { + {nullptr, nullptr}, + {nullptr, nullptr}, + {DepthwiseConvOpenclK3x3S1, nullptr}, + {nullptr, nullptr}, + {nullptr, nullptr}}; + + index_t kernel_h = filter->shape()[2]; + index_t kernel_w = filter->shape()[3]; + if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || + strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || + selector[kernel_h - 1][strides_[0] - 1] == nullptr) { + LOG(WARNING) << "OpenCL conv2d kernel with " + << "filter" << kernel_h << "x" << kernel_w << "," + << " stride " << strides_[0] << "x" << strides_[1] + << " is not implemented yet, using slow version"; + // TODO(heliangliang) The CPU/NEON kernel should map the buffer + DepthwiseConv2dFunctor(strides_, paddings_, dilations_)( + input, filter, bias, output); + return; + } + + auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1]; + if (paddings_[0] > 0 || paddings_[1] > 0) { + Tensor padded_input(GetDeviceAllocator(DeviceType::OPENCL), DataTypeToEnum::v()); + Tensor::MappingGuard input_mapper(input); + ConstructInputWithPadding(input->data(), input->shape().data(), paddings_.data(), + &padded_input); + conv2d_func(&padded_input, filter, bias, output); + }else { + conv2d_func(input, filter, bias, output); + } + +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc b/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc new file mode 100644 index 0000000000000000000000000000000000000000..e76858a51ef3c7839988a32684f12892cdb46524 --- /dev/null +++ b/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc @@ -0,0 +1,57 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/common.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/conv_2d.h" + +namespace mace { +namespace kernels { + +extern void DepthwiseConvOpenclK3x3S1(const Tensor *input, + const Tensor *filter, + const Tensor *bias, + Tensor *output) { + const index_t batch = output->dim(0); + const index_t channels = output->dim(1); + const index_t height = output->dim(2); + const index_t width = output->dim(3); + + const index_t input_batch = input->dim(0); + const index_t input_channels = input->dim(1); + const index_t input_height = input->dim(2); + const index_t input_width = input->dim(3); + + MACE_CHECK(input_batch == batch); + + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + auto conv_2d = cl::KernelFunctor(program, "depthwise_conv_3x3_s1"); + const index_t pixels = height * width; + const index_t channel_blocks = (channels + 3) / 4; + const index_t pixel_blocks = (width + 3) / 4 * height; + + cl_int error; + conv_2d(cl::EnqueueArgs(runtime->command_queue(), + cl::NDRange(static_cast(batch), + static_cast(channel_blocks), + static_cast(pixel_blocks)), + cl::NDRange(1, 1, 256)), + *(static_cast(input->buffer())), + *(static_cast(filter->buffer())), + *(static_cast(bias->buffer())), + *(static_cast(output->buffer())), + static_cast(input_channels), + static_cast(channels), + static_cast(input_height), + static_cast(input_width), + static_cast(height), + static_cast(width), + error); + MACE_CHECK(error == CL_SUCCESS); +}; + +} // namespace kernels +} // namespace mace diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 1575d2f2105c687bce3c2ecfd6a3244cc412c167..149711950eb0c29e69d715a74ed8e5b7dca5579b 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -171,7 +171,7 @@ TEST_F(Conv2dOpTest, Conv1x1) { } // TODO we need more tests -TEST_F(Conv2dOpTest, IdleConvNxNS12) { +TEST_F(Conv2dOpTest, AlignedConvNxNS12) { testing::internal::LogToStderr(); auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, Padding type) { @@ -222,7 +222,7 @@ TEST_F(Conv2dOpTest, IdleConvNxNS12) { } } -TEST_F(Conv2dOpTest, DisgustConvNxNS12) { +TEST_F(Conv2dOpTest, UnalignedConvNxNS12) { testing::internal::LogToStderr(); auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, Padding type) { diff --git a/mace/ops/depthwise_conv2d.cc b/mace/ops/depthwise_conv2d.cc index 6d66a6888b424e5846d9c11cc0d77798e7636421..992a6f2aa4584b6a9c5a1378885237fd19af6725 100644 --- a/mace/ops/depthwise_conv2d.cc +++ b/mace/ops/depthwise_conv2d.cc @@ -14,4 +14,7 @@ REGISTER_NEON_OPERATOR(DepthwiseConv2d, DepthwiseConv2dOp); #endif // __ARM_NEON +REGISTER_OPENCL_OPERATOR(DepthwiseConv2d, + DepthwiseConv2dOp); + } // namespace mace diff --git a/mace/ops/depthwise_conv2d.h b/mace/ops/depthwise_conv2d.h index 58c126fcf056e7baef4223bc788a582c08020e76..d4812def1fb1aaa534148e0951de5f06cec60564 100644 --- a/mace/ops/depthwise_conv2d.h +++ b/mace/ops/depthwise_conv2d.h @@ -26,10 +26,9 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase { bool Run() override { const Tensor *input = this->Input(INPUT); const Tensor *filter = this->Input(FILTER); - const T *bias_data = nullptr; + const Tensor *bias = nullptr; if (this->InputSize() >= 3) { - const Tensor *bias = this->Input(BIAS); - bias_data = bias->data(); + bias = this->Input(BIAS); } Tensor *output = this->Output(OUTPUT); @@ -47,9 +46,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase { output->Resize(output_shape); functor_.paddings_ = paddings; - functor_(input->data(), input->shape().data(), filter->data(), - filter_shape.data(), bias_data, output->mutable_data(), - output->shape().data()); + functor_(input, filter, bias, output); return true; } diff --git a/mace/ops/depthwise_conv2d_test.cc b/mace/ops/depthwise_conv2d_test.cc index 34ecb048201460c8fe6ea63797e031162233fff9..5a588950cde2c3cebb5983bb2e6c98872f673c05 100644 --- a/mace/ops/depthwise_conv2d_test.cc +++ b/mace/ops/depthwise_conv2d_test.cc @@ -9,10 +9,11 @@ using namespace mace; class DepthwiseConv2dOpTest : public OpsTestBase {}; -TEST_F(DepthwiseConv2dOpTest, Simple_VALID) { +template +void SimpleValidTest() { testing::internal::LogToStderr(); // Construct graph - auto &net = test_net(); + OpsTestNet net; OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") .Input("Input") .Input("Filter") @@ -26,15 +27,15 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddInputFromArray("Input", {1, 2, 2, 3}, - {1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12}); - net.AddInputFromArray( + net.AddInputFromArray("Input", {1, 2, 2, 3}, + {1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12}); + net.AddInputFromArray( "Filter", {2, 2, 2, 2}, {1.0f, 5.0f, 9.0f, 13.0f, 2.0f, 6.0f, 10.0f, 14.0f, 3.0f, 7.0f, 11.0f, 15.0f, 4.0f, 8.0f, 12.0f, 16.0f}); - net.AddInputFromArray("Bias", {4}, {.1f, .2f, .3f, .4f}); + net.AddInputFromArray("Bias", {4}, {.1f, .2f, .3f, .4f}); // Run - net.RunOp(); + net.RunOp(D); // Check auto expected = CreateTensor( @@ -42,22 +43,26 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) { {196.1f, 252.1f, 216.2f, 280.2f, 272.3f, 344.3f, 296.4f, 376.4f}); ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); + +} + +TEST_F(DepthwiseConv2dOpTest, SimpleCPU) { + SimpleValidTest(); } -TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) { +template +void TestNxNS12(const index_t height, const index_t width) { testing::internal::LogToStderr(); auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, Padding type) { srand(time(NULL)); // generate random input - index_t batch = 2 + rand() % 10; - index_t input_channels = 3 + rand() % 10; - index_t height = 107; - index_t width = 113; - index_t multiplier = 3 + rand() % 10; + index_t batch = 1; + index_t input_channels = 3; + index_t multiplier = 2; // Construct graph - auto &net = test_net(); + OpsTestNet net; OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") .Input("Input") .Input("Filter") @@ -71,19 +76,18 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddRandomInput("Input", {batch, input_channels, height, width}); - net.AddRandomInput("Filter", - {multiplier, input_channels, kernel_h, kernel_w}); - net.AddRandomInput("Bias", {multiplier * input_channels}); - // run cpu - net.RunOp(); + net.AddRandomInput("Input", {batch, input_channels, height, width}); + net.AddRandomInput("Filter", {multiplier, input_channels, kernel_h, kernel_w}); + net.AddRandomInput("Bias", {multiplier * input_channels}); + // Run on device + net.RunOp(D); // Check Tensor expected; expected.Copy(*net.GetOutput("Output")); - // Run NEON - net.RunOp(DeviceType::NEON); + // run cpu + net.RunOp(); ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-3); }; @@ -93,4 +97,31 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) { func(kernel_size, kernel_size, stride, stride, SAME); } } + +} + +TEST_F(DepthwiseConv2dOpTest, NeonSimpleNxNS12) { + TestNxNS12(4, 4); +} + +TEST_F(DepthwiseConv2dOpTest, OpenCLSimpleNxNS12) { + TestNxNS12(4, 4); +} + +TEST_F(DepthwiseConv2dOpTest, NeonAlignedNxNS12) { + TestNxNS12(64, 64); + TestNxNS12(128, 128); +} + +TEST_F(DepthwiseConv2dOpTest, OpenCLAlignedNxNS12) { + TestNxNS12(64, 64); + TestNxNS12(128, 128); +} + +TEST_F(DepthwiseConv2dOpTest, NeonUnalignedNxNS12) { + TestNxNS12(107, 113); +} + +TEST_F(DepthwiseConv2dOpTest, OpenCLUnalignedNxNS12) { + TestNxNS12(107, 113); } diff --git a/mace/ops/depthwise_conv_2d_benchmark.cc b/mace/ops/depthwise_conv_2d_benchmark.cc index f801c075e1e004e783a955c95869c3b2f41f692d..2534cdad9504bb67075fd947fa7e532d25900734 100644 --- a/mace/ops/depthwise_conv_2d_benchmark.cc +++ b/mace/ops/depthwise_conv_2d_benchmark.cc @@ -38,20 +38,22 @@ static void DepthwiseConv2d(int iters, net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddRandomInput("Input", {batch, channels, height, width}); - net.AddRandomInput("Filter", + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Filter", {output_channels, channels, kernel_h, kernel_w}); - net.AddRandomInput("Bias", {output_channels}); + net.AddRandomInput("Bias", {output_channels}); // Warm-up for (int i = 0; i < 5; ++i) { net.RunOp(D); } + net.Sync(); mace::testing::StartTiming(); while (iters--) { net.RunOp(D); } + net.Sync(); } #define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, \ @@ -70,7 +72,8 @@ static void DepthwiseConv2d(int iters, #define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \ BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \ - BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON); + BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON);\ + BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL); BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);