提交 59511a06 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5304 [MS][LITE][CPU]malloc using memory pool

Merge pull request !5304 from fuzhiye/tmp
......@@ -54,35 +54,32 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int channel_block = UP_DIV(in_channel, 4);
int ic4 = UP_DIV(in_channel, 4);
memset(packed_input, 0, kernel_w * kernel_h * ic4 * C4NUM * 16 * sizeof(float16_t));
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * stride_h - pad_h;
int input_w = block_start % out_w * stride_w - pad_w;
for (int j = 0; j < kernel_h; j++) {
int input_y = input_h + j * dilation_h;
if (input_y < 0 || input_y >= in_h) {
continue;
}
int input_y_stride = input_y * in_w * channel_block * C4NUM;
for (int n = 0; n < kernel_w; n++) {
int input_x = input_w + n * dilation_w;
if (input_x < 0 || input_x >= in_w) {
continue;
}
int input_x_stride = input_y_stride + input_x * channel_block * C4NUM;
int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * channel_block + i * C4NUM;
for (int m = 0; m < channel_block; m++) {
int input_stride = input_h * in_w * ic4 * C4NUM + input_w * ic4 * C4NUM;
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * ic4 * C4NUM + input_stride;
for (int n = kw_s; n < kw_e; n++) {
int input_x_stride = input_y_stride + n * dilation_w * ic4 * C4NUM;
int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * ic4 + i * C4NUM;
for (int m = 0; m < ic4; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
int channel_block_offset = input_plane_offset + m * 16 * C4NUM;
#ifdef ENABLE_ARM64
vst1_f16(packed_input + channel_block_offset, vld1_f16(input_data + channel_block_stride));
#else
(packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0];
(packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1];
(packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2];
(packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3];
for (int l = 0; l < C4NUM; ++l) {
(packed_input + channel_block_offset)[l] = (input_data + channel_block_stride)[l];
}
#endif
} // channel_block loop
} // kernel_w loop
......
......@@ -309,23 +309,21 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int ic4 = UP_DIV(in_channel, C4NUM);
memset(packed_input, 0, kernel_h * kernel_w * ic4 * C4NUM * TILE_NUM * sizeof(float));
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * stride_h - pad_h;
int input_w = block_start % out_w * stride_w - pad_w;
for (int j = 0; j < kernel_h; j++) {
int input_y = input_h + j * dilation_h;
if (input_y < 0 || input_y >= in_h) {
continue;
}
int input_y_stride = input_y * in_w * ic4 * C4NUM;
for (int n = 0; n < kernel_w; n++) {
int input_x = input_w + n * dilation_w;
if (input_x < 0 || input_x >= in_w) {
continue;
}
int input_x_stride = input_y_stride + input_x * ic4 * C4NUM;
int input_stride = input_h * in_w * ic4 * C4NUM + input_w * ic4 * C4NUM;
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * ic4 * C4NUM + input_stride;
for (int n = kw_s; n < kw_e; n++) {
int input_x_stride = input_y_stride + n * dilation_w * ic4 * C4NUM;
int input_plane_offset = (j * kernel_w + n) * C8NUM * C4NUM * ic4 + i * C4NUM;
for (int m = 0; m < ic4; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
......
......@@ -95,8 +95,16 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() {
const int tile_num = 16;
const int k_plane = 36;
int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM);
int iC8 = UP_DIV(conv_param_->input_channel_, C8NUM);
MS_ASSERT(ctx_->allocator != nullptr);
size_t tile_buffer_size = thread_count_ * tile_num * k_plane * iC8 * C8NUM * sizeof(float16_t);
tile_buffer_ = reinterpret_cast<float16_t *>(malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile_buffer_ failed.";
return RET_ERROR;
}
size_t block_unit_buffer_size = thread_count_ * k_plane * C8NUM * sizeof(float16_t);
block_unit_buffer_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(block_unit_buffer_size));
if (block_unit_buffer_ == nullptr) {
......@@ -152,10 +160,6 @@ int Convolution3x3FP16CPUKernel::ReSize() {
return ret;
}
if (tile_buffer_ != nullptr) {
free(tile_buffer_);
tile_buffer_ = nullptr;
}
if (nhwc4_input_ != nullptr) {
free(nhwc4_input_);
nhwc4_input_ = nullptr;
......@@ -166,10 +170,8 @@ int Convolution3x3FP16CPUKernel::ReSize() {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return ret;
}
const int tile_num = 16;
const int k_plane = 36;
int iC8 = UP_DIV(conv_param_->input_channel_, C8NUM);
int iC8 = UP_DIV(conv_param_->input_channel_, C8NUM);
size_t nhwc8_input_size =
iC8 * C8NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
nhwc4_input_ = malloc(nhwc8_input_size);
......@@ -179,14 +181,6 @@ int Convolution3x3FP16CPUKernel::ReSize() {
}
memset(nhwc4_input_, 0, nhwc8_input_size);
size_t tile_buffer_size = thread_count_ * tile_num * k_plane * iC8 * C8NUM * sizeof(float16_t);
tile_buffer_ = reinterpret_cast<float16_t *>(malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile_buffer_ failed.";
return RET_ERROR;
}
memset(tile_buffer_, 0, tile_buffer_size);
return RET_OK;
}
......
......@@ -39,10 +39,6 @@ class Convolution3x3FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
free(transformed_filter_addr_);
transformed_filter_addr_ = nullptr;
}
if (tile_buffer_ != nullptr) {
free(tile_buffer_);
tile_buffer_ = nullptr;
}
}
int Init() override;
......@@ -56,6 +52,10 @@ class Convolution3x3FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
private:
void FreeTmpBuffer() {
if (tile_buffer_ != nullptr) {
ctx_->allocator->Free(tile_buffer_);
tile_buffer_ = nullptr;
}
if (block_unit_buffer_ != nullptr) {
ctx_->allocator->Free(block_unit_buffer_);
block_unit_buffer_ = nullptr;
......
......@@ -38,10 +38,13 @@ using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
int ConvolutionFP16CPUKernel::InitWeightBias() {
int kernel_h = conv_param_->kernel_h_;
int kernel_w = conv_param_->kernel_w_;
int in_channel = conv_param_->input_channel_;
int out_channel = conv_param_->output_channel_;
auto filter_tensor = in_tensors_.at(kWeightIndex);
int kernel_h = filter_tensor->Height();
int kernel_w = filter_tensor->Width();
int in_channel = filter_tensor->Channel();
int out_channel = filter_tensor->Batch();
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
int oc8 = UP_DIV(out_channel, C8NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
......@@ -81,38 +84,34 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
}
int ConvolutionFP16CPUKernel::InitTmpBuffer() {
int kernel_h = conv_param_->kernel_h_;
int kernel_w = conv_param_->kernel_w_;
int in_batch = conv_param_->input_batch_;
int in_channel = conv_param_->input_channel_;
int out_channel = conv_param_->output_channel_;
int channel_block = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
// malloc packed_inputs
int cal_num = 16;
int output_count = conv_param_->output_h_ * conv_param_->output_w_;
int output_tile_count = UP_DIV(output_count, cal_num);
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
int unit_size = kernel_plane * channel_block * C4NUM;
int packed_input_size = output_tile_count * cal_num * unit_size;
packed_input_ = reinterpret_cast<float16_t *>(malloc(in_batch * packed_input_size * sizeof(float16_t)));
packed_input_ =
reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(in_batch * packed_input_size * sizeof(float16_t)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc packed_input_ failed.";
return RET_ERROR;
}
memset(packed_input_, 0, in_batch * packed_input_size * sizeof(float16_t));
size_t nhwc4_input_size = channel_block * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ *
size_t nhwc4_input_size = channel_block * C4NUM * in_batch * conv_param_->input_h_ *
conv_param_->input_w_ * sizeof(float16_t);
nhwc4_input_ = malloc(nhwc4_input_size);
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
return RET_ERROR;
}
memset(nhwc4_input_, 0, nhwc4_input_size);
tmp_output_block_ = reinterpret_cast<float16_t *>(malloc(thread_count_ * cal_num * out_channel * sizeof(float16_t)));
tmp_output_block_ =
reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(thread_count_ * cal_num * out_channel * sizeof(float16_t)));
if (tmp_output_block_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_output_block_ failed.";
return RET_ERROR;
......@@ -136,6 +135,12 @@ int ConvolutionFP16CPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
auto ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
return RET_ERROR;
}
ConfigInputOutput();
return ReSize();
}
......@@ -146,28 +151,11 @@ int ConvolutionFP16CPUKernel::ReSize() {
return ret;
}
FreeTmpBuffer();
if (nhwc4_input_ != nullptr) {
free(nhwc4_input_);
nhwc4_input_ = nullptr;
}
ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret;
return ret;
}
ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
return RET_ERROR;
}
ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
ConfigInputOutput();
return RET_OK;
}
......@@ -200,6 +188,12 @@ int ConvolutionFP16CPUKernel::Run() {
return ret;
}
ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
......@@ -209,9 +203,10 @@ int ConvolutionFP16CPUKernel::Run() {
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ConvolutionFp16Impl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv fp16 error error_code[" << error_code << "]";
FreeTmpBuffer();
return RET_ERROR;
}
FreeTmpBuffer();
ConvolutionBaseFP16CPUKernel::IfCastOutput();
ConvolutionBaseFP16CPUKernel::FreeTmpBuffer();
return RET_OK;
......
......@@ -29,7 +29,16 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionFP16CPUKernel() override { FreeTmpBuffer(); }
~ConvolutionFP16CPUKernel() override {
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
fp16_weight_ = nullptr;
}
if (packed_weight_ != nullptr) {
free(packed_weight_);
packed_weight_ = nullptr;
}
}
int Init() override;
int ReSize() override;
......@@ -41,21 +50,16 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
private:
void FreeTmpBuffer() {
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
fp16_weight_ = nullptr;
if (nhwc4_input_ != nullptr) {
ctx_->allocator->Free(nhwc4_input_);
nhwc4_input_ = nullptr;
}
if (packed_input_ != nullptr) {
free(packed_input_);
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
if (packed_weight_ != nullptr) {
free(packed_weight_);
packed_weight_ = nullptr;
}
if (tmp_output_block_ != nullptr) {
free(tmp_output_block_);
ctx_->allocator->Free(tmp_output_block_);
tmp_output_block_ = nullptr;
}
}
......
......@@ -218,6 +218,14 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
int output_h = conv_param_->output_h_;
int output_w = conv_param_->output_w_;
int oc8 = UP_DIV(channel_out, C8NUM);
int ic8 = UP_DIV(conv_param_->input_channel_, C8NUM);
size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic8 * C8NUM * sizeof(float16_t);
trans_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
return RET_ERROR;
}
gemm_out_ = reinterpret_cast<float16_t *>(
ctx_->allocator->Malloc(thread_count_ * cal_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float16_t)));
......@@ -296,10 +304,6 @@ int ConvolutionWinogradFP16CPUKernel::ReSize() {
free(nhwc4_input_);
nhwc4_input_ = nullptr;
}
if (trans_input_ != nullptr) {
free(trans_input_);
trans_input_ = nullptr;
}
ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
......@@ -311,10 +315,8 @@ int ConvolutionWinogradFP16CPUKernel::ReSize() {
conv_param_->input_unit_ = input_unit_;
conv_param_->output_unit_ = output_unit_;
int cal_num = 16;
int channel_in = conv_param_->input_channel_;
int ic8 = UP_DIV(channel_in, C8NUM);
size_t nhwc8_input_size =
ic8 * C8NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
nhwc4_input_ = malloc(nhwc8_input_size);
......@@ -324,14 +326,6 @@ int ConvolutionWinogradFP16CPUKernel::ReSize() {
}
memset(nhwc4_input_, 0, nhwc8_input_size);
size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic8 * C8NUM * sizeof(float16_t);
trans_input_ = reinterpret_cast<float16_t *>(malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
return RET_ERROR;
}
memset(trans_input_, 0, tile_buffer_size);
ret = ConfigInputOutput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConfigInputOutput failed.";
......
......@@ -38,10 +38,6 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
free(fp16_weight_);
fp16_weight_ = nullptr;
}
if (trans_input_ != nullptr) {
free(trans_input_);
trans_input_ = nullptr;
}
if (trans_weight_ != nullptr) {
delete trans_weight_;
trans_weight_ = nullptr;
......@@ -60,6 +56,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
private:
void FreeTmpBuffer() {
if (trans_input_ != nullptr) {
ctx_->allocator->Free(trans_input_);
trans_input_ = nullptr;
}
if (tmp_data_ != nullptr) {
ctx_->allocator->Free(tmp_data_);
tmp_data_ = nullptr;
......@@ -86,7 +86,7 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
TmpBufferAddressFp16 tmp_buffer_address_list_[4];
};
int WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit,
ConvParameter *conv_param, int oc_block);
ConvParameter *conv_param, int oc_block);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_
......@@ -83,6 +83,26 @@ int ConvolutionCPUKernel::InitTmpBuffer() {
int out_channel = conv_param_->output_channel_;
MS_ASSERT(ctx_->allocator != nullptr);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
size_t nhwc4_input_size =
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float);
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4 input failed.";
return RET_ERROR;
}
int output_count = conv_param_->output_h_ * conv_param_->output_w_;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * ic4 * C4NUM;
int packed_input_size = output_tile_count * TILE_NUM * unit_size;
packed_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(conv_param_->input_batch_ * packed_input_size * sizeof(float)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc packed input failed.";
return RET_ERROR;
}
tmp_output_block_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * TILE_NUM * out_channel * sizeof(float)));
if (tmp_output_block_ == nullptr) {
......@@ -124,40 +144,11 @@ int ConvolutionCPUKernel::ReSize() {
return ret;
}
if (nhwc4_input_ != nullptr) {
free(nhwc4_input_);
nhwc4_input_ = nullptr;
}
if (packed_input_ != nullptr) {
free(packed_input_);
packed_input_ = nullptr;
}
ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return RET_ERROR;
}
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
size_t nhwc4_input_size =
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float);
nhwc4_input_ = malloc(nhwc4_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4 input failed.";
return RET_ERROR;
}
memset(nhwc4_input_, 0, nhwc4_input_size);
int output_count = conv_param_->output_h_ * conv_param_->output_w_;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * ic4 * C4NUM;
int packed_input_size = output_tile_count * TILE_NUM * unit_size;
packed_input_ = reinterpret_cast<float *>(malloc(conv_param_->input_batch_ * packed_input_size * sizeof(float)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc packed input failed.";
return RET_ERROR;
}
memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size * sizeof(float));
return RET_OK;
}
......
......@@ -35,10 +35,6 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
free(packed_weight_);
packed_weight_ = nullptr;
}
if (packed_input_ != nullptr) {
free(packed_input_);
packed_input_ = nullptr;
}
}
int Init() override;
......@@ -55,6 +51,14 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
ctx_->allocator->Free(tmp_output_block_);
tmp_output_block_ = nullptr;
}
if (nhwc4_input_ != nullptr) {
ctx_->allocator->Free(nhwc4_input_);
nhwc4_input_ = nullptr;
}
if (packed_input_ != nullptr) {
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
}
float *packed_input_ = nullptr;
float *packed_weight_ = nullptr;
......
......@@ -44,6 +44,10 @@ void ProcessFilterUint8(int8_t *origin_weight, int16_t *dst_weight, ConvParamete
}
void Convolution3x3Int8CPUKernel::FreeTmpBuffer() {
if (tile_buffer_ != nullptr) {
ctx_->allocator->Free(tile_buffer_);
tile_buffer_ = nullptr;
}
if (block_unit_buffer_ != nullptr) {
ctx_->allocator->Free(block_unit_buffer_);
block_unit_buffer_ = nullptr;
......@@ -67,10 +71,6 @@ Convolution3x3Int8CPUKernel::~Convolution3x3Int8CPUKernel() {
free(input_data_);
input_data_ = nullptr;
}
if (tile_buffer_ != nullptr) {
free(tile_buffer_);
tile_buffer_ = nullptr;
}
FreeQuantParam();
}
......@@ -115,8 +115,16 @@ int Convolution3x3Int8CPUKernel::InitTmpBuffer() {
int output_batch = conv_param_->output_batch_;
int output_w = conv_param_->output_w_;
int output_h = conv_param_->output_h_;
int ic8 = UP_DIV(conv_param_->input_channel_, C8NUM);
MS_ASSERT(ctx_->allocator != nullptr);
size_t tile_buffer_size = thread_count_ * TILE_NUM * C16NUM * ic8 * C8NUM * sizeof(int16_t);
tile_buffer_ = reinterpret_cast<int16_t *>(ctx_->allocator->Malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile_buffer_ failed.";
return RET_ERROR;
}
size_t block_unit_buffer_size = thread_count_ * 4 * 4 * C8NUM * sizeof(int16_t);
block_unit_buffer_ = reinterpret_cast<int16_t *>(ctx_->allocator->Malloc(block_unit_buffer_size));
if (block_unit_buffer_ == nullptr) {
......@@ -175,10 +183,6 @@ int Convolution3x3Int8CPUKernel::ReSize() {
free(input_data_);
input_data_ = nullptr;
}
if (tile_buffer_ != nullptr) {
free(tile_buffer_);
tile_buffer_ = nullptr;
}
ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
......@@ -196,13 +200,6 @@ int Convolution3x3Int8CPUKernel::ReSize() {
}
memset(input_data_, 0, c8_input_size);
size_t tile_buffer_size = thread_count_ * TILE_NUM * C16NUM * ic8 * C8NUM * sizeof(int16_t);
tile_buffer_ = reinterpret_cast<int16_t *>(malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile_buffer_ failed.";
return RET_ERROR;
}
memset(tile_buffer_, 0, tile_buffer_size);
return RET_OK;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册