From 6115504361936fff130e3b4f408f65d44cd5eed6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Wed, 20 Sep 2017 18:57:14 +0800 Subject: [PATCH] Make bias optional --- mace/kernels/conv_2d.h | 4 +++- mace/kernels/depthwise_conv2d.h | 4 +++- mace/kernels/neon/conv_2d_neon_1x1.cc | 4 +--- mace/kernels/neon/conv_2d_neon_3x3.cc | 2 +- mace/kernels/neon/conv_2d_neon_5x5.cc | 5 ++--- mace/ops/conv_2d.h | 9 +++++++-- mace/ops/depthwise_conv2d.h | 8 ++++++-- 7 files changed, 23 insertions(+), 13 deletions(-) diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 536b28ad..fbe7953a 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -73,10 +73,12 @@ class Conv2dFunctor { #pragma omp parallel for collapse(2) for (int n = 0; n < batch; ++n) { for (int c = 0; c < channels; ++c) { + T bias_channel = bias ? bias[c] : 0; for (int h = 0; h < height; ++h) { for (int w = 0; w < width; ++w) { index_t offset = n * channels * height * width + c * height * width + h * width + w; + output[offset] = bias_channel; T sum = 0; const T *filter_ptr = filter + c * kernel_size; for (int inc = 0; inc < input_channels; ++inc) { @@ -102,8 +104,8 @@ class Conv2dFunctor { ++filter_ptr; } } - output[offset] = sum + bias[c]; } + output[offset] += sum; } } } diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index 472733af..c9be5c92 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -75,10 +75,12 @@ class DepthwiseConv2dFunctor { #pragma omp parallel for collapse(2) for (int n = 0; n < batch; ++n) { for (int c = 0; c < channels; ++c) { + T bias_channel = bias ? bias[c] : 0; for (int h = 0; h < height; ++h) { for (int w = 0; w < width; ++w) { index_t offset = n * channels * height * width + c * height * width + h * width + w; + output[offset] = bias_channel; T sum = 0; const T *filter_ptr = filter + c * kernel_size; for (int kh = 0; kh < kernel_h; ++kh) { @@ -103,7 +105,7 @@ class DepthwiseConv2dFunctor { ++filter_ptr; } } - output[offset] = sum + bias[c]; + output[offset] += sum; } } } diff --git a/mace/kernels/neon/conv_2d_neon_1x1.cc b/mace/kernels/neon/conv_2d_neon_1x1.cc index a82505e7..922b3265 100644 --- a/mace/kernels/neon/conv_2d_neon_1x1.cc +++ b/mace/kernels/neon/conv_2d_neon_1x1.cc @@ -47,9 +47,7 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW // Fill with bias float *output_ptr = channel_output_start; - for (index_t ptr = 0; ptr < total_pixels; ++ptr) { - output_ptr[ptr] = bias[c]; // TODO can we avoid this? - } + std::fill(output_ptr, output_ptr + total_pixels, bias ? bias[c] : 0); index_t inc = 0; // Process 4 input channels in batch diff --git a/mace/kernels/neon/conv_2d_neon_3x3.cc b/mace/kernels/neon/conv_2d_neon_3x3.cc index 6b62cb59..ac5636a8 100644 --- a/mace/kernels/neon/conv_2d_neon_3x3.cc +++ b/mace/kernels/neon/conv_2d_neon_3x3.cc @@ -28,7 +28,7 @@ namespace kernels { input_ptr += (oc / multiplier) * input_height * input_width; \ } \ float *output_ptr = output_ptr_base + oc * output_height * output_width; \ - std::fill(output_ptr, output_ptr + output_height * output_width, bias[oc]); \ + std::fill(output_ptr, output_ptr + output_height * output_width, bias ? bias[oc] : 0); \ for (int ic = 0; ic < filter_in_channels; ++ic) { \ float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)}; diff --git a/mace/kernels/neon/conv_2d_neon_5x5.cc b/mace/kernels/neon/conv_2d_neon_5x5.cc index 02c5ced2..88120f13 100644 --- a/mace/kernels/neon/conv_2d_neon_5x5.cc +++ b/mace/kernels/neon/conv_2d_neon_5x5.cc @@ -45,9 +45,8 @@ void Conv2dNeonK5x5S1(const float *input, // NCHW const float *input_ptr = input + n * input_total_pixels_per_batch; // Fill with bias - for (index_t i = 0; i < output_total_pixels_per_channel; ++i) { - output_ptr[i] = bias[c]; - } + std::fill(output_ptr, output_ptr + output_total_pixels_per_channel, + bias ? bias[c] : 0); for (index_t inc = 0; inc < input_channels; ++inc) { float *outptr = output_ptr; diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index 5c15ca83..89b91402 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -27,7 +27,12 @@ class Conv2dOp : public ConvPool2dOpBase { bool Run() override { const Tensor *input = this->Input(INPUT); const Tensor *filter = this->Input(FILTER); - const Tensor *bias = this->Input(BIAS); + const T *bias_data = nullptr; + if (this->InputSize() >= 3) { + const Tensor *bias = this->Input(BIAS); + bias_data = bias->data(); + } + Tensor *output = this->Output(OUTPUT); std::vector output_shape(4); @@ -36,7 +41,7 @@ class Conv2dOp : public ConvPool2dOpBase { output->Resize(output_shape); functor_(input->data(), input->shape().data(), filter->data(), - filter->shape().data(), bias->data(), output->mutable_data(), + filter->shape().data(), bias_data, output->mutable_data(), output->shape().data()); return true; diff --git a/mace/ops/depthwise_conv2d.h b/mace/ops/depthwise_conv2d.h index cc220f3c..9e5dc745 100644 --- a/mace/ops/depthwise_conv2d.h +++ b/mace/ops/depthwise_conv2d.h @@ -26,7 +26,11 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase { bool Run() override { const Tensor *input = this->Input(INPUT); const Tensor *filter = this->Input(FILTER); - const Tensor *bias = this->Input(BIAS); + const T *bias_data = nullptr; + if (this->InputSize() >= 3) { + const Tensor *bias = this->Input(BIAS); + bias_data = bias->data(); + } Tensor *output = this->Output(OUTPUT); // resize filter shape. @@ -38,7 +42,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase { output->Resize(output_shape); functor_(input->data(), input->shape().data(), filter->data(), - filter_shape.data(), bias->data(), output->mutable_data(), + filter_shape.data(), bias_data, output->mutable_data(), output->shape().data()); return true; -- GitLab