diff --git a/mindspore/lite/nnacl/fp32/activation.c b/mindspore/lite/nnacl/fp32/activation.c index 17e340751ea471b3746dba1d477ce804cd6a5c72..dbd61a726077b8f452d0900ed77bd7a8ad1f0dd6 100644 --- a/mindspore/lite/nnacl/fp32/activation.c +++ b/mindspore/lite/nnacl/fp32/activation.c @@ -43,8 +43,19 @@ int LRelu(const float *src, int length, float *dst, float alpha) { } int Sigmoid(const float *src, int length, float *dst) { + const float upper_bound = 16.619047164916992188f; + const float lower_bound = -9.0f; for (int i = 0; i < length; ++i) { - dst[i] = 1.0f / (1.0f + exp(-src[i])); + float input_val = src[i]; + float result; + if (input_val > upper_bound) { + result = 1.0f; + } else if (input_val < lower_bound) { + result = exp(input_val); + } else { + result = 1.0f / (1.0f + exp(-input_val)); + } + dst[i] = result; } return NNACL_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc index 1ed0a2bd84ad963884746fe70453890543dd5eea..cd44d271c04a25b5c4145ff8755e42b78f56e042 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc @@ -31,8 +31,6 @@ using mindspore::schema::PrimitiveType_Gather; namespace mindspore::kernel { int GatherCPUKernel::Init() { - axis_ = (reinterpret_cast(op_parameter_))->axis_; - batchDims_ = (reinterpret_cast(op_parameter_))->batchDims_; if (!InferShapeDone()) { return RET_OK; } @@ -47,7 +45,7 @@ int GatherCPUKernel::DoGather(int task_id) { auto out_tensor = out_tensors_.at(0); auto input_ptr = reinterpret_cast(input_tensor->Data()); - auto indices_ptr = reinterpret_cast(indices_tensor->Data()); + auto indices_ptr = reinterpret_cast(indices_tensor->Data()); auto output_ptr = reinterpret_cast(out_tensor->Data()); auto input_int32 = reinterpret_cast(input_tensor->Data()); @@ -56,26 +54,25 @@ int GatherCPUKernel::DoGather(int task_id) { auto in_shape = input_tensor->shape(); int in_rank = in_shape.size(); int indices_element_size = indices_tensor->ElementsNum(); + auto axis = (reinterpret_cast(op_parameter_))->axis_; - const int limit = in_shape[axis_]; + const int limit = in_shape[axis]; for (int i = 0; i < indices_element_size; ++i) { - if (indices_ptr[i] >= limit) { - MS_LOG(ERROR) << " indice data: " << indices_ptr[i] << " is not in [ 0, " << limit - 1 << " ]"; + indices_data_[i] = static_cast(indices_ptr[i]); + if (indices_data_[i] >= limit) { + MS_LOG(ERROR) << " indice data: " << indices_data_[i] << " is not in [ 0, " << limit - 1 << " ]"; return RET_ERROR; } } - int outer_size = 1; - for (int i = 0; i < axis_; ++i) { + int outer_size = 1, inner_size = 1; + for (int i = 0; i < axis; ++i) { outer_size *= in_shape[i]; } - - int inner_size = 1; - for (int i = axis_ + 1; i < in_rank; ++i) { + for (int i = axis + 1; i < in_rank; ++i) { inner_size *= in_shape[i]; } - - int stride = UP_DIV(outer_size, thread_count_); + int stride = UP_DIV(outer_size, op_parameter_->thread_num_); int count = MSMIN(stride, outer_size - stride * task_id); auto thread_stride = stride * task_id; @@ -83,17 +80,13 @@ int GatherCPUKernel::DoGather(int task_id) { if (input_tensor->data_type() == kNumberTypeInt32) { input_int32 += thread_stride * limit; output_int32 += thread_stride * indices_element_size; - error_code = GatherInt32(input_int32, count, inner_size, limit, indices_ptr, indices_element_size, output_int32); + error_code = GatherInt32(input_int32, count, inner_size, limit, indices_data_, indices_element_size, output_int32); } else { input_ptr += thread_stride * limit; output_ptr += thread_stride * indices_element_size; - error_code = Gather(input_ptr, count, inner_size, limit, indices_ptr, indices_element_size, output_ptr); - } - - if (error_code != RET_OK) { - return RET_ERROR; + error_code = Gather(input_ptr, count, inner_size, limit, indices_data_, indices_element_size, output_ptr); } - return RET_OK; + return error_code; } int GatherRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { @@ -101,9 +94,8 @@ int GatherRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { auto error_code = gather_kernel->DoGather(task_id); if (error_code != RET_OK) { MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]"; - return RET_ERROR; } - return RET_OK; + return error_code; } int GatherCPUKernel::Run() { @@ -112,12 +104,19 @@ int GatherCPUKernel::Run() { MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; return prepare_ret; } - int error_code = LiteBackendParallelLaunch(GatherRun, this, thread_count_); + + auto indices_tensor = in_tensors_.at(1); + indices_data_ = reinterpret_cast(context_->allocator->Malloc(indices_tensor->ElementsNum() * sizeof(int))); + if (indices_data_ == nullptr) { + MS_LOG(ERROR) << "Memory allocation failed"; + context_->allocator->Free(indices_data_); + return RET_ERROR; + } + int error_code = LiteBackendParallelLaunch(GatherRun, this, op_parameter_->thread_num_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]"; - return RET_ERROR; } - return RET_OK; + return error_code; } kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector &inputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h index 90e4b8d89e43924a68d92509cb68c71f6ea15e0f..7af9703e541f75f3f73f22462fea70c113d85ee8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h @@ -27,7 +27,7 @@ class GatherCPUKernel : public LiteKernel { GatherCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~GatherCPUKernel() override = default; int Init() override; @@ -36,9 +36,7 @@ class GatherCPUKernel : public LiteKernel { int DoGather(int task_id); private: - int thread_count_; - int batchDims_; - int axis_; + int *indices_data_; }; } // namespace mindspore::kernel