diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc index 52879bdc44fb7ffff305429c8231244bfe1a22b7..84b3df489de6f8ff3cf303c24609871a9de8c810 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc @@ -37,13 +37,6 @@ int GatherCPUKernel::Init() { return ReSize(); } -GatherCPUKernel::~GatherCPUKernel() { - if (indices_data_ != nullptr) { - free(indices_data_); - indices_data_ = nullptr; - } -} - int GatherCPUKernel::ReSize() { return RET_OK; } int GatherCPUKernel::DoGather(int task_id) { @@ -105,28 +98,45 @@ int GatherCPUKernel::Run() { } auto indices_tensor = in_tensors_.at(1); - indices_data_ = reinterpret_cast(malloc(indices_tensor->Size())); - if (indices_data_ == nullptr) { - MS_LOG(ERROR) << "Memory allocation failed"; - return RET_ERROR; + int indices_num = indices_tensor->ElementsNum(); + bool isIndicesInt32 = indices_tensor->data_type() == kNumberTypeInt32; + int ret = AssignIndicesData(isIndicesInt32, indices_num, indices_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]"; + return ret; } - auto in_shape = in_tensors_.at(0)->shape(); - int indices_element_size = indices_tensor->ElementsNum(); - auto axis = (reinterpret_cast(op_parameter_))->axis_;; - auto indices_ptr = reinterpret_cast(indices_tensor->Data()); - const int limit = in_shape[axis]; - for (int i = 0; i < indices_element_size; ++i) { - indices_data_[i] = static_cast(indices_ptr[i]); - if (indices_data_[i] >= limit) { - MS_LOG(ERROR) << " indice data: " << indices_data_[i] << " is not in [ 0, " << limit - 1 << " ]"; + + ret = ParallelLaunch(THREAD_POOL_DEFAULT, GatherRun, this, op_parameter_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]"; + } + if (!isIndicesInt32) { + context_->allocator->Free(indices_data_); + indices_data_ = nullptr; + } + return ret; +} + +int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, lite::tensor::Tensor *indices_tensor) { + if (!isIndicesInt32) { + indices_data_ = reinterpret_cast(context_->allocator->Malloc(sizeof(int32_t) * indices_num)); + if (indices_data_ == nullptr) { + MS_LOG(ERROR) << "Memory allocation failed"; return RET_ERROR; } + if (indices_tensor->data_type() == kNumberTypeInt64) { + for (int i = 0; i < indices_num; i++) { + indices_data_[i] = reinterpret_cast(indices_tensor->Data())[i]; + } + } else { + for (int i = 0; i < indices_num; i++) { + indices_data_[i] = reinterpret_cast(indices_tensor->Data())[i]; + } + } + } else { + indices_data_ = reinterpret_cast(indices_tensor->Data()); } - int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, GatherRun, this, op_parameter_->thread_num_); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]"; - } - return error_code; + return RET_OK; } kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector &inputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h index 334d93a648425b7a7bdd6245040ad2ddbccab5a5..e7c27e93957405895a3b2cb0e33bf638feb0167b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h @@ -28,7 +28,7 @@ class GatherCPUKernel : public LiteKernel { const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~GatherCPUKernel() override; + ~GatherCPUKernel() = default; int Init() override; int ReSize() override; @@ -37,6 +37,7 @@ class GatherCPUKernel : public LiteKernel { private: int *indices_data_ = nullptr; + int AssignIndicesData(bool isIndicesInt32, int indices_num, lite::tensor::Tensor *indices_tensor); }; } // namespace mindspore::kernel