From 8ef079077bcaa7adfb8736718f956b22c6a3c450 Mon Sep 17 00:00:00 2001 From: tao_yunhao Date: Fri, 7 Aug 2020 09:10:50 +0800 Subject: [PATCH] add arm cpu op: embedding lookup --- mindspore/lite/schema/ops.fbs | 3 +- mindspore/lite/src/model_impl.cc | 2 + mindspore/lite/src/ops/embedding_lookup.cc | 63 +++++++++ mindspore/lite/src/ops/ops.cc | 2 + mindspore/lite/src/ops/ops.h | 7 + mindspore/lite/src/populate_parameter.cc | 19 +++ .../src/runtime/kernel/arm/fp32/arithmetic.h | 8 +- .../kernel/arm/fp32/embedding_lookup.cc | 130 ++++++++++++++++++ .../kernel/arm/fp32/embedding_lookup.h | 49 +++++++ .../kernel/arm/int8/arithmetic_int8.cc | 4 +- .../kernel/arm/nnacl/fp32/embedding_lookup.cc | 60 ++++++++ .../kernel/arm/nnacl/fp32/embedding_lookup.h | 34 +++++ .../arm/fp32/embedding_lookup_fp32_test.cc | 85 ++++++++++++ 13 files changed, 458 insertions(+), 8 deletions(-) create mode 100644 mindspore/lite/src/ops/embedding_lookup.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 78278b0c3..e4672717d 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -723,8 +723,7 @@ table AddN { table EmbeddingLookup { - ids: [int]; - maxNorm: float; + maxNorm: float = 0.0; } table EmbeddingLookupSparse { diff --git a/mindspore/lite/src/model_impl.cc b/mindspore/lite/src/model_impl.cc index eef368204..abead7cba 100644 --- a/mindspore/lite/src/model_impl.cc +++ b/mindspore/lite/src/model_impl.cc @@ -216,6 +216,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { return new lite::MatMul(const_cast(srcPrim)); case schema::PrimitiveType_QuantDTypeCast: return new lite::QuantDTypeCast(const_cast(srcPrim)); + case schema::PrimitiveType_EmbeddingLookup: + return new lite::EmbeddingLookup(const_cast(srcPrim)); default: break; } diff --git a/mindspore/lite/src/ops/embedding_lookup.cc b/mindspore/lite/src/ops/embedding_lookup.cc new file mode 100644 index 000000000..3a2519761 --- /dev/null +++ b/mindspore/lite/src/ops/embedding_lookup.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "src/ir/tensor.h" +#include "utils/log_adapter.h" + +namespace mindspore::lite { +int EmbeddingLookup::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() < kDoubleNum) { + MS_LOG(ERROR) << "Embedding Lookup should have at least two inputs"; + return RET_INPUT_TENSOR_ERROR; + } + + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "Embedding Lookup should have one outputs"; + return RET_INPUT_TENSOR_ERROR; + } + + auto params_ = inputs_.front(); + MS_ASSERT(params_ != nullptr); + auto ids = inputs_.back(); + MS_ASSERT(ids != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + auto embedding_shape = params_->shape(); + embedding_shape.erase(embedding_shape.begin()); + + std::vector output_shape(ids->shape()); + for (size_t i = 0; i < embedding_shape.size(); ++i) { + output_shape.push_back(embedding_shape.at(i)); + } + + for (int i = 1; i < inputs_.size() - 1; ++i) { + auto embedding_shape_t = inputs_.at(i)->shape(); + embedding_shape_t.erase(embedding_shape_t.begin()); + if (embedding_shape_t != embedding_shape) { + MS_LOG(ERROR) << "The embedded layers should have the same shape"; + return RET_INPUT_TENSOR_ERROR; + } + } + + output->set_shape(output_shape); + output->set_data_type(params_->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/ops.cc b/mindspore/lite/src/ops/ops.cc index 06da5561f..9f771cbc3 100644 --- a/mindspore/lite/src/ops/ops.cc +++ b/mindspore/lite/src/ops/ops.cc @@ -141,6 +141,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) { return new lite::QuantDTypeCast(const_cast(primitive)); case schema::PrimitiveType_MatMul: return new lite::MatMul(const_cast(primitive)); + case schema::PrimitiveType_EmbeddingLookup: + return new lite::EmbeddingLookup(const_cast(primitive)); default: break; } diff --git a/mindspore/lite/src/ops/ops.h b/mindspore/lite/src/ops/ops.h index f199b93c5..302f085b9 100644 --- a/mindspore/lite/src/ops/ops.h +++ b/mindspore/lite/src/ops/ops.h @@ -778,6 +778,13 @@ class Lstm : public Primitive { const schema::Lstm *GetAttribute() const { return this->primitive->value_as_Lstm(); } int InferShape(std::vector inputs, std::vector outputs) override; }; + +class EmbeddingLookup : public Primitive { + public: + explicit EmbeddingLookup(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::EmbeddingLookup *GetAttribute() const { return this->primitive->value_as_EmbeddingLookup(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_SRC_OPS_OPS_H_ diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index ced813304..524c35588 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -69,6 +69,7 @@ #include "src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h" #include "src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h" #include "src/runtime/kernel/arm/nnacl/fp32/lstm.h" +#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" namespace mindspore::kernel { OpParameter *PopulateBatchNorm(const lite::Primitive *primitive) { @@ -1192,6 +1193,23 @@ OpParameter *PopulateLstmParameter(const lite::Primitive *primitive) { return reinterpret_cast(lstm_param); } +OpParameter *PopulateEmbeddingLookupParameter(const lite::Primitive *primitive) { + EmbeddingLookupParameter *embedding_lookup_parameter = new (std::nothrow) EmbeddingLookupParameter(); + if (embedding_lookup_parameter == nullptr) { + MS_LOG(ERROR) << "new EmbeddingLookupParameter failed"; + return nullptr; + } + embedding_lookup_parameter->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_EmbeddingLookup(); + embedding_lookup_parameter->max_norm_ = param->maxNorm(); + if (embedding_lookup_parameter->max_norm_ < 0) { + MS_LOG(ERROR) << "Embedding lookup max norm should be positive number, got " + << embedding_lookup_parameter->max_norm_; + return nullptr; + } + return reinterpret_cast(embedding_lookup_parameter); +} + PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter; @@ -1269,6 +1287,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter; populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter; populate_parameter_funcs_[schema::PrimitiveType_Lstm] = PopulateLstmParameter; + populate_parameter_funcs_[schema::PrimitiveType_EmbeddingLookup] = PopulateEmbeddingLookupParameter; } PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h index 968f5fbb4..ab38f7e82 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h @@ -100,12 +100,12 @@ class ArithmeticCPUKernel : public LiteKernel { arithmetic_broadcast_run_ = BroadcastNotEqual; break; case PrimitiveType_Less: - arithmetic_run_ = ElementEqual; - arithmetic_broadcast_run_ = BroadcastEqual; + arithmetic_run_ = ElementLess; + arithmetic_broadcast_run_ = BroadcastLess; break; case PrimitiveType_LessEqual: - arithmetic_run_ = ElementNotEqual; - arithmetic_broadcast_run_ = BroadcastNotEqual; + arithmetic_run_ = ElementLessEqual; + arithmetic_broadcast_run_ = BroadcastLessEqual; break; case PrimitiveType_Greater: arithmetic_run_ = ElementGreater; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc new file mode 100644 index 000000000..0904f90b5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/embedding_lookup.h" +#include "include/errorcode.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_EmbeddingLookup; + +namespace mindspore::kernel { +int EmbeddingLookupCPUKernel::Init() { + embedding_lookup_parameter_ = reinterpret_cast(opParameter); + embedding_lookup_parameter_->thread_num = thread_count_; + embedding_lookup_parameter_->ids_size_ = inputs_.back()->ElementsNum(); + + embedding_lookup_parameter_->layer_size_ = 1; + auto in_shape = inputs_.front()->shape(); + for (int i = 1; i < in_shape.size(); ++i) { + embedding_lookup_parameter_->layer_size_ *= in_shape[i]; + } + + embedding_lookup_parameter_->layer_num_ = 0; + for (int i = 0; i < inputs_.size() - 1; ++i) { + embedding_lookup_parameter_->layer_num_ += inputs_[i]->shape()[0]; + } + + input_addr_ = reinterpret_cast( + std::malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_)); + if (input_addr_ == nullptr) { + MS_LOG(ERROR) << "Create memory failed"; + return mindspore::lite::RET_MEMORY_FAILED; + } + + embedding_lookup_parameter_->is_regulated_ = + reinterpret_cast(std::malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_)); + if (embedding_lookup_parameter_->is_regulated_ == nullptr) { + MS_LOG(ERROR) << "Create memory failed"; + return mindspore::lite::RET_MEMORY_FAILED; + } + + for (int i = 0; i < embedding_lookup_parameter_->layer_num_; ++i) { + embedding_lookup_parameter_->is_regulated_[i] = embedding_lookup_parameter_->max_norm_ == 0; + } + + return RET_OK; +} + +int EmbeddingLookupCPUKernel::ReSize() { return RET_OK; } + +int EmbeddingLookupCPUKernel::DoExcute(int task_id) { + int error_code = EmbeddingLookup(input_addr_, ids_addr_, output_addr_, embedding_lookup_parameter_, task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "embedding lookup error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int EmbeddingLookupRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto EmbeddingLookupData = reinterpret_cast(cdata); + auto ret = EmbeddingLookupData->DoExcute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "EmbeddingLookupRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int EmbeddingLookupCPUKernel::Run() { + int dest_loc = 0; + for (int i = 0; i < inputs_.size() - 1; i++) { + auto input_t = reinterpret_cast(inputs_.at(i)->Data()); + memcpy(input_addr_ + dest_loc, input_t, sizeof(float) * inputs_.at(i)->ElementsNum()); + dest_loc += inputs_.at(i)->ElementsNum(); + } + output_addr_ = reinterpret_cast(outputs_.front()->Data()); + ids_addr_ = reinterpret_cast(inputs_.back()->Data()); + + auto ret = LiteBackendParallelLaunch(EmbeddingLookupRun, this, embedding_lookup_parameter_->thread_num); + if (ret != RET_OK) { + MS_LOG(ERROR) << "EmbeddingLookup error: error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuEmbeddingLookupFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + if (parameter == nullptr || ctx == nullptr) { + MS_LOG(ERROR) << "parameter or ctx is nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_EmbeddingLookup); + auto *kernel = new (std::nothrow) EmbeddingLookupCPUKernel(parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create Kernel failed, name: " << parameter->name_; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init Kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_EmbeddingLookup, CpuEmbeddingLookupFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h new file mode 100644 index 000000000..6afa0d562 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" + +namespace mindspore::kernel { +class EmbeddingLookupCPUKernel : public LiteKernel { + public: + explicit EmbeddingLookupCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + ~EmbeddingLookupCPUKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + int DoExcute(int task_id); + + protected: + int thread_count_; + const lite::Context *ctx_; + EmbeddingLookupParameter *embedding_lookup_parameter_; + + private: + float *input_addr_; + float *output_addr_; + int *ids_addr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc index a0f470467..3c01ca389 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc @@ -81,10 +81,10 @@ int ArithmeticInt8CPUKernel::Init() { arithmetic_run_ = ElementNotEqual; break; case PrimitiveType_Less: - arithmetic_run_ = ElementEqual; + arithmetic_run_ = ElementLess; break; case PrimitiveType_LessEqual: - arithmetic_run_ = ElementNotEqual; + arithmetic_run_ = ElementLessEqual; break; case PrimitiveType_Greater: arithmetic_run_ = ElementGreater; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc new file mode 100644 index 000000000..964041fa3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" +#include +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "mindspore/core/utils/log_adapter.h" + +void l2_regulate(float *data, int size, float max_norm) { + float sum = 0; + for (int i = 0; i < size; ++i) { + sum += data[i]; + } + if (sum != 0) { + for (int i = 0; i < size; ++i) { + data[i] *= max_norm / sum; + } + } + return; +} + +int CopyData(float *input_data, int *ids, float *output_data, int num, EmbeddingLookupParameter *parameter) { + if (ids[num] >= parameter->layer_num_ || ids[num] < 0) { + MS_LOG(ERROR) << "Embedding lookup index out of range"; + return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; + } + float *out_data = output_data + num * parameter->layer_size_; + float *in_data = input_data + ids[num] * parameter->layer_size_; + if (!parameter->is_regulated_[ids[num]]) { + l2_regulate(in_data, parameter->layer_size_, parameter->max_norm_); + parameter->is_regulated_[ids[num]] = true; + } + + memcpy(out_data, in_data, sizeof(float) * parameter->layer_size_); + return NNACL_OK; +} + +int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id) { + for (size_t i = task_id; i < parameter->ids_size_; i += parameter->thread_num) { + int ret = CopyData(input_data, ids, output_data, i, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h new file mode 100644 index 000000000..fa9f0ce5d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct EmbeddingLookupParameter { + OpParameter op_parameter_; + bool *is_regulated_; + float max_norm_; + int ids_size_; + int layer_size_; + int layer_num_; + int thread_num; +}; + +int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc new file mode 100644 index 000000000..1f6d2997c --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/fp32/embedding_lookup.h" +#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" +#include "src/common/file_utils.h" +#include "common/common_test.h" +#include "utils/log_adapter.h" + +namespace mindspore { +using mindspore::lite::tensor::Tensor; + +class TestEmbeddingLookupFp32 : public mindspore::Common { + public: + TestEmbeddingLookupFp32() {} +}; + +void ElTestInit(std::vector *inputs_, std::vector *outputs_, + EmbeddingLookupParameter *embedding_lookup_param) { + Tensor *in_t_first = new Tensor(kNumberTypeFloat32, {6, 2}, schema::Format_NHWC, static_cast(1)); + in_t_first->MallocData(); + float in_first[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + memcpy(in_t_first->Data(), in_first, sizeof(float) * in_t_first->ElementsNum()); + inputs_->push_back(in_t_first); + + Tensor *in_t_second = new Tensor(kNumberTypeFloat32, {4, 2}, schema::Format_NHWC, static_cast(1)); + in_t_second->MallocData(); + float in_second[] = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}; + memcpy(in_t_second->Data(), in_second, sizeof(float) * in_t_second->ElementsNum()); + inputs_->push_back(in_t_second); + + Tensor *ids_t = new Tensor(kNumberTypeFloat32, {2, 3}, schema::Format_NHWC, static_cast(1)); + ids_t->MallocData(); + int ids[] = {1, 9, 2, 4, 6, 7}; + memcpy(ids_t->Data(), ids, sizeof(int) * ids_t->ElementsNum()); + inputs_->push_back(ids_t); + + Tensor *outputs_t = new Tensor(kNumberTypeInt32, {2, 3, 2}, schema::Format_NHWC, static_cast(1)); + outputs_t->MallocData(); + outputs_->push_back(outputs_t); + + embedding_lookup_param->max_norm_ = 1; +} + +TEST_F(TestEmbeddingLookupFp32, ElTest) { + std::vector inputs_; + std::vector outputs_; + auto embedding_lookup_param_ = new EmbeddingLookupParameter(); + ElTestInit(&inputs_, &outputs_, embedding_lookup_param_); + + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::EmbeddingLookupCPUKernel *el = new kernel::EmbeddingLookupCPUKernel( + reinterpret_cast(embedding_lookup_param_), inputs_, outputs_, ctx); + + el->Init(); + el->Run(); + + std::cout << "output shape:" << std::endl; + for (int i = 0; i < outputs_.front()->shape().size(); ++i) { + std::cout << outputs_.front()->shape()[i] << ' '; + } + std::cout << std::endl; + float *out = reinterpret_cast(outputs_.front()->Data()); + for (int i = 0; i < outputs_.front()->ElementsNum(); ++i) { + std::cout << out[i] << ' '; + } + std::cout << std::endl; +} + +} // namespace mindspore -- GitLab