提交 c35ac033 编写于 作者: S sunsuodong

fix gather

上级 97563d5a
......@@ -37,13 +37,6 @@ int GatherCPUKernel::Init() {
return ReSize();
}
GatherCPUKernel::~GatherCPUKernel() {
if (indices_data_ != nullptr) {
free(indices_data_);
indices_data_ = nullptr;
}
}
int GatherCPUKernel::ReSize() { return RET_OK; }
int GatherCPUKernel::DoGather(int task_id) {
......@@ -105,28 +98,45 @@ int GatherCPUKernel::Run() {
}
auto indices_tensor = in_tensors_.at(1);
indices_data_ = reinterpret_cast<int *>(malloc(indices_tensor->Size()));
if (indices_data_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
return RET_ERROR;
int indices_num = indices_tensor->ElementsNum();
bool isIndicesInt32 = indices_tensor->data_type() == kNumberTypeInt32;
int ret = AssignIndicesData(isIndicesInt32, indices_num, indices_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
return ret;
}
auto in_shape = in_tensors_.at(0)->shape();
int indices_element_size = indices_tensor->ElementsNum();
auto axis = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;;
auto indices_ptr = reinterpret_cast<float *>(indices_tensor->Data());
const int limit = in_shape[axis];
for (int i = 0; i < indices_element_size; ++i) {
indices_data_[i] = static_cast<int>(indices_ptr[i]);
if (indices_data_[i] >= limit) {
MS_LOG(ERROR) << " indice data: " << indices_data_[i] << " is not in [ 0, " << limit - 1 << " ]";
ret = ParallelLaunch(THREAD_POOL_DEFAULT, GatherRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
}
if (!isIndicesInt32) {
context_->allocator->Free(indices_data_);
indices_data_ = nullptr;
}
return ret;
}
int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, lite::tensor::Tensor *indices_tensor) {
if (!isIndicesInt32) {
indices_data_ = reinterpret_cast<int32_t *>(context_->allocator->Malloc(sizeof(int32_t) * indices_num));
if (indices_data_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
return RET_ERROR;
}
if (indices_tensor->data_type() == kNumberTypeInt64) {
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = reinterpret_cast<int64_t *>(indices_tensor->Data())[i];
}
} else {
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = reinterpret_cast<float *>(indices_tensor->Data())[i];
}
}
} else {
indices_data_ = reinterpret_cast<int32_t *>(indices_tensor->Data());
}
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, GatherRun, this, op_parameter_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]";
}
return error_code;
return RET_OK;
}
kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
......
......@@ -28,7 +28,7 @@ class GatherCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~GatherCPUKernel() override;
~GatherCPUKernel() = default;
int Init() override;
int ReSize() override;
......@@ -37,6 +37,7 @@ class GatherCPUKernel : public LiteKernel {
private:
int *indices_data_ = nullptr;
int AssignIndicesData(bool isIndicesInt32, int indices_num, lite::tensor::Tensor *indices_tensor);
};
} // namespace mindspore::kernel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册