From 85258d86df2894dc73a1323860a080e1b9cf0943 Mon Sep 17 00:00:00 2001 From: wuchenghui Date: Thu, 3 May 2018 19:16:08 +0800 Subject: [PATCH] fix memory invalid read --- mace/core/tensor.h | 12 +- mace/kernels/arm/conv_2d_neon_3x3.cc | 36 ++--- mace/kernels/arm/depthwise_conv2d_neon_3x3.cc | 40 +++--- mace/kernels/conv_2d.h | 9 +- mace/ops/depthwise_conv2d_test.cc | 128 +++++++++--------- 5 files changed, 119 insertions(+), 106 deletions(-) diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 93f3e93d..8cc2359d 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -28,6 +28,13 @@ #include "mace/public/mace.h" #include "mace/utils/logging.h" +#ifdef MACE_ENABLE_NEON +// Avoid over-bound accessing memory +#define EXTRA_BUFFER_PAD_SIZE 64 +#else +#define EXTRA_BUFFER_PAD_SIZE 0 +#endif + namespace mace { #define SINGLE_ARG(...) __VA_ARGS__ @@ -212,10 +219,11 @@ class Tensor { image_shape_.clear(); if (buffer_ != nullptr) { MACE_CHECK(!has_opencl_image(), "Cannot resize image, use ResizeImage."); - if (raw_size() > buffer_->size()) buffer_->Resize(raw_size()); + if (raw_size() + EXTRA_BUFFER_PAD_SIZE > buffer_->size()) + buffer_->Resize(raw_size() + EXTRA_BUFFER_PAD_SIZE); } else { MACE_CHECK(is_buffer_owner_); - buffer_ = new Buffer(allocator_, raw_size()); + buffer_ = new Buffer(allocator_, raw_size() + EXTRA_BUFFER_PAD_SIZE); } } diff --git a/mace/kernels/arm/conv_2d_neon_3x3.cc b/mace/kernels/arm/conv_2d_neon_3x3.cc index b8b38340..fba0a7e2 100644 --- a/mace/kernels/arm/conv_2d_neon_3x3.cc +++ b/mace/kernels/arm/conv_2d_neon_3x3.cc @@ -334,7 +334,7 @@ void Conv2dNeonK3x3S1(const float *input, float32x4_t vf00, vf01, vf02; vf00 = vld1q_f32(filter_ptr0); vf01 = vld1q_f32(filter_ptr0 + 3); - vf02 = vld1q_f32(filter_ptr0 + 6); + vf02 = vld1q_f32(filter_ptr0 + 5); for (index_t h = 0; h + 1 < out_height; h += 2) { for (index_t w = 0; w + 3 < out_width; w += 4) { @@ -377,9 +377,9 @@ void Conv2dNeonK3x3S1(const float *input, vo00 = vfmaq_laneq_f32(vo00, vi10, vf01, 0); vo00 = vfmaq_laneq_f32(vo00, vi11, vf01, 1); vo00 = vfmaq_laneq_f32(vo00, vi12, vf01, 2); - vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 0); - vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 1); - vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 2); + vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 1); + vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 2); + vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 3); // outch 0, height 1 vo01 = vfmaq_laneq_f32(vo01, vi10, vf00, 0); @@ -388,9 +388,9 @@ void Conv2dNeonK3x3S1(const float *input, vo01 = vfmaq_laneq_f32(vo01, vi20, vf01, 0); vo01 = vfmaq_laneq_f32(vo01, vi21, vf01, 1); vo01 = vfmaq_laneq_f32(vo01, vi22, vf01, 2); - vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 0); - vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 1); - vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 2); + vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 1); + vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 2); + vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 3); vst1q_f32(out_ptr0, vo00); vst1q_f32(out_ptr0 + out_width, vo01); @@ -413,12 +413,12 @@ void Conv2dNeonK3x3S1(const float *input, } // h #elif defined(MACE_ENABLE_NEON) // arm v7 // load filter (1 outch x 3 height x 3 width): vf_outch_height - float32x2_t vf01, vf23, vf45, vf67, vf89; + float32x2_t vf01, vf23, vf45, vf67, vf78; vf01 = vld1_f32(filter_ptr0); vf23 = vld1_f32(filter_ptr0 + 2); vf45 = vld1_f32(filter_ptr0 + 4); vf67 = vld1_f32(filter_ptr0 + 6); - vf89 = vld1_f32(filter_ptr0 + 8); + vf78 = vld1_f32(filter_ptr0 + 7); for (index_t h = 0; h + 1 < out_height; h += 2) { for (index_t w = 0; w + 3 < out_width; w += 4) { @@ -463,7 +463,7 @@ void Conv2dNeonK3x3S1(const float *input, vo00 = vmlaq_lane_f32(vo00, vi12, vf45, 1); vo00 = vmlaq_lane_f32(vo00, vi20, vf67, 0); vo00 = vmlaq_lane_f32(vo00, vi21, vf67, 1); - vo00 = vmlaq_lane_f32(vo00, vi22, vf89, 0); + vo00 = vmlaq_lane_f32(vo00, vi22, vf78, 1); // outch 0, height 1 vo01 = vmlaq_lane_f32(vo01, vi10, vf01, 0); @@ -474,7 +474,7 @@ void Conv2dNeonK3x3S1(const float *input, vo01 = vmlaq_lane_f32(vo01, vi22, vf45, 1); vo01 = vmlaq_lane_f32(vo01, vi30, vf67, 0); vo01 = vmlaq_lane_f32(vo01, vi31, vf67, 1); - vo01 = vmlaq_lane_f32(vo01, vi32, vf89, 0); + vo01 = vmlaq_lane_f32(vo01, vi32, vf78, 1); vst1q_f32(out_ptr0, vo00); vst1q_f32(out_ptr0 + out_width, vo01); @@ -544,7 +544,7 @@ void Conv2dNeonK3x3S2(const float *input, float32x4_t vf00, vf01, vf02; vf00 = vld1q_f32(filter_ptr); vf01 = vld1q_f32(filter_ptr + 3); - vf02 = vld1q_f32(filter_ptr + 6); + vf02 = vld1q_f32(filter_ptr + 5); for (index_t h = 0; h < out_height; ++h) { for (index_t w = 0; w + 3 < out_width; w += 4) { @@ -592,21 +592,21 @@ void Conv2dNeonK3x3S2(const float *input, vo = vfmaq_laneq_f32(vo, vi10, vf01, 0); vo = vfmaq_laneq_f32(vo, vi11, vf01, 1); vo = vfmaq_laneq_f32(vo, vi12, vf01, 2); - vo = vfmaq_laneq_f32(vo, vi20, vf02, 0); - vo = vfmaq_laneq_f32(vo, vi21, vf02, 1); - vo = vfmaq_laneq_f32(vo, vi22, vf02, 2); + vo = vfmaq_laneq_f32(vo, vi20, vf02, 1); + vo = vfmaq_laneq_f32(vo, vi21, vf02, 2); + vo = vfmaq_laneq_f32(vo, vi22, vf02, 3); vst1q_f32(out_base + out_offset, vo); } // w } // h #elif defined(MACE_ENABLE_NEON) // arm v7 // load filter (1 outch x 3 height x 3 width): vf_outch_height - float32x2_t vf01, vf23, vf45, vf67, vf89; + float32x2_t vf01, vf23, vf45, vf67, vf78; vf01 = vld1_f32(filter_ptr); vf23 = vld1_f32(filter_ptr + 2); vf45 = vld1_f32(filter_ptr + 4); vf67 = vld1_f32(filter_ptr + 6); - vf89 = vld1_f32(filter_ptr + 8); + vf78 = vld1_f32(filter_ptr + 7); for (index_t h = 0; h < out_height; ++h) { for (index_t w = 0; w + 3 < out_width; w += 4) { @@ -656,7 +656,7 @@ void Conv2dNeonK3x3S2(const float *input, vo = vmlaq_lane_f32(vo, vi12, vf45, 1); vo = vmlaq_lane_f32(vo, vi20, vf67, 0); vo = vmlaq_lane_f32(vo, vi21, vf67, 1); - vo = vmlaq_lane_f32(vo, vi22, vf89, 0); + vo = vmlaq_lane_f32(vo, vi22, vf78, 1); vst1q_f32(out_base + out_offset, vo); } // w diff --git a/mace/kernels/arm/depthwise_conv2d_neon_3x3.cc b/mace/kernels/arm/depthwise_conv2d_neon_3x3.cc index 489a3ce8..fb0f3933 100644 --- a/mace/kernels/arm/depthwise_conv2d_neon_3x3.cc +++ b/mace/kernels/arm/depthwise_conv2d_neon_3x3.cc @@ -104,7 +104,7 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, float32x4_t vf00, vf01, vf02; vf00 = vld1q_f32(filter_ptr); vf01 = vld1q_f32(filter_ptr + 3); - vf02 = vld1q_f32(filter_ptr + 6); + vf02 = vld1q_f32(filter_ptr + 5); for (h = valid_h_start; h + 1 < valid_h_stop; h += 2) { // left @@ -180,9 +180,9 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, vo00 = vfmaq_laneq_f32(vo00, vi10, vf01, 0); vo00 = vfmaq_laneq_f32(vo00, vi11, vf01, 1); vo00 = vfmaq_laneq_f32(vo00, vi12, vf01, 2); - vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 0); - vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 1); - vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 2); + vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 1); + vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 2); + vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 3); // outch 0, height 1 vo01 = vfmaq_laneq_f32(vo01, vi10, vf00, 0); @@ -191,9 +191,9 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, vo01 = vfmaq_laneq_f32(vo01, vi20, vf01, 0); vo01 = vfmaq_laneq_f32(vo01, vi21, vf01, 1); vo01 = vfmaq_laneq_f32(vo01, vi22, vf01, 2); - vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 0); - vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 1); - vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 2); + vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 1); + vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 2); + vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 3); #else // outch 0, height 0 vo00 = vmlaq_lane_f32(vo00, vi00, vget_low_f32(vf00), 0); @@ -202,9 +202,9 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, vo00 = vmlaq_lane_f32(vo00, vi10, vget_low_f32(vf01), 0); vo00 = vmlaq_lane_f32(vo00, vi11, vget_low_f32(vf01), 1); vo00 = vmlaq_lane_f32(vo00, vi12, vget_high_f32(vf01), 0); - vo00 = vmlaq_lane_f32(vo00, vi20, vget_low_f32(vf02), 0); - vo00 = vmlaq_lane_f32(vo00, vi21, vget_low_f32(vf02), 1); - vo00 = vmlaq_lane_f32(vo00, vi22, vget_high_f32(vf02), 0); + vo00 = vmlaq_lane_f32(vo00, vi20, vget_low_f32(vf02), 1); + vo00 = vmlaq_lane_f32(vo00, vi21, vget_high_f32(vf02), 0); + vo00 = vmlaq_lane_f32(vo00, vi22, vget_high_f32(vf02), 1); // outch 0, height 1 vo01 = vmlaq_lane_f32(vo01, vi10, vget_low_f32(vf00), 0); @@ -213,9 +213,9 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, vo01 = vmlaq_lane_f32(vo01, vi20, vget_low_f32(vf01), 0); vo01 = vmlaq_lane_f32(vo01, vi21, vget_low_f32(vf01), 1); vo01 = vmlaq_lane_f32(vo01, vi22, vget_high_f32(vf01), 0); - vo01 = vmlaq_lane_f32(vo01, vi30, vget_low_f32(vf02), 0); - vo01 = vmlaq_lane_f32(vo01, vi31, vget_low_f32(vf02), 1); - vo01 = vmlaq_lane_f32(vo01, vi32, vget_high_f32(vf02), 0); + vo01 = vmlaq_lane_f32(vo01, vi30, vget_low_f32(vf02), 1); + vo01 = vmlaq_lane_f32(vo01, vi31, vget_high_f32(vf02), 0); + vo01 = vmlaq_lane_f32(vo01, vi32, vget_high_f32(vf02), 1); #endif vst1q_f32(out_base + out_offset, vo00); vst1q_f32(out_base + out_offset + out_width, vo01); @@ -344,7 +344,7 @@ void DepthwiseConv2dNeonK3x3S2(const float *input, float32x4_t vf00, vf01, vf02; vf00 = vld1q_f32(filter_ptr); vf01 = vld1q_f32(filter_ptr + 3); - vf02 = vld1q_f32(filter_ptr + 6); + vf02 = vld1q_f32(filter_ptr + 5); for (h = valid_h_start; h < valid_h_stop; ++h) { // left @@ -409,9 +409,9 @@ void DepthwiseConv2dNeonK3x3S2(const float *input, vo = vfmaq_laneq_f32(vo, vi10, vf01, 0); vo = vfmaq_laneq_f32(vo, vi11, vf01, 1); vo = vfmaq_laneq_f32(vo, vi12, vf01, 2); - vo = vfmaq_laneq_f32(vo, vi20, vf02, 0); - vo = vfmaq_laneq_f32(vo, vi21, vf02, 1); - vo = vfmaq_laneq_f32(vo, vi22, vf02, 2); + vo = vfmaq_laneq_f32(vo, vi20, vf02, 1); + vo = vfmaq_laneq_f32(vo, vi21, vf02, 2); + vo = vfmaq_laneq_f32(vo, vi22, vf02, 3); #else // outch 0, height 0 vo = vmlaq_lane_f32(vo, vi00, vget_low_f32(vf00), 0); @@ -420,9 +420,9 @@ void DepthwiseConv2dNeonK3x3S2(const float *input, vo = vmlaq_lane_f32(vo, vi10, vget_low_f32(vf01), 0); vo = vmlaq_lane_f32(vo, vi11, vget_low_f32(vf01), 1); vo = vmlaq_lane_f32(vo, vi12, vget_high_f32(vf01), 0); - vo = vmlaq_lane_f32(vo, vi20, vget_low_f32(vf02), 0); - vo = vmlaq_lane_f32(vo, vi21, vget_low_f32(vf02), 1); - vo = vmlaq_lane_f32(vo, vi22, vget_high_f32(vf02), 0); + vo = vmlaq_lane_f32(vo, vi20, vget_low_f32(vf02), 1); + vo = vmlaq_lane_f32(vo, vi21, vget_high_f32(vf02), 0); + vo = vmlaq_lane_f32(vo, vi22, vget_high_f32(vf02), 1); #endif vst1q_f32(out_base + out_offset, vo); } // w diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index cf49d2cb..c9a859ff 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -461,7 +461,8 @@ struct Conv2dFunctor : Conv2dFunctorBase { || extra_input_width != input_width) { padded_input_size = batch * input_channels * (input_height + pad_top + pad_bottom) - * (input_width + pad_left + pad_right) * sizeof(float); + * (input_width + pad_left + pad_right) * sizeof(float) + + EXTRA_BUFFER_PAD_SIZE; total_scratch_size += padded_input_size; } if (extra_output_height != height || extra_output_width != width) { @@ -482,8 +483,8 @@ struct Conv2dFunctor : Conv2dFunctorBase { // decide which convolution function to call if (use_winograd) { - transformed_input.Resize(transformed_input_shape); - transformed_output.Resize(transformed_output_shape); + transformed_input.Reshape(transformed_input_shape); + transformed_output.Reshape(transformed_output_shape); const float *transformed_filter_ptr; if (transformed_filter_.dim_size() == 0) { if (is_filter_transformed_) { @@ -652,7 +653,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { Tensor *pad_output_ptr = output; if (extra_output_height != height || extra_output_width != width) { - padded_output.Resize({batch, channels, extra_output_height, + padded_output.Reshape({batch, channels, extra_output_height, extra_output_width}); padded_output.Clear(); pad_output_ptr = &padded_output; diff --git a/mace/ops/depthwise_conv2d_test.cc b/mace/ops/depthwise_conv2d_test.cc index 825396e1..5ce4b181 100644 --- a/mace/ops/depthwise_conv2d_test.cc +++ b/mace/ops/depthwise_conv2d_test.cc @@ -114,50 +114,28 @@ TEST_F(DepthwiseConv2dOpTest, SimpleOpenCLHalf) { namespace { template -void ComplexValidTest() { +void ComplexValidTest(index_t batch, index_t channel, index_t height, + index_t width, index_t kernel, index_t multiplier, + int stride) { testing::internal::LogToStderr(); // Construct graph OpsTestNet net; // Add input data - net.AddInputFromArray( - "Input", {1, 10, 10, 3}, - {0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, - 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, - 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35, - 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, - 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, - 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, - 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, - 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, - 0.96, 0.97, 0.98, 0.99, 1.0, 1.01, 1.02, 1.03, 1.04, 1.05, 1.06, 1.07, - 1.08, 1.09, 1.1, 1.11, 1.12, 1.13, 1.14, 1.15, 1.16, 1.17, 1.18, 1.19, - 1.2, 1.21, 1.22, 1.23, 1.24, 1.25, 1.26, 1.27, 1.28, 1.29, 1.3, 1.31, - 1.32, 1.33, 1.34, 1.35, 1.36, 1.37, 1.38, 1.39, 1.4, 1.41, 1.42, 1.43, - 1.44, 1.45, 1.46, 1.47, 1.48, 1.49, 1.5, 1.51, 1.52, 1.53, 1.54, 1.55, - 1.56, 1.57, 1.58, 1.59, 1.6, 1.61, 1.62, 1.63, 1.64, 1.65, 1.66, 1.67, - 1.68, 1.69, 1.7, 1.71, 1.72, 1.73, 1.74, 1.75, 1.76, 1.77, 1.78, 1.79, - 1.8, 1.81, 1.82, 1.83, 1.84, 1.85, 1.86, 1.87, 1.88, 1.89, 1.9, 1.91, - 1.92, 1.93, 1.94, 1.95, 1.96, 1.97, 1.98, 1.99, 2.0, 2.01, 2.02, 2.03, - 2.04, 2.05, 2.06, 2.07, 2.08, 2.09, 2.1, 2.11, 2.12, 2.13, 2.14, 2.15, - 2.16, 2.17, 2.18, 2.19, 2.2, 2.21, 2.22, 2.23, 2.24, 2.25, 2.26, 2.27, - 2.28, 2.29, 2.3, 2.31, 2.32, 2.33, 2.34, 2.35, 2.36, 2.37, 2.38, 2.39, - 2.4, 2.41, 2.42, 2.43, 2.44, 2.45, 2.46, 2.47, 2.48, 2.49, 2.5, 2.51, - 2.52, 2.53, 2.54, 2.55, 2.56, 2.57, 2.58, 2.59, 2.6, 2.61, 2.62, 2.63, - 2.64, 2.65, 2.66, 2.67, 2.68, 2.69, 2.7, 2.71, 2.72, 2.73, 2.74, 2.75, - 2.76, 2.77, 2.78, 2.79, 2.8, 2.81, 2.82, 2.83, 2.84, 2.85, 2.86, 2.87, - 2.88, 2.89, 2.9, 2.91, 2.92, 2.93, 2.94, 2.95, 2.96, 2.97, 2.98, 2.99}); - net.AddInputFromArray( - "Filter", {5, 5, 3, 1}, - {0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, - 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, - 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, - 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43, - 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, - 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, - 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74}); - net.AddInputFromArray("Bias", {6}, - {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + std::vector input_data(batch * height * width * channel); + GenerateRandomRealTypeData({batch, height, width, channel}, &input_data); + net.AddInputFromArray("Input", {batch, height, width, channel}, + input_data); + std::vector filter_data(kernel * kernel * channel * multiplier); + GenerateRandomRealTypeData({kernel, kernel, channel, multiplier}, + &filter_data); + net.AddInputFromArray("Filter", + {kernel, kernel, channel, multiplier}, + filter_data); + std::vector bias_data(channel * multiplier); + GenerateRandomRealTypeData({channel * multiplier}, &bias_data); + net.AddInputFromArray("Bias", {channel * multiplier}, + bias_data); if (D == DeviceType::CPU) { net.TransformDataFormat("Input", @@ -173,7 +151,7 @@ void ComplexValidTest() { .Input("FilterOIHW") .Input("Bias") .Output("OutputNCHW") - .AddIntsArg("strides", {2, 2}) + .AddIntsArg("strides", {stride, stride}) .AddIntArg("padding", Padding::SAME) .AddIntsArg("dilations", {1, 1}) .AddIntArg("T", static_cast(DataTypeToEnum::value)) @@ -196,7 +174,7 @@ void ComplexValidTest() { .Input("FilterImage") .Input("BiasImage") .Output("OutputImage") - .AddIntsArg("strides", {2, 2}) + .AddIntsArg("strides", {stride, stride}) .AddIntArg("padding", Padding::SAME) .AddIntsArg("dilations", {1, 1}) .AddIntArg("T", static_cast(DataTypeToEnum::value)) @@ -212,25 +190,43 @@ void ComplexValidTest() { MACE_NOT_IMPLEMENTED; } - // Check + // expect + index_t out_height = (height - 1) / stride + 1; + index_t out_width = (width - 1) / stride + 1; + index_t pad_top = ((out_height - 1) * stride + kernel - height) >> 1; + index_t pad_left = ((out_width - 1) * stride + kernel - width) >> 1; + index_t out_channels = channel * multiplier; + std::vector expect(batch * out_height * out_width * out_channels); + for (index_t b = 0; b < batch; ++b) { + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w < out_width; ++w) { + for (index_t m = 0; m < out_channels; ++m) { + index_t out_offset = + ((b * out_height + h) * out_width + w) * out_channels + m; + index_t c = m / multiplier; + index_t o = m % multiplier; + float sum = 0; + for (index_t kh = 0; kh < kernel; ++kh) { + for (index_t kw = 0; kw < kernel; ++kw) { + index_t ih = h * stride - pad_top + kh; + index_t iw = w * stride - pad_left + kw; + if (ih >= 0 && ih < height && iw >= 0 && iw < width) { + index_t in_offset = + ((b * height + ih) * width + iw) * channel + c; + index_t filter_offset = + (((kh * kernel) + kw) * channel + c) * multiplier + o; + sum += input_data[in_offset] * filter_data[filter_offset]; + } + } + } + expect[out_offset] = sum + bias_data[m]; + } + } + } + } + auto expected = CreateTensor( - {1, 5, 5, 3}, - VectorStaticCast( - {4.48200035, 4.63479996, 4.79079962, 5.85899973, 6.05599976, - 6.25699997, 6.38100004, 6.59000015, 6.80300045, 6.90299988, - 7.1239996, 7.34899998, 4.03559971, 4.16820002, 4.30319977, - 8.90999985, 9.1760006, 9.44599915, 11.20499992, 11.54500103, - 11.89000034, 11.74499989, 12.09999943, 12.46000004, 12.28499985, - 12.65500069, 13.03000069, 7.00200033, 7.22399998, 7.44900036, - 13.4100008, 13.79599953, 14.18599987, 16.60500145, 17.09499741, - 17.59000015, 17.14500046, 17.65000153, 18.15999794, 17.68499947, - 18.20499992, 18.72999954, 9.97200012, 10.28399944, 10.59899998, - 17.90999985, 18.41600037, 18.92599869, 22.00500107, 22.64500046, - 23.28999901, 22.54500008, 23.19999886, 23.8599987, 23.0850029, - 23.75500107, 24.43000031, 12.94200039, 13.34400082, 13.7489996, - 6.97500038, 7.29659986, 7.62060022, 8.32049942, 8.72700024, - 9.13650036, 8.5095005, 8.92500019, 9.34349918, 8.69849968, - 9.12300014, 9.55049992, 4.55220032, 4.80690002, 5.06340027})); + {1, out_height, out_width, out_channels}, expect); if (DataTypeToEnum::value == DT_FLOAT) { ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); @@ -241,15 +237,23 @@ void ComplexValidTest() { } // namespace TEST_F(DepthwiseConv2dOpTest, ComplexCPU) { - ComplexValidTest(); + ComplexValidTest(1, 3, 10, 10, 5, 1, 2); +} + +TEST_F(DepthwiseConv2dOpTest, ComplexCPU3x3s1) { + ComplexValidTest(1, 3, 10, 10, 3, 1, 1); +} + +TEST_F(DepthwiseConv2dOpTest, ComplexCPU3x3s2) { + ComplexValidTest(1, 3, 10, 10, 3, 1, 2); } TEST_F(DepthwiseConv2dOpTest, ComplexOpenCL) { - ComplexValidTest(); + ComplexValidTest(1, 3, 10, 10, 5, 1, 2); } TEST_F(DepthwiseConv2dOpTest, ComplexOpenCLHalf) { - ComplexValidTest(); + ComplexValidTest(1, 3, 10, 10, 5, 1, 2); } namespace { -- GitLab