未验证 提交 3aa9809f 编写于 作者: A AddSalt8227 提交者: GitHub

fix op bugs with split and batchnorm (#1236)

* fix retinaface_timvx memory leak

* fix rgb error in tm_yolact.cpp

* 1.fix eltwise uint8 bug
2.add hardsigmoid uint8

* fix op bugs with split and batchnorm
上级 01713180
......@@ -64,7 +64,23 @@ int batchnorm_run(struct tensor* output_tensor, struct tensor* input_tensor, flo
{
int batch_number = input_tensor->dims[0];
int channel_num = input_tensor->dims[1];
int channel_size = (input_tensor->dims[2]) * (input_tensor->dims[3]);
int channel_size;
if (4 == input_tensor->dim_num)
{
channel_size = (input_tensor->dims[2]) * (input_tensor->dims[3]);
}
else if (3 == input_tensor->dim_num)
{
channel_size = (input_tensor->dims[2]);
}
else if (2 == input_tensor->dim_num)
{
channel_size = 1;
}
else
{
return -1;
}
int img_size = channel_num * channel_size;
const float* input = (const float*)input_tensor->data;
......
......@@ -69,6 +69,11 @@ int ref_split_uint8(struct tensor* input_tensor, struct tensor* output_tensor, s
{
uint8_t* input_data = (uint8_t*)input_tensor->data;
uint8_t* output_data = (uint8_t*)output_tensor->data;
float input_scale = input_tensor->scale;
float output_scale = output_tensor->scale;
int32_t input_zero = input_tensor->zero_point;
int32_t output_zero = output_tensor->zero_point;
float rescale = input_scale / output_scale;
if (split_param->is_caffe)
{
......@@ -84,7 +89,15 @@ int ref_split_uint8(struct tensor* input_tensor, struct tensor* output_tensor, s
{
int in_offset = (n * in_slice + *slice_index) * slice_size;
int out_offset = n * out_slice * slice_size;
memcpy(output_data + out_offset, input_data + in_offset, (size_t)slice_size * out_slice * sizeof(uint8_t));
for (size_t i = 0; i < slice_size * out_slice * sizeof(uint8_t); i++)
{
int udata = roundf((input_data[in_offset + i] - input_zero) * rescale + output_zero);
if (udata > 255)
udata = 255;
else if (udata < 0)
udata = 0;
output_data[i] = udata;
}
}
*slice_index += out_slice;
......@@ -148,6 +161,7 @@ static int run(struct node_ops* node_ops, struct exec_node* exec_node, struct ex
struct split_param* split_param = (struct split_param*)ir_node->op.param_mem;
/* the follow codes need to be checked ! */
/* maybe int8 need dequant and quant */
int slice_axis = split_param->axis;
int num_slices = 1;
int slice_size = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册