提交 ac3905c2 编写于 作者: F fuzhiye

extract post process func

上级 b5ed5466
......@@ -307,8 +307,8 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
}
// step 4 : output transform
WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
cal_num, out_tile_index, out_w_block, conv_param, output_trans_func);
WinogradOutputTransform(dst_ptr, tmp_out_data + tmp_out_batch_offset, bias_data, cal_num, out_tile_index,
out_w_block, conv_param, output_trans_func);
}
}
}
......@@ -449,8 +449,8 @@ void UnPackWinogradRelu6Output(const float *src, float *dst, int batch, int heig
}
// fp32 conv3x3
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func) {
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func) {
int thread_count = conv_param->thread_num_;
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int output_channel = conv_param->output_channel_;
......@@ -461,6 +461,7 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, C12NUM);
const int input_unit_square = 4 * 4;
float *tile_buffer = buffer_list[0];
float *block_unit_buffer = buffer_list[1];
float *tmp_dst_buffer = buffer_list[2];
......@@ -491,8 +492,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2);
}
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
Conv3x3Fp32OutputTransform(dst_ptr, nc4hw4_out + nc4hw4_buffer_offset, bias_data, start_index, real_cal_num,
out_w_block, conv_param);
}
}
}
......@@ -65,8 +65,8 @@ void UnPackWinogradRelu6Output(const float *src, float *dst, int batch, int heig
int output_unit);
// fp32 conv3x3
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func);
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func);
#ifdef __cplusplus
}
#endif
......
......@@ -61,6 +61,6 @@ typedef struct OpParameter {
int thread_num_;
} OpParameter;
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType;
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6, ActType_Prelu } ActType;
#endif // MINDSPORE_LITE_NNACL_OP_BASE_H_
......@@ -761,6 +761,8 @@ void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane,
}
}
void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel) {}
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
......
......@@ -81,6 +81,8 @@ void PackNC4HW4ToNHWCReluFp32(const void *src, void *dst, int batch, int plane,
void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel);
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
......
......@@ -207,6 +207,28 @@ static int Convolution3x3Fp16Impl(int task_id, LiteParallelGroupEnv *penv, void
return RET_OK;
}
int Convolution3x3FP16CPUKernel::PostProcess() {
auto act_type = conv_param_->act_type_;
switch (act_type) {
case ActType_No:
UnPack3x3OutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_);
break;
case ActType_Relu:
UnPack3x3ReluOutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_);
break;
case ActType_Relu6:
UnPack3x3Relu6OutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_);
break;
default:
MS_LOG(ERROR) << "Unsupport activation type.";
return RET_ERROR;
}
return RET_OK;
}
int Convolution3x3FP16CPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
......@@ -236,20 +258,11 @@ int Convolution3x3FP16CPUKernel::Run() {
return RET_ERROR;
}
// get real output
bool relu = conv_param_->act_type_ == ActType_Relu;
bool relu6 = conv_param_->act_type_ == ActType_Relu6;
if (relu) {
UnPack3x3ReluOutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_);
} else if (relu6) {
UnPack3x3Relu6OutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_);
} else {
UnPack3x3OutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_);
ret = PostProcess();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Post process failed.";
return ret;
}
ConvolutionBaseFP16CPUKernel::IfCastOutput();
ConvolutionBaseFP16CPUKernel::FreeTmpBuffer();
FreeTmpBuffer();
......
......@@ -52,6 +52,7 @@ class Convolution3x3FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
int InitWeightBias();
int InitTmpBuffer();
void ConfigInputOutput();
int PostProcess();
private:
void FreeTmpBuffer() {
......
......@@ -358,6 +358,28 @@ static int ConvolutionWinogradFp16Impl(int task_id, LiteParallelGroupEnv *penv,
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::PostProcess() {
auto act_type = conv_param_->act_type_;
switch (act_type) {
case ActType_No:
UnPackWinogradOutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
break;
case ActType_Relu:
UnPackWinogradReluOutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
break;
case ActType_Relu6:
UnPackWinogradRelu6OutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
break;
default:
MS_LOG(ERROR) << "Unsupport activation type.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
......@@ -390,16 +412,10 @@ int ConvolutionWinogradFP16CPUKernel::Run() {
return RET_ERROR;
}
// get real output
if (conv_param_->act_type_ == ActType_Relu) {
UnPackWinogradReluOutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
} else if (conv_param_->act_type_ == ActType_Relu6) {
UnPackWinogradRelu6OutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
} else {
UnPackWinogradOutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
ret = PostProcess();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Post process failed.";
return ret;
}
ConvolutionBaseFP16CPUKernel::IfCastOutput();
ConvolutionBaseFP16CPUKernel::FreeTmpBuffer();
......
......@@ -56,6 +56,7 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
int MallocFilterMatrix(int oc_block, int oc_block_num);
int InitTmpBuffer();
int ConfigInputOutput();
int PostProcess();
private:
void FreeTmpBuffer() {
......
......@@ -207,9 +207,8 @@ int Convolution3x3CPUKernel::RunImpl(int task_id) {
MS_LOG(ERROR) << "gemm_func is nullptr.";
return RET_ERROR;
}
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
Conv3x3Fp32(reinterpret_cast<float *>(nhwc4_input_), transformed_filter_addr_, reinterpret_cast<float *>(bias_data_),
output_addr, tmp_buffer_address_list_, task_id, conv_param_, gemm_func_);
tmp_buffer_address_list_, task_id, conv_param_, gemm_func_);
return RET_OK;
}
......@@ -223,6 +222,29 @@ int Convolution3x3Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
return RET_OK;
}
int Convolution3x3CPUKernel::PostProcess() {
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
auto act_type = conv_param_->act_type_;
switch (act_type) {
case ActType_No:
PackNC4HW4ToNHWCFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
break;
case ActType_Relu:
PackNC4HW4ToNHWCReluFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
break;
case ActType_Relu6:
PackNC4HW4ToNHWCRelu6Fp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
break;
default:
MS_LOG(ERROR) << "Unsupport activation type.";
return RET_ERROR;
}
return RET_OK;
}
int Convolution3x3CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
......@@ -247,18 +269,10 @@ int Convolution3x3CPUKernel::Run() {
return RET_ERROR;
}
auto is_relu = conv_param_->act_type_ == ActType_Relu;
auto is_relu6 = conv_param_->act_type_ == ActType_Relu6;
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
if (is_relu) {
PackNC4HW4ToNHWCReluFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
} else if (is_relu6) {
PackNC4HW4ToNHWCRelu6Fp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
} else {
PackNC4HW4ToNHWCFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
ret = PostProcess();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Post process failed.";
return ret;
}
FreeTmpBuffer();
return RET_OK;
......
......@@ -45,6 +45,7 @@ class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel {
int InitWeightBias();
int InitTmpBuffer();
void ConfigInputOutput();
int PostProcess();
private:
void FreeTmpBuffer() {
......
......@@ -353,18 +353,43 @@ int ConvolutionWinogradImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata
return RET_OK;
}
int ConvolutionWinogradCPUKernel::PostProcess() {
auto out_tensor = out_tensors_.front();
auto out_data = reinterpret_cast<float *>(out_tensor->Data());
auto act_type = conv_param_->act_type_;
switch (act_type) {
case ActType_No:
UnPackWinogradOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
break;
case ActType_Relu:
UnPackWinogradReluOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
break;
case ActType_Relu6:
UnPackWinogradRelu6Output(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
break;
default:
MS_LOG(ERROR) << "Unsupport activation type.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionWinogradCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
// malloc tmp buffer
auto ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = input_tensor->Data();
PackNHWCToNHWC4Fp32(ori_input_data, nhwc4_input_, conv_param_->input_batch_,
......@@ -377,18 +402,10 @@ int ConvolutionWinogradCPUKernel::Run() {
return RET_ERROR;
}
// get real output
auto out_tensor = out_tensors_.front();
auto out_data = reinterpret_cast<float *>(out_tensor->Data());
if (conv_param_->act_type_ == ActType_Relu) {
UnPackWinogradReluOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
} else if (conv_param_->act_type_ == ActType_Relu6) {
UnPackWinogradRelu6Output(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
} else {
UnPackWinogradOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
ret = PostProcess();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Post process failed.";
return ret;
}
FreeTmpBuffer();
return RET_OK;
......
......@@ -51,6 +51,7 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
int MallocFilterMatrix(int oc_block, int oc_block_num);
int InitTmpBuffer();
int ConfigInputOutput();
int PostProcess();
private:
void FreeTmpBuffer() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部