提交 7c0db5e2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4648 optimize fp16 conv3x3 input transform

Merge pull request !4648 from fuzhiye/tmp
......@@ -36,15 +36,15 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara
auto input_channel = conv_param->input_channel_;
auto output_channel = conv_param->output_channel_;
auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
int iC4 = UP_DIV(input_channel, C4NUM);
int iC8 = UP_DIV(input_channel, C8NUM);
int oC8 = UP_DIV(output_channel, C8NUM);
size_t tmp_size = oC8 * C8NUM * iC4 * C4NUM * kernel_plane * sizeof(float16_t);
size_t tmp_size = oC8 * C8NUM * iC8 * C8NUM * kernel_plane * sizeof(float16_t);
auto tmp_addr = reinterpret_cast<float16_t *>(malloc(tmp_size));
memset(tmp_addr, 0, tmp_size);
PackWeightToC4Fp16(origin_weight, tmp_addr, conv_param);
Conv3x3Fp16FilterTransform(tmp_addr, dst_weight, iC4, output_channel, kernel_plane);
Conv3x3Fp16FilterTransform(tmp_addr, dst_weight, iC8 * 2, output_channel, kernel_plane);
free(tmp_addr);
}
......@@ -52,10 +52,10 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara
int Convolution3x3FP16CPUKernel::InitWeightBias() {
auto input_channel = conv_param_->input_channel_;
int output_channel = conv_param_->output_channel_;
int iC4 = UP_DIV(input_channel, C4NUM);
int iC8 = UP_DIV(input_channel, C8NUM);
int oC8 = UP_DIV(output_channel, C8NUM);
// init weight
size_t transformed_size = iC4 * C4NUM * oC8 * C8NUM * 36 * sizeof(float16_t);
size_t transformed_size = iC8 * C8NUM * oC8 * C8NUM * 36 * sizeof(float16_t);
transformed_filter_addr_ = reinterpret_cast<float16_t *>(malloc(transformed_size));
if (transformed_filter_addr_ == nullptr) {
MS_LOG(ERROR) << "malloc transformed_filter_addr_ failed.";
......@@ -92,11 +92,11 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() {
int Convolution3x3FP16CPUKernel::InitTmpBuffer() {
const int tile_num = 16;
const int k_plane = 36;
int iC4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int iC8 = UP_DIV(conv_param_->input_channel_, C8NUM);
int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM);
/*=============================tile_buffer_============================*/
size_t tile_buffer_size = thread_count_ * tile_num * k_plane * iC4 * C4NUM * sizeof(float16_t);
size_t tile_buffer_size = thread_count_ * tile_num * k_plane * iC8 * C8NUM * sizeof(float16_t);
tile_buffer_ = reinterpret_cast<float16_t *>(malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile_buffer_ failed.";
......@@ -105,7 +105,7 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() {
memset(tile_buffer_, 0, tile_buffer_size);
/*=============================block_unit_buffer_============================*/
size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float16_t);
size_t block_unit_buffer_size = thread_count_ * k_plane * C8NUM * sizeof(float16_t);
block_unit_buffer_ = reinterpret_cast<float16_t *>(malloc(block_unit_buffer_size));
if (block_unit_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc block_unit_buffer_ failed.";
......@@ -133,14 +133,14 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() {
memset(tmp_out_, 0, tmp_out_size);
/*=============================nhwc4_input_============================*/
size_t nhwc4_input_size =
iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
nhwc4_input_ = malloc(nhwc4_input_size);
size_t nhwc8_input_size =
iC8 * C8NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
nhwc4_input_ = malloc(nhwc8_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
return RET_ERROR;
}
memset(nhwc4_input_, 0, nhwc4_input_size);
memset(nhwc4_input_, 0, nhwc8_input_size);
return RET_OK;
}
......@@ -189,7 +189,6 @@ int Convolution3x3FP16CPUKernel::ReSize() {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
ConfigInputOutput();
return RET_OK;
}
......@@ -225,7 +224,7 @@ int Convolution3x3FP16CPUKernel::Run() {
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int in_channel = conv_param_->input_channel_;
convert_func_(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
PackNHWCToNHWC8Fp16(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
int error_code = LiteBackendParallelLaunch(Convolution3x3Fp16Impl, this, thread_count_);
if (error_code != RET_OK) {
......
......@@ -150,10 +150,11 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
int ConvolutionWinogradFP16CPUKernel::MallocFilterMatrix(int oc_block, int oc_block_num) {
int channel_in = conv_param_->input_channel_;
int ic4 = UP_DIV(channel_in, BLOCK);
int ic8 = UP_DIV(channel_in, C8NUM);
int ic4 = ic8 * 2;
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float);
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic8 * C8NUM * oc_block_num * oc_block * sizeof(float);
auto matrix_buffer = malloc(trans_matrix_data_size);
if (matrix_buffer == nullptr) {
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
......@@ -191,11 +192,11 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
int channel_out = conv_param_->output_channel_;
int output_h = conv_param_->output_h_;
int output_w = conv_param_->output_w_;
int ic4 = UP_DIV(channel_in, C4NUM);
int ic8 = UP_DIV(channel_in, C8NUM);
int oc8 = UP_DIV(channel_out, C8NUM);
/*=============================trans_input_============================*/
size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float16_t);
size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic8 * C8NUM * sizeof(float16_t);
trans_input_ = reinterpret_cast<float16_t *>(malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
......@@ -223,12 +224,12 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
/*=============================tmp_data_============================*/
tmp_data_ =
reinterpret_cast<float16_t *>(malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float16_t)));
reinterpret_cast<float16_t *>(malloc(thread_count_ * C8NUM * input_unit_ * input_unit_ * sizeof(float16_t)));
if (tmp_data_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_data_ failed.";
return RET_ERROR;
}
memset(tmp_data_, 0, C4NUM * input_unit_ * input_unit_ * sizeof(float16_t));
memset(tmp_data_, 0, C8NUM * input_unit_ * input_unit_ * sizeof(float16_t));
tmp_buffer_address_list_[0] = trans_input_;
tmp_buffer_address_list_[1] = gemm_out_;
......@@ -236,24 +237,18 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
tmp_buffer_address_list_[3] = tmp_data_;
/*=============================nhwc4_input_============================*/
size_t nhwc4_input_size =
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
nhwc4_input_ = malloc(nhwc4_input_size);
size_t nhwc8_input_size =
ic8 * C8NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
nhwc4_input_ = malloc(nhwc8_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
return RET_ERROR;
}
memset(nhwc4_input_, 0, nhwc4_input_size);
memset(nhwc4_input_, 0, nhwc8_input_size);
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() {
auto input_tensor = in_tensors_.at(kInputIndex);
auto ret = CheckLayout(input_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Check layout failed.";
return RET_ERROR;
}
auto output_tensor = out_tensors_.at(kOutputIndex);
output_tensor->SetFormat(schema::Format_NHWC);
......@@ -348,7 +343,7 @@ int ConvolutionWinogradFP16CPUKernel::Run() {
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int in_channel = conv_param_->input_channel_;
convert_func_(execute_input_, nhwc4_input_, in_batch, in_h * in_w, in_channel);
PackNHWCToNHWC8Fp16(execute_input_, nhwc4_input_, in_batch, in_h * in_w, in_channel);
int error_code = LiteBackendParallelLaunch(ConvolutionWinogradFp16Impl, this, thread_count_);
if (error_code != RET_OK) {
......
......@@ -382,7 +382,8 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
const int tile_num = 16;
const int output_unit = 4;
const int k_plane = 36;
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
int ic4 = ic8 * 2;
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
int out_w_block = UP_DIV(conv_param->output_w_, C4NUM);
......@@ -390,7 +391,7 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, tile_num);
int tile_buffer_offset = tile_num * k_plane * ic4 * C4NUM;
int block_unit_buffer_offset = k_plane * C4NUM;
int block_unit_buffer_offset = k_plane * C8NUM;
int tmp_dst_buffer_offset = tile_num * k_plane * oc8 * C8NUM;
int input_batch = conv_param->input_batch_;
......@@ -541,7 +542,7 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
int input_unit = conv_param->input_unit_;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int ic8 = UP_DIV(in_channel, C8NUM);
int out_unit = conv_param->output_unit_;
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
......@@ -557,16 +558,16 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
float16_t *gemm_out = buffer_list[1];
float16_t *tmp_out_data = buffer_list[2];
float16_t *tmp_data = buffer_list[3];
int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM;
int trans_input_offset = tile_num * input_unit_square * ic8 * C8NUM;
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * C4NUM;
int tmp_data_offset = input_unit_square * C8NUM;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int in_batch_offset = b * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc8 * C8NUM;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_tile_index = thread_id * TILE_NUM;
int out_tile_index = thread_id * tile_num;
int cal_num = output_count - thread_id * tile_num;
cal_num = cal_num > tile_num ? tile_num : cal_num;
WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
......@@ -574,7 +575,7 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
input_trans_func);
// step 3 : gemm
IndirectGemmFp16_16x8(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset,
trans_weight, NULL, input_unit_square, ic4, oc8 * C8NUM, output_offset, 1, 1, 0, 0);
trans_weight, NULL, input_unit_square, ic8 * 2, oc8 * C8NUM, output_offset, 1, 1, 0, 0);
// step 4 : output transform
WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
......
......@@ -161,7 +161,8 @@ void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_w
void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) {
// origin weight format : ohwi
int input_channel = conv_param->input_channel_;
int ic4 = UP_DIV(input_channel, C4NUM);
int ic8 = UP_DIV(input_channel, C8NUM);
int ic4 = ic8 * 2;
int output_channel = conv_param->output_channel_;
int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_;
......@@ -240,6 +241,26 @@ void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int c
}
}
void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int ic8 = UP_DIV(channel, C8NUM);
int nhwc8_batch_unit_offset = ic8 * C8NUM * plane;
int ic_remainder_ = channel % C8NUM;
if (ic_remainder_ != 0) {
int nhwc8_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int i = 0; i < plane; i++) {
memcpy((float16_t *)dst + nhwc8_batch_offset + i * ic8 * C8NUM, (float16_t *)src + batch_offset + i * channel,
channel * sizeof(float16_t));
}
nhwc8_batch_offset += nhwc8_batch_unit_offset;
}
} else {
size_t ori_input_size = batch * plane * channel * sizeof(float16_t);
memcpy(dst, src, ori_input_size);
}
}
void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
int ic_remainder_ = channel % C4NUM;
......@@ -399,19 +420,6 @@ void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, i
}
}
void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel) {
int c8_channel = UP_DIV(channel, C8NUM) * C8NUM;
for (int b = 0; b < batch; b++) {
float16_t *dst_batch = dst + b * plane * c8_channel;
float16_t *src_batch = src + b * plane * channel;
for (int i = 0; i < plane; i++) {
float16_t *dst_plane = dst_batch + i * c8_channel;
float16_t *src_plane = src_batch + i * channel;
memcpy(dst_plane, src_plane, channel * sizeof(float16_t));
}
}
}
void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel) {
int c8_channel = UP_DIV(channel, C8NUM) * C8NUM;
for (int b = 0; b < batch; b++) {
......
......@@ -43,6 +43,8 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
......@@ -63,8 +65,6 @@ void PackNHWCFp32ToC8HWN8Fp16(float *src, float16_t *dst, int batch, int plane,
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel);
void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel);
#ifdef __cplusplus
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册