提交 94f8cf67 编写于 作者: Y yangruoqi713

[MS][LITE] add arm fp32 op: conv depthwise 3x3, add testcase for conv depthwise

上级 1f28a7c0
......@@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
......@@ -182,7 +183,15 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::T
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx);
kernel::LiteKernel *kernel;
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx);
// auto param = reinterpret_cast<ConvParameter *>(opParameter);
// if (param->kernel_h_ == 3 && param->kernel_w_ == 3 && param->stride_h_ == 1 && param->stride_w_ == 1 &&
// param->dilation_h_ == 1 && param->dilation_w_ == 1) {
// kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx);
// } else {
// kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx);
// }
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
namespace mindspore::kernel {
int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
// init weight: o, h, w, i; o == group, i == 1
auto weight_tensor = inputs_[kWeightIndex];
auto origin_weight = reinterpret_cast<float *>(weight_tensor->Data());
// o h w 1 -> o/4 h w 1 4
int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM);
int weight_c4_size = OC4 * C4NUM * 9;
auto tmp_weight = reinterpret_cast<float *>(malloc(weight_c4_size * sizeof(float)));
if (tmp_weight == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
memset(tmp_weight, 0, weight_c4_size * sizeof(float));
PackNCHWToNC4HW4Fp32(origin_weight, tmp_weight, 1, conv_param_->kernel_h_ * conv_param_->kernel_w_,
conv_param_->output_channel_);
// weight transform
int packed_weight_size = OC4 * C4NUM * 16;
packed_weight_ = reinterpret_cast<float *>(malloc(packed_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
memset(packed_weight_, 0, packed_weight_size * sizeof(float));
ConvDw3x3Fp32FilterTrans(packed_weight_, tmp_weight, OC4);
// init bias
bias_data_ = reinterpret_cast<float *>(malloc(C4NUM * OC4 * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
memset(bias_data_, 0, C4NUM * OC4 * sizeof(float));
if (inputs_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data());
memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(float));
}
return RET_OK;
}
int ConvolutionDepthwise3x3CPUKernel::InitBuffer() {
if (conv_param_->input_channel_ % C4NUM != 0) {
need_align_ = true;
int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4;
packed_input_ = reinterpret_cast<float *>(malloc(pack_input_size * sizeof(float)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
memset(packed_input_, 0, pack_input_size * sizeof(float));
int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM);
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4;
packed_output_ = reinterpret_cast<float *>(malloc(pack_output_size * sizeof(float)));
if (packed_output_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
}
// malloc transform buffer
trans_size_ = UP_DIV(conv_param_->output_w_, 2) * UP_DIV(conv_param_->output_h_, 2) * 16 * C4NUM;
size_t trans_buffer_size = thread_count_ * trans_size_ * sizeof(float);
trans_buffer_ = reinterpret_cast<float *>(malloc(trans_buffer_size));
if (trans_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc trans buffer failed.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionDepthwise3x3CPUKernel::Init() {
// conv base init
ConvolutionBaseCPUKernel::Init();
auto ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Depthwise3x3 fp32 initWeightBias error!";
return ret;
}
// init threadNum;
conv_param_->thread_num_ = MSMIN(thread_count_, UP_DIV(conv_param_->output_channel_, C4NUM));
ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Depthwise3x3 fp32 initBuffer error!";
return ret;
}
// malloc one block buffer
block_buffer_ = reinterpret_cast<float *>(malloc(thread_count_ * 16 * C4NUM * sizeof(float)));
if (block_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc block buffer failed.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionDepthwise3x3CPUKernel::ReSize() {
if (need_align_) {
free(packed_input_);
free(packed_output_);
}
free(trans_buffer_);
// conv base init
ConvolutionBaseCPUKernel::Init();
auto ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Depthwise3x3 fp32 initBuffer error!";
return ret;
}
return RET_OK;
}
int ConvolutionDepthwise3x3CPUKernel::Execute(int task_id) {
auto trans_buf = trans_buffer_ + task_id * trans_size_;
auto block_buf = block_buffer_ + task_id * 16 * C4NUM;
ConvDw3x3Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), trans_buf,
block_buf, conv_param_, task_id);
return RET_OK;
}
int ConvDw3x3Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv_dw_3x3 = reinterpret_cast<ConvolutionDepthwise3x3CPUKernel *>(cdata);
auto ret = conv_dw_3x3->Execute(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionDepthwise3x3Run error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionDepthwise3x3CPUKernel::Run() {
if (conv_param_->input_channel_ != conv_param_->output_channel_) {
MS_LOG(ERROR) << "Only support input channel equals output channel.";
return RET_ERROR;
}
auto input_tensor = inputs_.at(kInputIndex);
auto input_addr = reinterpret_cast<float *>(input_tensor->Data());
// pack input: to nhwc4
if (need_align_) {
PackNHWCToNHWC4Fp32(input_addr, packed_input_, conv_param_->input_batch_,
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
} else {
packed_input_ = input_addr;
}
auto output_addr = reinterpret_cast<float *>(outputs_.at(kOutputIndex)->Data());
if (!need_align_) {
packed_output_ = output_addr;
}
auto ret = LiteBackendParallelLaunch(ConvDw3x3Run, this, conv_param_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]";
return RET_ERROR;
}
if (need_align_) {
PackNHWC4ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
}
return RET_OK;
}
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_
#define MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h"
namespace mindspore::kernel {
class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel {
public:
ConvolutionDepthwise3x3CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~ConvolutionDepthwise3x3CPUKernel() override {
free(packed_weight_);
if (need_align_) {
free(packed_input_);
free(packed_output_);
}
free(block_buffer_);
free(trans_buffer_);
};
int Init() override;
int ReSize() override;
int Run() override;
int InitWeightBias();
int InitBuffer();
int Execute(int task_id);
private:
float *packed_weight_;
float *packed_input_;
float *packed_output_;
float *block_buffer_;
float *trans_buffer_;
int trans_size_;
bool need_align_ = false;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_
......@@ -16,6 +16,7 @@
#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h"
#include "src/runtime/kernel/arm/nnacl/fp32/common_func.h"
#include "src/runtime/kernel/arm/nnacl/winograd_transform.h"
#ifdef ENABLE_ARM64
#include <arm_neon.h>
#endif
......@@ -212,6 +213,372 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig
}
/*conv depthwise fp32 end*/
/*conv depthwise 3x3 fp32 begin*/
void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4) {
for (int c = 0; c < oc4; c++) {
float *src = weight + c * C4NUM * 9;
float *dst = trans_weight + c * C4NUM * 16;
#ifdef ENABLE_ARM
float32x4_t g00 = vld1q_f32(src);
float32x4_t g01 = vld1q_f32(src + 4);
float32x4_t g02 = vld1q_f32(src + 2 * 4);
float32x4_t g10 = vld1q_f32(src + 3 * 4);
float32x4_t g11 = vld1q_f32(src + 4 * 4);
float32x4_t g12 = vld1q_f32(src + 5 * 4);
float32x4_t g20 = vld1q_f32(src + 6 * 4);
float32x4_t g21 = vld1q_f32(src + 7 * 4);
float32x4_t g22 = vld1q_f32(src + 8 * 4);
float32x4_t dst00 = g00;
float32x4_t dst01 = g01;
float32x4_t dst02 = g02;
float32x4_t dst10 = vaddq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5));
dst10 = vaddq_f32(dst10, vmulq_n_f32(g20, 0.5));
float32x4_t dst11 = vaddq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5));
dst11 = vaddq_f32(dst11, vmulq_n_f32(g21, 0.5));
float32x4_t dst12 = vaddq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5));
dst12 = vaddq_f32(dst12, vmulq_n_f32(g22, 0.5));
float32x4_t dst20 = vsubq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5));
dst20 = vaddq_f32(dst20, vmulq_n_f32(g20, 0.5));
float32x4_t dst21 = vsubq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5));
dst21 = vaddq_f32(dst21, vmulq_n_f32(g21, 0.5));
float32x4_t dst22 = vsubq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5));
dst22 = vaddq_f32(dst22, vmulq_n_f32(g22, 0.5));
float32x4_t dst30 = g20;
float32x4_t dst31 = g21;
float32x4_t dst32 = g22;
float32x4_t m00 = dst00;
float32x4_t m01 = vaddq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5));
m01 = vaddq_f32(m01, vmulq_n_f32(dst02, 0.5));
float32x4_t m02 = vsubq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5));
m02 = vaddq_f32(m02, vmulq_n_f32(dst02, 0.5));
float32x4_t m03 = dst02;
float32x4_t m10 = dst10;
float32x4_t m11 = vaddq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5));
m11 = vaddq_f32(m11, vmulq_n_f32(dst12, 0.5));
float32x4_t m12 = vsubq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5));
m12 = vaddq_f32(m12, vmulq_n_f32(dst12, 0.5));
float32x4_t m13 = dst12;
float32x4_t m20 = dst20;
float32x4_t m21 = vaddq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5));
m21 = vaddq_f32(m21, vmulq_n_f32(dst22, 0.5));
float32x4_t m22 = vsubq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5));
m22 = vaddq_f32(m22, vmulq_n_f32(dst22, 0.5));
float32x4_t m23 = dst22;
float32x4_t m30 = dst30;
float32x4_t m31 = vaddq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5));
m31 = vaddq_f32(m31, vmulq_n_f32(dst32, 0.5));
float32x4_t m32 = vsubq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5));
m32 = vaddq_f32(m32, vmulq_n_f32(dst32, 0.5));
float32x4_t m33 = dst32;
vst1q_f32(dst, m00);
vst1q_f32(dst + 4, m01);
vst1q_f32(dst + 8, m02);
vst1q_f32(dst + 12, m03);
vst1q_f32(dst + 16, m10);
vst1q_f32(dst + 20, m11);
vst1q_f32(dst + 24, m12);
vst1q_f32(dst + 28, m13);
vst1q_f32(dst + 32, m20);
vst1q_f32(dst + 36, m21);
vst1q_f32(dst + 40, m22);
vst1q_f32(dst + 44, m23);
vst1q_f32(dst + 48, m30);
vst1q_f32(dst + 52, m31);
vst1q_f32(dst + 56, m32);
vst1q_f32(dst + 60, m33);
#else
for (int j = 0; j < C4NUM; j++) {
float *local_ptr = src + j;
float dst00 = local_ptr[0];
float dst01 = (local_ptr + 4)[0];
float dst02 = (local_ptr + 8)[0];
float dst10 = 0.5f * local_ptr[0] + 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
float dst11 = 0.5f * (local_ptr + 4)[0] + 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
float dst12 = 0.5f * (local_ptr + 8)[0] + 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];
float dst20 = 0.5f * local_ptr[0] - 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
float dst21 = 0.5f * (local_ptr + 4)[0] - 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
float dst22 = 0.5f * (local_ptr + 8)[0] - 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];
float dst30 = (local_ptr + 24)[0];
float dst31 = (local_ptr + 28)[0];
float dst32 = (local_ptr + 32)[0];
float m00 = dst00;
float m01 = 0.5f * dst00 + 0.5f * dst01 + 0.5f * dst02;
float m02 = 0.5f * dst00 - 0.5f * dst01 + 0.5f * dst02;
float m03 = dst02;
float m10 = dst10;
float m11 = 0.5f * dst10 + 0.5f * dst11 + 0.5f * dst12;
float m12 = 0.5f * dst10 - 0.5f * dst11 + 0.5f * dst12;
float m13 = dst12;
float m20 = dst20;
float m21 = 0.5f * dst20 + 0.5f * dst21 + 0.5f * dst22;
float m22 = 0.5f * dst20 - 0.5f * dst21 + 0.5f * dst22;
float m23 = dst22;
float m30 = dst30;
float m31 = 0.5f * dst30 + 0.5f * dst31 + 0.5f * dst32;
float m32 = 0.5f * dst30 - 0.5f * dst31 + 0.5f * dst32;
float m33 = dst32;
*(dst + j) = m00;
*(dst + j + 4) = m01;
*(dst + j + 8) = m02;
*(dst + j + 12) = m03;
*(dst + j + 16) = m10;
*(dst + j + 20) = m11;
*(dst + j + 24) = m12;
*(dst + j + 28) = m13;
*(dst + j + 32) = m20;
*(dst + j + 36) = m21;
*(dst + j + 40) = m22;
*(dst + j + 44) = m23;
*(dst + j + 48) = m30;
*(dst + j + 52) = m31;
*(dst + j + 56) = m32;
*(dst + j + 60) = m33;
}
#endif
}
}
void ConvDw3x3Fp32InputTrans(const float *input_data, float *trans_input, float *block_buffer, int out_h_block,
int out_w_block, const ConvParameter *conv_param) {
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int input_unit = 4;
memset(trans_input, 0, out_h_block * out_h_block * 16 * C4NUM * sizeof(float));
for (int oh = 0; oh < out_h_block; oh++) {
int ih = oh * 2 - conv_param->pad_h_;
int real_h_start = ih > 0 ? 0 : -ih;
int real_h_end = (ih + input_unit) < conv_param->input_h_ ? input_unit : (conv_param->input_h_ - ih);
for (int ow = 0; ow < out_w_block; ow++) {
int iw = ow * 2 - conv_param->pad_w_;
int real_w_start = iw > 0 ? 0 : -iw;
int real_w_end = (iw + input_unit) < conv_param->input_w_ ? input_unit : (conv_param->input_w_ - iw);
memset(block_buffer, 0, 16 * C4NUM * sizeof(float));
int src_plane_offset = ic4 * C4NUM * (ih * conv_param->input_w_ + iw);
for (int h = real_h_start; h < real_h_end; h++) {
int src_h_offset = src_plane_offset + (h * conv_param->input_w_) * ic4 * C4NUM;
int dst_h_offset = (h * input_unit) * C4NUM;
for (int w = real_w_start; w < real_w_end; w++) {
int src_w_offset = src_h_offset + w * ic4 * C4NUM;
int dst_w_offset = dst_h_offset + w * C4NUM;
float *src_addr = (float *)(input_data) + src_w_offset;
float *dst_addr = block_buffer + dst_w_offset;
#ifdef ENABLE_NEON
vst1q_f32(dst_addr, vld1q_f32(src_addr));
#else
for (int k = 0; k < C4NUM; k++) {
(dst_addr + k)[0] = (src_addr + k)[0];
}
#endif
}
}
int trans_offset = (oh * out_w_block + ow) * 16 * C4NUM;
Conv3x3Fp32InputUnit(block_buffer, trans_input + trans_offset, C4NUM);
}
}
}
// todo yangruoqi: implement assembly
void ConvDw3x3Fp32Winograd(float *trans_buffer, const float *weight, int out_h_block, int out_w_block) {
int unit = 4;
for (int oh = 0; oh < out_h_block; oh++) {
float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM;
for (int ow = 0; ow < out_w_block; ow++) {
float *buf_ow = buf_oh + ow * 16 * C4NUM;
for (int kh = 0; kh < unit; kh++) {
float *buf_kh = buf_ow + kh * unit * C4NUM;
const float *weight_kh = weight + kh * unit * C4NUM;
for (int kw = 0; kw < unit; kw++) {
float *buf_kw = buf_kh + kw * C4NUM;
const float *weight_kw = weight_kh + kw * C4NUM;
for (int c = 0; c < C4NUM; c++) {
buf_kw[c] = buf_kw[c] * weight_kw[c];
}
}
}
}
}
}
void ConvDw3x3Fp32OutputUnit(float *src_buf, float *dst_output, const float *bias, int channel, int output_w,
bool h_in_range, bool w_in_range, bool is_relu, bool is_relu6) {
#ifdef ENABLE_ARM
float32x4_t bias_ptr = vld1q_f32(bias);
float32x4_t s00 = vld1q_f32(src_buf);
float32x4_t s01 = vld1q_f32(src_buf + 4);
float32x4_t s02 = vld1q_f32(src_buf + 8);
float32x4_t s03 = vld1q_f32(src_buf + 12);
float32x4_t s10 = vld1q_f32(src_buf + 16);
float32x4_t s11 = vld1q_f32(src_buf + 20);
float32x4_t s12 = vld1q_f32(src_buf + 24);
float32x4_t s13 = vld1q_f32(src_buf + 28);
float32x4_t s20 = vld1q_f32(src_buf + 32);
float32x4_t s21 = vld1q_f32(src_buf + 36);
float32x4_t s22 = vld1q_f32(src_buf + 40);
float32x4_t s23 = vld1q_f32(src_buf + 44);
float32x4_t s30 = vld1q_f32(src_buf + 48);
float32x4_t s31 = vld1q_f32(src_buf + 52);
float32x4_t s32 = vld1q_f32(src_buf + 56);
float32x4_t s33 = vld1q_f32(src_buf + 60);
float32x4_t t00 = vaddq_f32(vaddq_f32(s00, s10), s20);
float32x4_t t01 = vaddq_f32(vaddq_f32(s01, s11), s21);
float32x4_t t02 = vaddq_f32(vaddq_f32(s02, s12), s22);
float32x4_t t03 = vaddq_f32(vaddq_f32(s03, s13), s23);
float32x4_t t10 = vsubq_f32(vsubq_f32(s10, s20), s30);
float32x4_t t11 = vsubq_f32(vsubq_f32(s11, s21), s31);
float32x4_t t12 = vsubq_f32(vsubq_f32(s12, s22), s32);
float32x4_t t13 = vsubq_f32(vsubq_f32(s13, s23), s33);
float32x4_t d00 = vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), bias_ptr);
float32x4_t d01 = vaddq_f32(vsubq_f32(vsubq_f32(t01, t02), t03), bias_ptr);
float32x4_t d10 = vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), bias_ptr);
float32x4_t d11 = vaddq_f32(vsubq_f32(vsubq_f32(t11, t12), t13), bias_ptr);
vst1q_f32(dst_output, d00);
if (w_in_range) {
vst1q_f32(dst_output + channel, d01);
}
if (h_in_range) {
vst1q_f32(dst_output + output_w * channel, d10);
if (w_in_range) {
vst1q_f32(dst_output + output_w * channel + channel, d11);
}
}
#else
for (int i = 0; i < C4NUM; i++) {
const float *local_ptr = src_buf + i;
const float *bias_ptr = bias + i;
float s00 = local_ptr[0];
float s01 = (local_ptr + 4)[0];
float s02 = (local_ptr + 8)[0];
float s03 = (local_ptr + 12)[0];
float s10 = (local_ptr + 16)[0];
float s11 = (local_ptr + 20)[0];
float s12 = (local_ptr + 24)[0];
float s13 = (local_ptr + 28)[0];
float s20 = (local_ptr + 32)[0];
float s21 = (local_ptr + 36)[0];
float s22 = (local_ptr + 40)[0];
float s23 = (local_ptr + 44)[0];
float s30 = (local_ptr + 48)[0];
float s31 = (local_ptr + 52)[0];
float s32 = (local_ptr + 56)[0];
float s33 = (local_ptr + 60)[0];
float t00 = s00 + s10 + s20;
float t01 = s01 + s11 + s21;
float t02 = s02 + s12 + s22;
float t03 = s03 + s13 + s23;
float t10 = s10 - s20 - s30;
float t11 = s11 - s21 - s31;
float t12 = s12 - s22 - s32;
float t13 = s13 - s23 - s33;
float d00 = t00 + t01 + t02 + bias_ptr[0];
float d01 = t01 - t02 - t03 + bias_ptr[0];
float d10 = t10 + t11 + t12 + bias_ptr[0];
float d11 = t11 - t12 - t13 + bias_ptr[0];
(dst_output + i)[0] = d00;
if (w_in_range) {
(dst_output + i + channel)[0] = d01;
}
if (h_in_range) {
(dst_output + i + output_w * channel)[0] = d10;
if (w_in_range) {
(dst_output + i + output_w * channel + channel)[0] = d11;
}
}
}
#endif
}
void ConvDw3x3Fp32OutputTrans(float *trans_buffer, float *output_data, const float *bias, int out_h_block,
int out_w_block, const ConvParameter *conv_param) {
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
bool h_in_range = true;
for (int oh = 0; oh < out_h_block; oh++) {
int real_oh = 2 * oh;
if ((oh + 1) * 2 > conv_param->output_h_) {
h_in_range = false;
}
bool w_in_range = true;
float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM;
float *output_oh = output_data + real_oh * conv_param->output_w_ * oc4 * C4NUM;
for (int ow = 0; ow < out_w_block; ow++) {
int real_ow = 2 * ow;
if ((ow + 1) * 2 > conv_param->output_w_) {
w_in_range = false;
}
float *buf_ow = buf_oh + ow * 16 * C4NUM;
float *output_ow = output_oh + real_ow * oc4 * C4NUM;
ConvDw3x3Fp32OutputUnit(buf_ow, output_ow, bias, oc4 * C4NUM, conv_param->output_w_, h_in_range, w_in_range,
conv_param->is_relu_, conv_param->is_relu6_);
}
}
}
void ConvDw3x3Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
float *trans_buffer, float *block_buffer, const ConvParameter *conv_param, int task_id) {
int thread_count = conv_param->thread_num_;
int output_channel = conv_param->output_channel_;
int oc4 = UP_DIV(output_channel, C4NUM);
int out_h_block = UP_DIV(conv_param->output_h_, 2);
int out_w_block = UP_DIV(conv_param->output_w_, 2);
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
const float *input = input_data + batch * conv_param->input_h_ * conv_param->input_w_ *
UP_DIV(conv_param->input_channel_, C4NUM) * C4NUM;
float *output = output_data + batch * conv_param->output_h_ * conv_param->output_w_ *
UP_DIV(conv_param->output_channel_, C4NUM) * C4NUM;
for (int oc = task_id; oc < oc4; oc += thread_count) {
const float *weight = weight_data + oc * 16 * C4NUM;
const float *bias = bias_data + oc * C4NUM;
ConvDw3x3Fp32InputTrans(input + oc * C4NUM, trans_buffer, block_buffer, out_h_block, out_w_block, conv_param);
ConvDw3x3Fp32Winograd(trans_buffer, weight, out_h_block, out_w_block);
ConvDw3x3Fp32OutputTrans(trans_buffer, output + oc * C4NUM, bias, out_h_block, out_w_block, conv_param);
}
}
}
/*conv depthwise 3x3 fp32 end*/
/*deconv depthwise fp32 begin*/
void DeconvDepthwiseBorderPixel(float *dst, const float *src, const float *weight, int height, int width,
int in_kh_step, int in_kw_step, int kernel_w) {
......
......@@ -42,8 +42,12 @@ void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_par
void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4);
void ConvDw3x3Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
float *trans_buffer, float *block_buffer, const ConvParameter *conv_param, int task_id);
void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONV_DEPTHWISE_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/ops/ops.h"
namespace mindspore {
class TestConvolutionDwFp32 : public mindspore::Common {
public:
TestConvolutionDwFp32() {}
};
void InitConvDwParam(ConvParameter *conv_param) {
conv_param->input_batch_ = 1;
conv_param->input_h_ = 288;
conv_param->input_w_ = 288;
conv_param->input_channel_ = 25;
conv_param->output_batch_ = 1;
conv_param->output_h_ = 288;
conv_param->output_w_ = 288;
conv_param->output_channel_ = 25;
conv_param->kernel_h_ = 3;
conv_param->kernel_w_ = 3;
conv_param->stride_h_ = 1;
conv_param->stride_w_ = 1;
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 1;
conv_param->pad_w_ = 1;
}
void InitConvDwCreator(std::vector<lite::tensor::Tensor *> *inputs, std::vector<lite::tensor::Tensor *> *outputs,
const ConvParameter *conv_param) {
// prepare input, format NHWC
size_t input_size;
std::string input_path = "./test_data/convDw/convDwfp32_input.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
auto *input = new lite::tensor::Tensor;
input->set_data_type(kNumberTypeFloat32);
input->SetFormat(schema::Format_NHWC);
input->set_shape({conv_param->input_batch_, conv_param->input_h_, conv_param->input_w_, conv_param->input_channel_});
input->MallocData();
memcpy(input->Data(), input_data, input_size);
// prepare weight, format co kh kw ci, ci = 1
size_t weight_size;
std::string weight_path = "./test_data/convDw/convDwfp32_weight.bin";
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size));
auto *weight = new lite::tensor::Tensor;
weight->set_data_type(kNumberTypeFloat32);
weight->set_shape({conv_param->output_channel_, conv_param->kernel_h_, conv_param->kernel_w_, 1});
weight->MallocData();
memcpy(weight->Data(), weight_data, weight_size);
// prepare bias
auto *bias = new lite::tensor::Tensor;
bias->set_data_type(kNumberTypeFloat32);
bias->set_shape({conv_param->output_channel_});
bias->MallocData();
memset(bias->Data(), 0, bias->ElementsNum() * sizeof(float));
inputs->push_back(input);
inputs->push_back(weight);
inputs->push_back(bias);
auto *output = new lite::tensor::Tensor;
output->set_data_type(kNumberTypeFloat32);
output->set_shape(
{conv_param->output_batch_, conv_param->output_h_, conv_param->output_w_, conv_param->output_channel_});
output->SetFormat(schema::Format_NHWC);
output->MallocData();
memset(output->Data(), 0, output->ElementsNum() * sizeof(float));
outputs->push_back(output);
}
TEST_F(TestConvolutionDwFp32, ConvDwFp32Accuracy) {
// prepare stage
auto conv_param = new ConvParameter();
InitConvDwParam(conv_param);
// init ctx
auto ctx = new Context();
ctx->thread_num_ = 4;
// init tensor
std::vector<lite::tensor::Tensor *> inputs;
std::vector<lite::tensor::Tensor *> outputs;
InitConvDwCreator(&inputs, &outputs, conv_param);
// register op
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_DepthwiseConv2D};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), ctx, desc);
ASSERT_NE(kernel, nullptr);
// op run
kernel->Run();
std::cout << "==================output data=================" << std::endl;
auto output_ptr = reinterpret_cast<float *>(outputs[0]->Data());
for (int i = 0; i < 20; i++) {
std::cout << output_ptr[i] << ", ";
}
std::cout << std::endl;
// read output data, format NHWC
size_t output_size;
std::string output_path = "./test_data/convDw/convDwfp32_output.bin";
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
// compare
CompareOutputData(output_ptr, correct_data, outputs[0]->ElementsNum(), 0.0001);
delete conv_param;
for (int i = 0; i < inputs.size(); i++) {
delete inputs[i];
}
for (int i = 0; i < outputs.size(); i++) {
delete outputs[i];
}
delete kernel;
delete correct_data;
MS_LOG(INFO) << "TestConvolutionDwFp32 accuracy passed";
}
TEST_F(TestConvolutionDwFp32, ConvDwFp32Performance) {
// prepare stage
auto conv_param = new ConvParameter();
InitConvDwParam(conv_param);
// init ctx
auto ctx = new Context();
ctx->thread_num_ = 1;
// init tensor
std::vector<lite::tensor::Tensor *> inputs;
std::vector<lite::tensor::Tensor *> outputs;
InitConvDwCreator(&inputs, &outputs, conv_param);
// register op
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_DepthwiseConv2D};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), ctx, desc);
ASSERT_NE(kernel, nullptr);
/* running warm up */
for (int i = 0; i < 3; i++) {
kernel->Run();
}
/* running time cost */
int loop_count = 10;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
kernel->Run();
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
uint64_t time_avg = cost / loop_count;
printf("Convolution_depthwise fp32 average time : %f ms\n", time_avg / 1000.0f);
delete conv_param;
for (int i = 0; i < inputs.size(); i++) {
delete inputs[i];
}
for (int i = 0; i < outputs.size(); i++) {
delete outputs[i];
}
delete kernel;
MS_LOG(INFO) << "TestConvolutionDwFp32 performance passed";
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册