提交 dac39a7c 编写于 作者: C chenjianping

fix slice parallel run bug

上级 089774c6
......@@ -47,7 +47,7 @@ void PadSliceParameterTo4D(SliceParameter *param) {
param->param_length_ = DIMENSION_4D;
}
void DoSlice(const float *input, float *output, SliceParameter *param) {
void DoSlice(const float *input, float *output, SliceParameter *param, int thread_id) {
int32_t out_dim1 = param->size_[1];
int32_t out_dim2 = param->size_[2];
int32_t out_dim3 = param->size_[3];
......@@ -55,7 +55,6 @@ void DoSlice(const float *input, float *output, SliceParameter *param) {
size_t out_stride1 = out_stride2 * out_dim2;
size_t out_stride0 = out_stride1 * out_dim1;
size_t count_per_thread = UP_DIV(out_dim1, param->op_parameter_.thread_num_);
int thread_id = param->thread_id_;
size_t thread_stride = thread_id * count_per_thread;
size_t copy_size = param->size_[3] * sizeof(float);
size_t in_stride2 = param->shape_[3];
......
......@@ -23,7 +23,7 @@
extern "C" {
#endif
void PadSliceParameterTo4D(SliceParameter *param);
void DoSlice(const float *input, float *output, SliceParameter *param);
void DoSlice(const float *input, float *output, SliceParameter *param, int thread_id);
void DoSliceNoParallel(const float *input, float *output, SliceParameter *param);
#ifdef __cplusplus
}
......
......@@ -66,7 +66,7 @@ int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *par
return 0;
}
int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param) {
int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param, int thread_id) {
double input_scale = param->quant_arg_.in_args_.scale_;
int input_zp = param->quant_arg_.in_args_.zp_;
double output_scale = param->quant_arg_.out_args_.scale_;
......@@ -81,7 +81,6 @@ int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param) {
int out_stride1 = out_stride2 * out_dim2;
int out_stride0 = out_stride1 * out_dim1;
int count_per_thread = UP_DIV(out_dim1, param->op_parameter_.thread_num_);
int thread_id = param->thread_id_;
int thread_stride = thread_id * count_per_thread;
int unit_size = param->size_[3] * sizeof(int8_t);
int in_stride2 = param->shape_[3];
......
......@@ -23,7 +23,7 @@
extern "C" {
#endif
int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *param);
int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param);
int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param, int thread_id);
#ifdef __cplusplus
}
#endif
......
......@@ -30,7 +30,6 @@ typedef struct SliceParameter {
int32_t size_[SLICE_SHAPE_MAX_SIZE];
int32_t shape_[SLICE_SHAPE_MAX_SIZE];
int32_t param_length_;
int32_t thread_id_;
} SliceParameter;
#endif // MINDSPORE_LITE_NNACL_SLICE_PARAMETER_H_
......@@ -78,7 +78,7 @@ int SliceCPUKernel::SliceParallelRun(int thread_id) {
const float *input_data = reinterpret_cast<const float *>(in_tensors_[0]->Data());
float *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());
SliceParameter *param = reinterpret_cast<SliceParameter *>(op_parameter_);
DoSlice(input_data, output_data, param);
DoSlice(input_data, output_data, param, thread_id);
return RET_OK;
}
......
......@@ -60,8 +60,7 @@ int SliceInt8CPUKernel::DoSlice(int task_id) {
const int8_t *input_data = reinterpret_cast<const int8_t *>(in_tensors_[0]->Data());
int8_t *output_data = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
param_->thread_id_ = task_id;
auto ret = SliceInt8(input_data, output_data, param_);
auto ret = SliceInt8(input_data, output_data, param_, task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SliceInt8 error ,task_id[" << task_id << "] error_code[" << ret << "]";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册