提交 33784fef 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4214 modify arm cpu op: embedding lookup and elu

Merge pull request !4214 from 陶云浩/modify
......@@ -28,12 +28,18 @@ namespace mindspore::kernel {
int EluCPUKernel::Init() {
elu_parameter_ = reinterpret_cast<EluParameter *>(opParameter);
elu_parameter_->thread_num_ = thread_count_;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int EluCPUKernel::ReSize() {
elu_parameter_->in_size_ = inputs_.front()->ElementsNum();
return RET_OK;
}
int EluCPUKernel::ReSize() { return RET_OK; }
int EluCPUKernel::DoExcute(int task_id) { Elu(input_addr, output_addr, elu_parameter_, task_id); }
int EluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
......@@ -47,6 +53,11 @@ int EluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
}
int EluCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
input_addr = reinterpret_cast<float *>(inputs_.front()->Data());
output_addr = reinterpret_cast<float *>(outputs_.front()->Data());
......
......@@ -26,12 +26,16 @@ using mindspore::schema::PrimitiveType_EmbeddingLookup;
namespace mindspore::kernel {
int EmbeddingLookupCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
SetNeedReInit();
return RET_OK;
}
embedding_lookup_parameter_ = reinterpret_cast<EmbeddingLookupParameter *>(opParameter);
embedding_lookup_parameter_->thread_num = thread_count_;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int EmbeddingLookupCPUKernel::ReSize() {
embedding_lookup_parameter_->ids_size_ = inputs_.back()->ElementsNum();
embedding_lookup_parameter_->layer_size_ = 1;
......@@ -45,18 +49,34 @@ int EmbeddingLookupCPUKernel::Init() {
embedding_lookup_parameter_->layer_num_ += inputs_[i]->shape()[0];
}
input_addr_ = reinterpret_cast<float *>(
std::malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_));
if (input_addr_ != nullptr) {
free(input_addr_);
}
if (context_ != nullptr && context_->allocator != nullptr) {
input_addr_ = reinterpret_cast<float *>(context_->allocator->Malloc(
sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_));
} else {
input_addr_ = reinterpret_cast<float *>(
malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_));
}
if (input_addr_ == nullptr) {
MS_LOG(ERROR) << "Create memory failed";
return mindspore::lite::RET_MEMORY_FAILED;
MS_LOG(ERROR) << "Malloc buffer failed";
return RET_ERROR;
}
embedding_lookup_parameter_->is_regulated_ =
reinterpret_cast<bool *>(std::malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_));
if (embedding_lookup_parameter_->is_regulated_ != nullptr) {
free(embedding_lookup_parameter_->is_regulated_);
}
if (context_ != nullptr && context_->allocator != nullptr) {
embedding_lookup_parameter_->is_regulated_ =
reinterpret_cast<bool *>(context_->allocator->Malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_));
} else {
embedding_lookup_parameter_->is_regulated_ =
reinterpret_cast<bool *>(malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_));
}
if (embedding_lookup_parameter_->is_regulated_ == nullptr) {
MS_LOG(ERROR) << "Create memory failed";
return mindspore::lite::RET_MEMORY_FAILED;
MS_LOG(ERROR) << "Malloc buffer failed";
return RET_ERROR;
}
for (int i = 0; i < embedding_lookup_parameter_->layer_num_; ++i) {
......@@ -66,8 +86,6 @@ int EmbeddingLookupCPUKernel::Init() {
return RET_OK;
}
int EmbeddingLookupCPUKernel::ReSize() { return RET_OK; }
int EmbeddingLookupCPUKernel::DoExcute(int task_id) {
int error_code = EmbeddingLookup(input_addr_, ids_addr_, output_addr_, embedding_lookup_parameter_, task_id);
if (error_code != RET_OK) {
......
......@@ -28,7 +28,14 @@ class EmbeddingLookupCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {}
~EmbeddingLookupCPUKernel() override{};
~EmbeddingLookupCPUKernel() override {
if (input_addr_ != nullptr) {
free(input_addr_);
}
if (embedding_lookup_parameter_->is_regulated_ != nullptr) {
free(embedding_lookup_parameter_->is_regulated_);
}
};
int Init() override;
int ReSize() override;
......
......@@ -15,7 +15,6 @@
*/
#include "src/runtime/kernel/arm/nnacl/fp32/elu.h"
#include <string.h>
#include "include/errorcode.h"
#include "src/runtime/kernel/arm/nnacl/errorcode.h"
#include "mindspore/core/utils/log_adapter.h"
......
......@@ -19,12 +19,12 @@
#include "src/runtime/kernel/arm/nnacl/op_base.h"
struct EluParameter {
typedef struct {
OpParameter op_parameter_;
float alpha_;
int thread_num_;
int in_size_;
};
} EluParameter;
int Elu(float *input_data, float *output_data, EluParameter *parameter, int task_id);
......
......@@ -15,7 +15,6 @@
*/
#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h"
#include <string.h>
#include "include/errorcode.h"
#include "src/runtime/kernel/arm/nnacl/errorcode.h"
#include "mindspore/core/utils/log_adapter.h"
......
......@@ -19,15 +19,15 @@
#include "src/runtime/kernel/arm/nnacl/op_base.h"
struct EmbeddingLookupParameter {
OpParameter op_parameter_;
bool *is_regulated_;
float max_norm_;
int ids_size_;
int layer_size_;
int layer_num_;
int thread_num;
};
typedef struct {
OpParameter op_parameter_;
bool *is_regulated_;
float max_norm_;
int ids_size_;
int layer_size_;
int layer_num_;
int thread_num;
} EmbeddingLookupParameter;
int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册