diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc index 077cbd812dfa734d2fcc1c548c6179fee4b3bdbe..6eec3f91aa78d05a88ea9c9151e0c0c7fa756678 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/fp32/convolution_depthwise.h" +#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" @@ -182,7 +183,15 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector(opParameter); + // if (param->kernel_h_ == 3 && param->kernel_w_ == 3 && param->stride_h_ == 1 && param->stride_w_ == 1 && + // param->dilation_h_ == 1 && param->dilation_w_ == 1) { + // kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx); + // } else { + // kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); + // } if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f2322255b4adaa64349d8393996903403feb691 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.cc @@ -0,0 +1,199 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthwiseConv2D; + +namespace mindspore::kernel { +int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() { + // init weight: o, h, w, i; o == group, i == 1 + auto weight_tensor = inputs_[kWeightIndex]; + auto origin_weight = reinterpret_cast(weight_tensor->Data()); + // o h w 1 -> o/4 h w 1 4 + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int weight_c4_size = OC4 * C4NUM * 9; + auto tmp_weight = reinterpret_cast(malloc(weight_c4_size * sizeof(float))); + if (tmp_weight == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(tmp_weight, 0, weight_c4_size * sizeof(float)); + PackNCHWToNC4HW4Fp32(origin_weight, tmp_weight, 1, conv_param_->kernel_h_ * conv_param_->kernel_w_, + conv_param_->output_channel_); + + // weight transform + int packed_weight_size = OC4 * C4NUM * 16; + packed_weight_ = reinterpret_cast(malloc(packed_weight_size * sizeof(float))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, packed_weight_size * sizeof(float)); + ConvDw3x3Fp32FilterTrans(packed_weight_, tmp_weight, OC4); + + // init bias + bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(float))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(bias_data_, 0, C4NUM * OC4 * sizeof(float)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(float)); + } + return RET_OK; +} + +int ConvolutionDepthwise3x3CPUKernel::InitBuffer() { + if (conv_param_->input_channel_ % C4NUM != 0) { + need_align_ = true; + int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM); + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4; + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(float))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_input_, 0, pack_input_size * sizeof(float)); + + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4; + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(float))); + if (packed_output_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + } + + // malloc transform buffer + trans_size_ = UP_DIV(conv_param_->output_w_, 2) * UP_DIV(conv_param_->output_h_, 2) * 16 * C4NUM; + size_t trans_buffer_size = thread_count_ * trans_size_ * sizeof(float); + trans_buffer_ = reinterpret_cast(malloc(trans_buffer_size)); + if (trans_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc trans buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwise3x3CPUKernel::Init() { + // conv base init + ConvolutionBaseCPUKernel::Init(); + + auto ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise3x3 fp32 initWeightBias error!"; + return ret; + } + + // init threadNum; + conv_param_->thread_num_ = MSMIN(thread_count_, UP_DIV(conv_param_->output_channel_, C4NUM)); + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise3x3 fp32 initBuffer error!"; + return ret; + } + + // malloc one block buffer + block_buffer_ = reinterpret_cast(malloc(thread_count_ * 16 * C4NUM * sizeof(float))); + if (block_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc block buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwise3x3CPUKernel::ReSize() { + if (need_align_) { + free(packed_input_); + free(packed_output_); + } + free(trans_buffer_); + + // conv base init + ConvolutionBaseCPUKernel::Init(); + + auto ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise3x3 fp32 initBuffer error!"; + return ret; + } + return RET_OK; +} + +int ConvolutionDepthwise3x3CPUKernel::Execute(int task_id) { + auto trans_buf = trans_buffer_ + task_id * trans_size_; + auto block_buf = block_buffer_ + task_id * 16 * C4NUM; + ConvDw3x3Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), trans_buf, + block_buf, conv_param_, task_id); + return RET_OK; +} + +int ConvDw3x3Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv_dw_3x3 = reinterpret_cast(cdata); + auto ret = conv_dw_3x3->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionDepthwise3x3Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwise3x3CPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + + // pack input: to nhwc4 + if (need_align_) { + PackNHWCToNHWC4Fp32(input_addr, packed_input_, conv_param_->input_batch_, + conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_); + } else { + packed_input_ = input_addr; + } + + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (!need_align_) { + packed_output_ = output_addr; + } + + auto ret = LiteBackendParallelLaunch(ConvDw3x3Run, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]"; + return RET_ERROR; + } + + if (need_align_) { + PackNHWC4ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h new file mode 100644 index 0000000000000000000000000000000000000000..63f3d35cd2e00ae2b801f5ab36c7ce3c0e53aecd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" + +namespace mindspore::kernel { +class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionDepthwise3x3CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + + ~ConvolutionDepthwise3x3CPUKernel() override { + free(packed_weight_); + if (need_align_) { + free(packed_input_); + free(packed_output_); + } + free(block_buffer_); + free(trans_buffer_); + }; + + int Init() override; + int ReSize() override; + int Run() override; + + int InitWeightBias(); + int InitBuffer(); + int Execute(int task_id); + + private: + float *packed_weight_; + float *packed_input_; + float *packed_output_; + float *block_buffer_; + float *trans_buffer_; + int trans_size_; + bool need_align_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.cc index 638abeb2fbbe57adcdc3dba81158791ee7f21f66..0b8123c2e10e675c7840b9730c494e2baa2e53ab 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.cc @@ -16,6 +16,7 @@ #include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" #include "src/runtime/kernel/arm/nnacl/fp32/common_func.h" +#include "src/runtime/kernel/arm/nnacl/winograd_transform.h" #ifdef ENABLE_ARM64 #include #endif @@ -212,6 +213,372 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig } /*conv depthwise fp32 end*/ +/*conv depthwise 3x3 fp32 begin*/ +void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4) { + for (int c = 0; c < oc4; c++) { + float *src = weight + c * C4NUM * 9; + float *dst = trans_weight + c * C4NUM * 16; +#ifdef ENABLE_ARM + float32x4_t g00 = vld1q_f32(src); + float32x4_t g01 = vld1q_f32(src + 4); + float32x4_t g02 = vld1q_f32(src + 2 * 4); + float32x4_t g10 = vld1q_f32(src + 3 * 4); + float32x4_t g11 = vld1q_f32(src + 4 * 4); + float32x4_t g12 = vld1q_f32(src + 5 * 4); + float32x4_t g20 = vld1q_f32(src + 6 * 4); + float32x4_t g21 = vld1q_f32(src + 7 * 4); + float32x4_t g22 = vld1q_f32(src + 8 * 4); + + float32x4_t dst00 = g00; + float32x4_t dst01 = g01; + float32x4_t dst02 = g02; + + float32x4_t dst10 = vaddq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5)); + dst10 = vaddq_f32(dst10, vmulq_n_f32(g20, 0.5)); + float32x4_t dst11 = vaddq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5)); + dst11 = vaddq_f32(dst11, vmulq_n_f32(g21, 0.5)); + float32x4_t dst12 = vaddq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5)); + dst12 = vaddq_f32(dst12, vmulq_n_f32(g22, 0.5)); + + float32x4_t dst20 = vsubq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5)); + dst20 = vaddq_f32(dst20, vmulq_n_f32(g20, 0.5)); + float32x4_t dst21 = vsubq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5)); + dst21 = vaddq_f32(dst21, vmulq_n_f32(g21, 0.5)); + float32x4_t dst22 = vsubq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5)); + dst22 = vaddq_f32(dst22, vmulq_n_f32(g22, 0.5)); + + float32x4_t dst30 = g20; + float32x4_t dst31 = g21; + float32x4_t dst32 = g22; + + float32x4_t m00 = dst00; + float32x4_t m01 = vaddq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5)); + m01 = vaddq_f32(m01, vmulq_n_f32(dst02, 0.5)); + float32x4_t m02 = vsubq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5)); + m02 = vaddq_f32(m02, vmulq_n_f32(dst02, 0.5)); + float32x4_t m03 = dst02; + + float32x4_t m10 = dst10; + float32x4_t m11 = vaddq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5)); + m11 = vaddq_f32(m11, vmulq_n_f32(dst12, 0.5)); + float32x4_t m12 = vsubq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5)); + m12 = vaddq_f32(m12, vmulq_n_f32(dst12, 0.5)); + float32x4_t m13 = dst12; + + float32x4_t m20 = dst20; + float32x4_t m21 = vaddq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5)); + m21 = vaddq_f32(m21, vmulq_n_f32(dst22, 0.5)); + float32x4_t m22 = vsubq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5)); + m22 = vaddq_f32(m22, vmulq_n_f32(dst22, 0.5)); + float32x4_t m23 = dst22; + + float32x4_t m30 = dst30; + float32x4_t m31 = vaddq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5)); + m31 = vaddq_f32(m31, vmulq_n_f32(dst32, 0.5)); + float32x4_t m32 = vsubq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5)); + m32 = vaddq_f32(m32, vmulq_n_f32(dst32, 0.5)); + float32x4_t m33 = dst32; + + vst1q_f32(dst, m00); + vst1q_f32(dst + 4, m01); + vst1q_f32(dst + 8, m02); + vst1q_f32(dst + 12, m03); + vst1q_f32(dst + 16, m10); + vst1q_f32(dst + 20, m11); + vst1q_f32(dst + 24, m12); + vst1q_f32(dst + 28, m13); + vst1q_f32(dst + 32, m20); + vst1q_f32(dst + 36, m21); + vst1q_f32(dst + 40, m22); + vst1q_f32(dst + 44, m23); + vst1q_f32(dst + 48, m30); + vst1q_f32(dst + 52, m31); + vst1q_f32(dst + 56, m32); + vst1q_f32(dst + 60, m33); +#else + for (int j = 0; j < C4NUM; j++) { + float *local_ptr = src + j; + float dst00 = local_ptr[0]; + float dst01 = (local_ptr + 4)[0]; + float dst02 = (local_ptr + 8)[0]; + + float dst10 = 0.5f * local_ptr[0] + 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0]; + float dst11 = 0.5f * (local_ptr + 4)[0] + 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0]; + float dst12 = 0.5f * (local_ptr + 8)[0] + 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0]; + + float dst20 = 0.5f * local_ptr[0] - 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0]; + float dst21 = 0.5f * (local_ptr + 4)[0] - 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0]; + float dst22 = 0.5f * (local_ptr + 8)[0] - 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0]; + + float dst30 = (local_ptr + 24)[0]; + float dst31 = (local_ptr + 28)[0]; + float dst32 = (local_ptr + 32)[0]; + + float m00 = dst00; + float m01 = 0.5f * dst00 + 0.5f * dst01 + 0.5f * dst02; + float m02 = 0.5f * dst00 - 0.5f * dst01 + 0.5f * dst02; + float m03 = dst02; + + float m10 = dst10; + float m11 = 0.5f * dst10 + 0.5f * dst11 + 0.5f * dst12; + float m12 = 0.5f * dst10 - 0.5f * dst11 + 0.5f * dst12; + float m13 = dst12; + + float m20 = dst20; + float m21 = 0.5f * dst20 + 0.5f * dst21 + 0.5f * dst22; + float m22 = 0.5f * dst20 - 0.5f * dst21 + 0.5f * dst22; + float m23 = dst22; + + float m30 = dst30; + float m31 = 0.5f * dst30 + 0.5f * dst31 + 0.5f * dst32; + float m32 = 0.5f * dst30 - 0.5f * dst31 + 0.5f * dst32; + float m33 = dst32; + + *(dst + j) = m00; + *(dst + j + 4) = m01; + *(dst + j + 8) = m02; + *(dst + j + 12) = m03; + + *(dst + j + 16) = m10; + *(dst + j + 20) = m11; + *(dst + j + 24) = m12; + *(dst + j + 28) = m13; + + *(dst + j + 32) = m20; + *(dst + j + 36) = m21; + *(dst + j + 40) = m22; + *(dst + j + 44) = m23; + + *(dst + j + 48) = m30; + *(dst + j + 52) = m31; + *(dst + j + 56) = m32; + *(dst + j + 60) = m33; + } +#endif + } +} + +void ConvDw3x3Fp32InputTrans(const float *input_data, float *trans_input, float *block_buffer, int out_h_block, + int out_w_block, const ConvParameter *conv_param) { + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int input_unit = 4; + memset(trans_input, 0, out_h_block * out_h_block * 16 * C4NUM * sizeof(float)); + + for (int oh = 0; oh < out_h_block; oh++) { + int ih = oh * 2 - conv_param->pad_h_; + int real_h_start = ih > 0 ? 0 : -ih; + int real_h_end = (ih + input_unit) < conv_param->input_h_ ? input_unit : (conv_param->input_h_ - ih); + for (int ow = 0; ow < out_w_block; ow++) { + int iw = ow * 2 - conv_param->pad_w_; + int real_w_start = iw > 0 ? 0 : -iw; + int real_w_end = (iw + input_unit) < conv_param->input_w_ ? input_unit : (conv_param->input_w_ - iw); + + memset(block_buffer, 0, 16 * C4NUM * sizeof(float)); + int src_plane_offset = ic4 * C4NUM * (ih * conv_param->input_w_ + iw); + for (int h = real_h_start; h < real_h_end; h++) { + int src_h_offset = src_plane_offset + (h * conv_param->input_w_) * ic4 * C4NUM; + int dst_h_offset = (h * input_unit) * C4NUM; + for (int w = real_w_start; w < real_w_end; w++) { + int src_w_offset = src_h_offset + w * ic4 * C4NUM; + int dst_w_offset = dst_h_offset + w * C4NUM; + float *src_addr = (float *)(input_data) + src_w_offset; + float *dst_addr = block_buffer + dst_w_offset; +#ifdef ENABLE_NEON + vst1q_f32(dst_addr, vld1q_f32(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + (dst_addr + k)[0] = (src_addr + k)[0]; + } +#endif + } + } + int trans_offset = (oh * out_w_block + ow) * 16 * C4NUM; + Conv3x3Fp32InputUnit(block_buffer, trans_input + trans_offset, C4NUM); + } + } +} + +// todo yangruoqi: implement assembly +void ConvDw3x3Fp32Winograd(float *trans_buffer, const float *weight, int out_h_block, int out_w_block) { + int unit = 4; + for (int oh = 0; oh < out_h_block; oh++) { + float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM; + for (int ow = 0; ow < out_w_block; ow++) { + float *buf_ow = buf_oh + ow * 16 * C4NUM; + for (int kh = 0; kh < unit; kh++) { + float *buf_kh = buf_ow + kh * unit * C4NUM; + const float *weight_kh = weight + kh * unit * C4NUM; + for (int kw = 0; kw < unit; kw++) { + float *buf_kw = buf_kh + kw * C4NUM; + const float *weight_kw = weight_kh + kw * C4NUM; + for (int c = 0; c < C4NUM; c++) { + buf_kw[c] = buf_kw[c] * weight_kw[c]; + } + } + } + } + } +} + +void ConvDw3x3Fp32OutputUnit(float *src_buf, float *dst_output, const float *bias, int channel, int output_w, + bool h_in_range, bool w_in_range, bool is_relu, bool is_relu6) { +#ifdef ENABLE_ARM + float32x4_t bias_ptr = vld1q_f32(bias); + + float32x4_t s00 = vld1q_f32(src_buf); + float32x4_t s01 = vld1q_f32(src_buf + 4); + float32x4_t s02 = vld1q_f32(src_buf + 8); + float32x4_t s03 = vld1q_f32(src_buf + 12); + + float32x4_t s10 = vld1q_f32(src_buf + 16); + float32x4_t s11 = vld1q_f32(src_buf + 20); + float32x4_t s12 = vld1q_f32(src_buf + 24); + float32x4_t s13 = vld1q_f32(src_buf + 28); + + float32x4_t s20 = vld1q_f32(src_buf + 32); + float32x4_t s21 = vld1q_f32(src_buf + 36); + float32x4_t s22 = vld1q_f32(src_buf + 40); + float32x4_t s23 = vld1q_f32(src_buf + 44); + + float32x4_t s30 = vld1q_f32(src_buf + 48); + float32x4_t s31 = vld1q_f32(src_buf + 52); + float32x4_t s32 = vld1q_f32(src_buf + 56); + float32x4_t s33 = vld1q_f32(src_buf + 60); + + float32x4_t t00 = vaddq_f32(vaddq_f32(s00, s10), s20); + float32x4_t t01 = vaddq_f32(vaddq_f32(s01, s11), s21); + float32x4_t t02 = vaddq_f32(vaddq_f32(s02, s12), s22); + float32x4_t t03 = vaddq_f32(vaddq_f32(s03, s13), s23); + + float32x4_t t10 = vsubq_f32(vsubq_f32(s10, s20), s30); + float32x4_t t11 = vsubq_f32(vsubq_f32(s11, s21), s31); + float32x4_t t12 = vsubq_f32(vsubq_f32(s12, s22), s32); + float32x4_t t13 = vsubq_f32(vsubq_f32(s13, s23), s33); + + float32x4_t d00 = vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), bias_ptr); + float32x4_t d01 = vaddq_f32(vsubq_f32(vsubq_f32(t01, t02), t03), bias_ptr); + float32x4_t d10 = vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), bias_ptr); + float32x4_t d11 = vaddq_f32(vsubq_f32(vsubq_f32(t11, t12), t13), bias_ptr); + + vst1q_f32(dst_output, d00); + if (w_in_range) { + vst1q_f32(dst_output + channel, d01); + } + if (h_in_range) { + vst1q_f32(dst_output + output_w * channel, d10); + if (w_in_range) { + vst1q_f32(dst_output + output_w * channel + channel, d11); + } + } +#else + for (int i = 0; i < C4NUM; i++) { + const float *local_ptr = src_buf + i; + const float *bias_ptr = bias + i; + + float s00 = local_ptr[0]; + float s01 = (local_ptr + 4)[0]; + float s02 = (local_ptr + 8)[0]; + float s03 = (local_ptr + 12)[0]; + + float s10 = (local_ptr + 16)[0]; + float s11 = (local_ptr + 20)[0]; + float s12 = (local_ptr + 24)[0]; + float s13 = (local_ptr + 28)[0]; + + float s20 = (local_ptr + 32)[0]; + float s21 = (local_ptr + 36)[0]; + float s22 = (local_ptr + 40)[0]; + float s23 = (local_ptr + 44)[0]; + + float s30 = (local_ptr + 48)[0]; + float s31 = (local_ptr + 52)[0]; + float s32 = (local_ptr + 56)[0]; + float s33 = (local_ptr + 60)[0]; + + float t00 = s00 + s10 + s20; + float t01 = s01 + s11 + s21; + float t02 = s02 + s12 + s22; + float t03 = s03 + s13 + s23; + + float t10 = s10 - s20 - s30; + float t11 = s11 - s21 - s31; + float t12 = s12 - s22 - s32; + float t13 = s13 - s23 - s33; + + float d00 = t00 + t01 + t02 + bias_ptr[0]; + float d01 = t01 - t02 - t03 + bias_ptr[0]; + float d10 = t10 + t11 + t12 + bias_ptr[0]; + float d11 = t11 - t12 - t13 + bias_ptr[0]; + + (dst_output + i)[0] = d00; + if (w_in_range) { + (dst_output + i + channel)[0] = d01; + } + if (h_in_range) { + (dst_output + i + output_w * channel)[0] = d10; + if (w_in_range) { + (dst_output + i + output_w * channel + channel)[0] = d11; + } + } + } +#endif +} + +void ConvDw3x3Fp32OutputTrans(float *trans_buffer, float *output_data, const float *bias, int out_h_block, + int out_w_block, const ConvParameter *conv_param) { + int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); + bool h_in_range = true; + for (int oh = 0; oh < out_h_block; oh++) { + int real_oh = 2 * oh; + if ((oh + 1) * 2 > conv_param->output_h_) { + h_in_range = false; + } + bool w_in_range = true; + float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM; + float *output_oh = output_data + real_oh * conv_param->output_w_ * oc4 * C4NUM; + + for (int ow = 0; ow < out_w_block; ow++) { + int real_ow = 2 * ow; + if ((ow + 1) * 2 > conv_param->output_w_) { + w_in_range = false; + } + float *buf_ow = buf_oh + ow * 16 * C4NUM; + float *output_ow = output_oh + real_ow * oc4 * C4NUM; + + ConvDw3x3Fp32OutputUnit(buf_ow, output_ow, bias, oc4 * C4NUM, conv_param->output_w_, h_in_range, w_in_range, + conv_param->is_relu_, conv_param->is_relu6_); + } + } +} + +void ConvDw3x3Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + float *trans_buffer, float *block_buffer, const ConvParameter *conv_param, int task_id) { + int thread_count = conv_param->thread_num_; + int output_channel = conv_param->output_channel_; + int oc4 = UP_DIV(output_channel, C4NUM); + int out_h_block = UP_DIV(conv_param->output_h_, 2); + int out_w_block = UP_DIV(conv_param->output_w_, 2); + + int input_batch = conv_param->input_batch_; + for (int batch = 0; batch < input_batch; batch++) { + const float *input = input_data + batch * conv_param->input_h_ * conv_param->input_w_ * + UP_DIV(conv_param->input_channel_, C4NUM) * C4NUM; + float *output = output_data + batch * conv_param->output_h_ * conv_param->output_w_ * + UP_DIV(conv_param->output_channel_, C4NUM) * C4NUM; + for (int oc = task_id; oc < oc4; oc += thread_count) { + const float *weight = weight_data + oc * 16 * C4NUM; + const float *bias = bias_data + oc * C4NUM; + + ConvDw3x3Fp32InputTrans(input + oc * C4NUM, trans_buffer, block_buffer, out_h_block, out_w_block, conv_param); + + ConvDw3x3Fp32Winograd(trans_buffer, weight, out_h_block, out_w_block); + + ConvDw3x3Fp32OutputTrans(trans_buffer, output + oc * C4NUM, bias, out_h_block, out_w_block, conv_param); + } + } +} +/*conv depthwise 3x3 fp32 end*/ + /*deconv depthwise fp32 begin*/ void DeconvDepthwiseBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step, int in_kw_step, int kernel_w) { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h index 8ab5332f1902ff8d4e0d32dcd59ddff58f304e80..e83b6b6dcf4ee809a16b5d5c440bb6c7a34372b7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h @@ -42,8 +42,12 @@ void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_par void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); +void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4); + +void ConvDw3x3Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + float *trans_buffer, float *block_buffer, const ConvParameter *conv_param, int task_id); + void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONV_DEPTHWISE_H_ - diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32_tests.cc new file mode 100644 index 0000000000000000000000000000000000000000..3394ecb5afad962d9565dbb44d3c7f8d08a56e5d --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32_tests.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/ops/ops.h" + +namespace mindspore { +class TestConvolutionDwFp32 : public mindspore::Common { + public: + TestConvolutionDwFp32() {} +}; + +void InitConvDwParam(ConvParameter *conv_param) { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 288; + conv_param->input_w_ = 288; + conv_param->input_channel_ = 25; + + conv_param->output_batch_ = 1; + conv_param->output_h_ = 288; + conv_param->output_w_ = 288; + conv_param->output_channel_ = 25; + + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; +} + +void InitConvDwCreator(std::vector *inputs, std::vector *outputs, + const ConvParameter *conv_param) { + // prepare input, format NHWC + size_t input_size; + std::string input_path = "./test_data/convDw/convDwfp32_input.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + + auto *input = new lite::tensor::Tensor; + input->set_data_type(kNumberTypeFloat32); + input->SetFormat(schema::Format_NHWC); + input->set_shape({conv_param->input_batch_, conv_param->input_h_, conv_param->input_w_, conv_param->input_channel_}); + input->MallocData(); + memcpy(input->Data(), input_data, input_size); + + // prepare weight, format co kh kw ci, ci = 1 + size_t weight_size; + std::string weight_path = "./test_data/convDw/convDwfp32_weight.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + + auto *weight = new lite::tensor::Tensor; + weight->set_data_type(kNumberTypeFloat32); + weight->set_shape({conv_param->output_channel_, conv_param->kernel_h_, conv_param->kernel_w_, 1}); + weight->MallocData(); + memcpy(weight->Data(), weight_data, weight_size); + + // prepare bias + auto *bias = new lite::tensor::Tensor; + bias->set_data_type(kNumberTypeFloat32); + bias->set_shape({conv_param->output_channel_}); + bias->MallocData(); + memset(bias->Data(), 0, bias->ElementsNum() * sizeof(float)); + + inputs->push_back(input); + inputs->push_back(weight); + inputs->push_back(bias); + + auto *output = new lite::tensor::Tensor; + output->set_data_type(kNumberTypeFloat32); + output->set_shape( + {conv_param->output_batch_, conv_param->output_h_, conv_param->output_w_, conv_param->output_channel_}); + output->SetFormat(schema::Format_NHWC); + output->MallocData(); + memset(output->Data(), 0, output->ElementsNum() * sizeof(float)); + outputs->push_back(output); +} + +TEST_F(TestConvolutionDwFp32, ConvDwFp32Accuracy) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvDwParam(conv_param); + + // init ctx + auto ctx = new Context(); + ctx->thread_num_ = 4; + + // init tensor + std::vector inputs; + std::vector outputs; + InitConvDwCreator(&inputs, &outputs, conv_param); + + // register op + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_DepthwiseConv2D}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(conv_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + // op run + kernel->Run(); + + std::cout << "==================output data=================" << std::endl; + auto output_ptr = reinterpret_cast(outputs[0]->Data()); + for (int i = 0; i < 20; i++) { + std::cout << output_ptr[i] << ", "; + } + std::cout << std::endl; + + // read output data, format NHWC + size_t output_size; + std::string output_path = "./test_data/convDw/convDwfp32_output.bin"; + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + + // compare + CompareOutputData(output_ptr, correct_data, outputs[0]->ElementsNum(), 0.0001); + + delete conv_param; + for (int i = 0; i < inputs.size(); i++) { + delete inputs[i]; + } + for (int i = 0; i < outputs.size(); i++) { + delete outputs[i]; + } + delete kernel; + delete correct_data; + MS_LOG(INFO) << "TestConvolutionDwFp32 accuracy passed"; +} + +TEST_F(TestConvolutionDwFp32, ConvDwFp32Performance) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvDwParam(conv_param); + + // init ctx + auto ctx = new Context(); + ctx->thread_num_ = 1; + + // init tensor + std::vector inputs; + std::vector outputs; + InitConvDwCreator(&inputs, &outputs, conv_param); + + // register op + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_DepthwiseConv2D}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(conv_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + + /* running warm up */ + for (int i = 0; i < 3; i++) { + kernel->Run(); + } + + /* running time cost */ + int loop_count = 10; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + uint64_t time_avg = cost / loop_count; + printf("Convolution_depthwise fp32 average time : %f ms\n", time_avg / 1000.0f); + + delete conv_param; + for (int i = 0; i < inputs.size(); i++) { + delete inputs[i]; + } + for (int i = 0; i < outputs.size(); i++) { + delete outputs[i]; + } + delete kernel; + MS_LOG(INFO) << "TestConvolutionDwFp32 performance passed"; +} +} // namespace mindspore