提交 60551b1f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4608 fix concat int8 memory leak

Merge pull request !4608 from zhaozhenlong/lite/issue/concat_int8_mem_leak_fix
......@@ -46,7 +46,7 @@ class ConcatBaseCPUKernel : public LiteKernel {
int thread_count_;
int axis_;
const Context *ctx_;
ConcatParameter *concat_param_;
ConcatParameter *concat_param_ = nullptr;
};
} // namespace mindspore::kernel
......
......@@ -28,9 +28,15 @@ namespace mindspore::kernel {
int ConcatInt8CPUKernel::Init() {
ConcatBaseCPUKernel::Init();
concat_param_->input_shapes_ = nullptr;
auto input_num = in_tensors_.size();
input_data_ = reinterpret_cast<int8_t **>(malloc(sizeof(int8_t *) * input_num));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "Null pointer reference: inputs_array.";
return RET_ERROR;
}
concat_param_->quant_arg_.in_args_ =
reinterpret_cast<QuantArg *>(ctx_->allocator->Malloc(sizeof(QuantArg) * input_num));
reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg) * input_num));
if (concat_param_->quant_arg_.in_args_ == nullptr) {
MS_LOG(ERROR) << "Null pointer reference: quant_concat_parm_->in_quant_args_.";
return RET_ERROR;
......@@ -61,11 +67,11 @@ int ConcatInt8CPUKernel::ReSize() {
return ret;
}
if (concat_param_->input_shapes_ != nullptr) {
ctx_->allocator->Free(concat_param_->input_shapes_);
// free(concat_param_->input_shapes_);
}
auto input_num = in_tensors_.size();
concat_param_->input_num_ = input_num;
concat_param_->input_shapes_ = reinterpret_cast<const int **>(ctx_->allocator->Malloc(sizeof(int *) * input_num));
concat_param_->input_shapes_ = reinterpret_cast<const int **>(malloc(sizeof(int *) * input_num));
for (size_t i = 0; i < input_num; i++) {
concat_param_->input_shapes_[i] = reinterpret_cast<const int *>(in_tensors_.at(i)->shape().data());
}
......@@ -96,11 +102,7 @@ int ConcatInt8CPUKernel::Run() {
auto input_num = concat_param_->input_num_;
count_unit_ = thread_count_ > 1 ? UP_DIV(before_axis_size, thread_count_) : before_axis_size;
concat_param_->count_unit_ = count_unit_;
input_data_ = reinterpret_cast<int8_t **>(ctx_->allocator->Malloc(sizeof(int8_t *) * input_num));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "Null pointer reference: inputs_array.";
return RET_ERROR;
}
for (size_t i = 0; i < input_num; i++) {
input_data_[i] = static_cast<int8_t *>(in_tensors_.at(i)->Data());
}
......@@ -108,10 +110,6 @@ int ConcatInt8CPUKernel::Run() {
ret = LiteBackendParallelLaunch(ConcatInt8Run, this, thread_count_);
ctx_->allocator->Free(input_data_);
ctx_->allocator->Free(concat_param_->input_shapes_);
ctx_->allocator->Free(concat_param_->quant_arg_.in_args_);
return ret;
}
......
......@@ -32,7 +32,17 @@ class ConcatInt8CPUKernel : public ConcatBaseCPUKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: ConcatBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConcatInt8CPUKernel() override {}
~ConcatInt8CPUKernel() override {
if (input_data_ != nullptr) {
free(input_data_);
}
if (concat_param_->input_shapes_ != nullptr) {
free(concat_param_->input_shapes_);
}
if (concat_param_->quant_arg_.in_args_ != nullptr) {
free(concat_param_->quant_arg_.in_args_);
}
}
int Init() override;
int ReSize() override;
......
......@@ -35,7 +35,7 @@ int Int8ElementFloor(int8_t *input, int8_t *output, int element_size, ArithSelfQ
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (output_tmp);
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
......@@ -54,7 +54,7 @@ int Int8ElementRound(int8_t *input, int8_t *output, int element_size, ArithSelfQ
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (output_tmp);
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
......@@ -73,7 +73,7 @@ int Int8ElementCeil(int8_t *input, int8_t *output, int element_size, ArithSelfQu
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (output_tmp);
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
......@@ -92,7 +92,7 @@ int Int8ElementAbs(int8_t *input, int8_t *output, int element_size, ArithSelfQua
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (output_tmp);
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
......@@ -111,7 +111,7 @@ int Int8ElementSin(int8_t *input, int8_t *output, int element_size, ArithSelfQua
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (output_tmp);
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
......@@ -130,7 +130,7 @@ int Int8ElementCos(int8_t *input, int8_t *output, int element_size, ArithSelfQua
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (output_tmp);
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
......@@ -149,7 +149,7 @@ int Int8ElementLog(int8_t *input, int8_t *output, int element_size, ArithSelfQua
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (output_tmp);
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
......@@ -172,7 +172,7 @@ int Int8ElementSqrt(int8_t *input, int8_t *output, int element_size, ArithSelfQu
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (output_tmp);
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
......@@ -195,7 +195,7 @@ int Int8ElementRsqrt(int8_t *input, int8_t *output, int element_size, ArithSelfQ
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (output_tmp);
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
......@@ -230,6 +230,7 @@ void SquareInt8NEON(int8_t *input_data, int8_t *output_data, int64_t element_siz
int16x8_t res_s16 = vcombine_s16(sum_low, sum_high);
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
vst1_s8(output_data, res_u8_n0);
output_data += 8;
}
}
#endif
......@@ -253,7 +254,7 @@ int Int8ElementSquare(int8_t *input, int8_t *output, int element_size, ArithSelf
} else if (output_tmp < para.output_activation_min_) {
output[index] = para.output_activation_min_;
} else {
output[index] = (output_tmp);
output[index] = (int8_t)output_tmp;
}
}
return NNACL_OK;
......
......@@ -22,36 +22,36 @@ void Int8Concat(int8_t **inputs, int8_t *output, ConcatParameter *para, int axis
float output_scale = para->quant_arg_.out_args_.scale_;
const float output_inverse_scale = 1.f / output_scale;
int input_num = para->input_num_;
int count_unit_ = para->count_unit_;
int after_axis_size = para->after_axis_size;
int64_t count_unit_ = para->count_unit_;
int64_t after_axis_size = para->after_axis_size;
const int *output_shape = para->output_shapes_;
int out_copy_size = output_shape[axis] * after_axis_size;
QuantArg *input_quant = para->quant_arg_.in_args_;
int output_zp = para->quant_arg_.out_args_.zp_;
int max_int8 = para->quant_arg_.output_activation_max_;
int min_int8 = para->quant_arg_.output_activation_min_;
int8_t max_int8 = para->quant_arg_.output_activation_max_;
int8_t min_int8 = para->quant_arg_.output_activation_min_;
int64_t start = task_id * count_unit_;
int64_t end = start + real_dst_count;
output += start * out_copy_size;
for (int k = start; k < end; k++) {
for (int i = 0; i < input_num; i++) {
const int *input_shape = para->input_shapes_[i];
int in_copy_size = input_shape[axis] * after_axis_size;
int64_t in_copy_size = input_shape[axis] * after_axis_size;
int8_t *input_ptr = inputs[i] + k * in_copy_size;
int8_t *output_ptr = output + k * out_copy_size;
if (input_quant[i].scale_ == output_scale && input_quant[i].zp_ == output_zp) {
memcpy(output_ptr, input_ptr, in_copy_size);
memcpy(output, input_ptr, in_copy_size);
} else {
float scale = input_quant[i].scale_ * output_inverse_scale;
float bias = -input_quant[i].zp_ * scale;
for (int j = 0; j < in_copy_size; j++) {
int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp;
if (output_tmp > max_int8) {
output_ptr[j] = max_int8;
output[j] = max_int8;
} else if (output_tmp < min_int8) {
output_ptr[j] = min_int8;
output[j] = min_int8;
} else {
output_ptr[j] = (output_tmp);
output[j] = (int8_t)output_tmp;
}
}
}
......
......@@ -68,7 +68,7 @@ void Crop1D(const int8_t *input, int8_t *output, int task_id, CropParameter *par
} else if (output_tmp < para->quant_arg.output_activation_min_) {
out_ptr[i] = para->quant_arg.output_activation_min_;
} else {
out_ptr[i] = output_tmp;
out_ptr[i] = (int8_t)output_tmp;
}
}
}
......@@ -110,7 +110,7 @@ void Crop2D(const int8_t *input, int8_t *output, int task_id, CropParameter *par
} else if (output_tmp < para->quant_arg.output_activation_min_) {
out_ptr[i] = para->quant_arg.output_activation_min_;
} else {
out_ptr[i] = (output_tmp);
out_ptr[i] = (int8_t)output_tmp;
}
}
}
......@@ -164,7 +164,7 @@ void Crop3D(const int8_t *input, int8_t *output, int task_id, CropParameter *par
} else if (output_tmp < para->quant_arg.output_activation_min_) {
out_ptr[i] = para->quant_arg.output_activation_min_;
} else {
out_ptr[i] = (output_tmp);
out_ptr[i] = (int8_t)output_tmp;
}
}
}
......@@ -225,7 +225,7 @@ void Int8Crop4D(const int8_t *input, int8_t *output, int task_id, CropParameter
} else if (output_tmp < para->quant_arg.output_activation_min_) {
out_ptr[i] = para->quant_arg.output_activation_min_;
} else {
out_ptr[i] = (output_tmp);
out_ptr[i] = (int8_t)output_tmp;
}
}
}
......
......@@ -80,7 +80,7 @@ void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t
} else if (mul_result < para.output_activation_min_) {
output_data[index] = para.output_activation_min_;
} else {
output_data[index] = (mul_result);
output_data[index] = (int8_t)mul_result;
}
}
return;
......
......@@ -33,7 +33,7 @@ void Int8Reshape(int8_t *input_ptr, int8_t *output_ptr, int64_t real_dst_count,
} else if (output_tmp < para.output_activation_min_) {
output_ptr[i] = para.output_activation_min_;
} else {
output_ptr[i] = output_tmp;
output_ptr[i] = (int8_t)output_tmp;
}
}
}
......
......@@ -62,7 +62,7 @@ int Int8DoSplit(int8_t *in_data, int8_t **out_data, const int *input_shape, int
} else if (output_tmp < param->quant_arg_.output_activation_min_) {
dst[j] = param->quant_arg_.output_activation_min_;
} else {
dst[j] = output_tmp;
dst[j] = (int8_t)output_tmp;
}
}
}
......
......@@ -53,8 +53,8 @@ typedef struct ConvQuantArg {
typedef struct ConcatQuantArg {
QuantArg *in_args_;
QuantArg out_args_;
int output_activation_min_;
int output_activation_max_;
int8_t output_activation_min_;
int8_t output_activation_max_;
} ConcatQuantArg;
typedef struct SqueezeQuantArg {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册