提交 6d4d692f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4033 add arm cpu op: embedding_lookup

Merge pull request !4033 from 陶云浩/lite
...@@ -727,8 +727,7 @@ table AddN { ...@@ -727,8 +727,7 @@ table AddN {
table EmbeddingLookup { table EmbeddingLookup {
ids: [int]; maxNorm: float = 0.0;
maxNorm: float;
} }
table EmbeddingLookupSparse { table EmbeddingLookupSparse {
......
...@@ -216,6 +216,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { ...@@ -216,6 +216,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim)); return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_QuantDTypeCast: case schema::PrimitiveType_QuantDTypeCast:
return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(srcPrim)); return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_EmbeddingLookup:
return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(srcPrim));
default: default:
break; break;
} }
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "src/ir/tensor.h"
#include "utils/log_adapter.h"
namespace mindspore::lite {
int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (inputs_.size() < kDoubleNum) {
MS_LOG(ERROR) << "Embedding Lookup should have at least two inputs";
return RET_INPUT_TENSOR_ERROR;
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "Embedding Lookup should have one outputs";
return RET_INPUT_TENSOR_ERROR;
}
auto params_ = inputs_.front();
MS_ASSERT(params_ != nullptr);
auto ids = inputs_.back();
MS_ASSERT(ids != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
auto embedding_shape = params_->shape();
embedding_shape.erase(embedding_shape.begin());
std::vector<int> output_shape(ids->shape());
for (size_t i = 0; i < embedding_shape.size(); ++i) {
output_shape.push_back(embedding_shape.at(i));
}
for (int i = 1; i < inputs_.size() - 1; ++i) {
auto embedding_shape_t = inputs_.at(i)->shape();
embedding_shape_t.erase(embedding_shape_t.begin());
if (embedding_shape_t != embedding_shape) {
MS_LOG(ERROR) << "The embedded layers should have the same shape";
return RET_INPUT_TENSOR_ERROR;
}
}
output->set_shape(output_shape);
output->set_data_type(params_->data_type());
return RET_OK;
}
} // namespace mindspore::lite
...@@ -141,6 +141,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) { ...@@ -141,6 +141,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) {
return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(primitive)); return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_MatMul: case schema::PrimitiveType_MatMul:
return new lite::MatMul(const_cast<schema::Primitive *>(primitive)); return new lite::MatMul(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_EmbeddingLookup:
return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(primitive));
default: default:
break; break;
} }
......
...@@ -778,6 +778,13 @@ class Lstm : public Primitive { ...@@ -778,6 +778,13 @@ class Lstm : public Primitive {
const schema::Lstm *GetAttribute() const { return this->primitive->value_as_Lstm(); } const schema::Lstm *GetAttribute() const { return this->primitive->value_as_Lstm(); }
int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override; int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override;
}; };
class EmbeddingLookup : public Primitive {
public:
explicit EmbeddingLookup(schema::Primitive *primitive) : Primitive(primitive) {}
const schema::EmbeddingLookup *GetAttribute() const { return this->primitive->value_as_EmbeddingLookup(); }
int InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) override;
};
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_OPS_H_ #endif // MINDSPORE_LITE_SRC_OPS_OPS_H_
...@@ -69,6 +69,7 @@ ...@@ -69,6 +69,7 @@
#include "src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h" #include "src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h"
#include "src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h" #include "src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h"
#include "src/runtime/kernel/arm/nnacl/fp32/lstm.h" #include "src/runtime/kernel/arm/nnacl/fp32/lstm.h"
#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h"
namespace mindspore::kernel { namespace mindspore::kernel {
OpParameter *PopulateBatchNorm(const lite::Primitive *primitive) { OpParameter *PopulateBatchNorm(const lite::Primitive *primitive) {
...@@ -1209,6 +1210,23 @@ OpParameter *PopulateLstmParameter(const lite::Primitive *primitive) { ...@@ -1209,6 +1210,23 @@ OpParameter *PopulateLstmParameter(const lite::Primitive *primitive) {
return reinterpret_cast<OpParameter *>(lstm_param); return reinterpret_cast<OpParameter *>(lstm_param);
} }
OpParameter *PopulateEmbeddingLookupParameter(const lite::Primitive *primitive) {
EmbeddingLookupParameter *embedding_lookup_parameter = new (std::nothrow) EmbeddingLookupParameter();
if (embedding_lookup_parameter == nullptr) {
MS_LOG(ERROR) << "new EmbeddingLookupParameter failed";
return nullptr;
}
embedding_lookup_parameter->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_EmbeddingLookup();
embedding_lookup_parameter->max_norm_ = param->maxNorm();
if (embedding_lookup_parameter->max_norm_ < 0) {
MS_LOG(ERROR) << "Embedding lookup max norm should be positive number, got "
<< embedding_lookup_parameter->max_norm_;
return nullptr;
}
return reinterpret_cast<OpParameter *>(embedding_lookup_parameter);
}
PopulateParameterRegistry::PopulateParameterRegistry() { PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter;
populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter; populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter;
...@@ -1286,6 +1304,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { ...@@ -1286,6 +1304,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter; populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter;
populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter; populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter;
populate_parameter_funcs_[schema::PrimitiveType_Lstm] = PopulateLstmParameter; populate_parameter_funcs_[schema::PrimitiveType_Lstm] = PopulateLstmParameter;
populate_parameter_funcs_[schema::PrimitiveType_EmbeddingLookup] = PopulateEmbeddingLookupParameter;
} }
PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {
......
...@@ -137,12 +137,12 @@ class ArithmeticCPUKernel : public LiteKernel { ...@@ -137,12 +137,12 @@ class ArithmeticCPUKernel : public LiteKernel {
arithmetic_broadcast_run_ = BroadcastNotEqual; arithmetic_broadcast_run_ = BroadcastNotEqual;
break; break;
case PrimitiveType_Less: case PrimitiveType_Less:
arithmetic_run_ = ElementEqual; arithmetic_run_ = ElementLess;
arithmetic_broadcast_run_ = BroadcastEqual; arithmetic_broadcast_run_ = BroadcastLess;
break; break;
case PrimitiveType_LessEqual: case PrimitiveType_LessEqual:
arithmetic_run_ = ElementNotEqual; arithmetic_run_ = ElementLessEqual;
arithmetic_broadcast_run_ = BroadcastNotEqual; arithmetic_broadcast_run_ = BroadcastLessEqual;
break; break;
case PrimitiveType_Greater: case PrimitiveType_Greater:
arithmetic_run_ = ElementGreater; arithmetic_run_ = ElementGreater;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/embedding_lookup.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_EmbeddingLookup;
namespace mindspore::kernel {
int EmbeddingLookupCPUKernel::Init() {
embedding_lookup_parameter_ = reinterpret_cast<EmbeddingLookupParameter *>(opParameter);
embedding_lookup_parameter_->thread_num = thread_count_;
embedding_lookup_parameter_->ids_size_ = inputs_.back()->ElementsNum();
embedding_lookup_parameter_->layer_size_ = 1;
auto in_shape = inputs_.front()->shape();
for (int i = 1; i < in_shape.size(); ++i) {
embedding_lookup_parameter_->layer_size_ *= in_shape[i];
}
embedding_lookup_parameter_->layer_num_ = 0;
for (int i = 0; i < inputs_.size() - 1; ++i) {
embedding_lookup_parameter_->layer_num_ += inputs_[i]->shape()[0];
}
input_addr_ = reinterpret_cast<float *>(
std::malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_));
if (input_addr_ == nullptr) {
MS_LOG(ERROR) << "Create memory failed";
return mindspore::lite::RET_MEMORY_FAILED;
}
embedding_lookup_parameter_->is_regulated_ =
reinterpret_cast<bool *>(std::malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_));
if (embedding_lookup_parameter_->is_regulated_ == nullptr) {
MS_LOG(ERROR) << "Create memory failed";
return mindspore::lite::RET_MEMORY_FAILED;
}
for (int i = 0; i < embedding_lookup_parameter_->layer_num_; ++i) {
embedding_lookup_parameter_->is_regulated_[i] = embedding_lookup_parameter_->max_norm_ == 0;
}
return RET_OK;
}
int EmbeddingLookupCPUKernel::ReSize() { return RET_OK; }
int EmbeddingLookupCPUKernel::DoExcute(int task_id) {
int error_code = EmbeddingLookup(input_addr_, ids_addr_, output_addr_, embedding_lookup_parameter_, task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "embedding lookup error error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int EmbeddingLookupRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto EmbeddingLookupData = reinterpret_cast<EmbeddingLookupCPUKernel *>(cdata);
auto ret = EmbeddingLookupData->DoExcute(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "EmbeddingLookupRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int EmbeddingLookupCPUKernel::Run() {
int dest_loc = 0;
for (int i = 0; i < inputs_.size() - 1; i++) {
auto input_t = reinterpret_cast<float *>(inputs_.at(i)->Data());
memcpy(input_addr_ + dest_loc, input_t, sizeof(float) * inputs_.at(i)->ElementsNum());
dest_loc += inputs_.at(i)->ElementsNum();
}
output_addr_ = reinterpret_cast<float *>(outputs_.front()->Data());
ids_addr_ = reinterpret_cast<int *>(inputs_.back()->Data());
auto ret = LiteBackendParallelLaunch(EmbeddingLookupRun, this, embedding_lookup_parameter_->thread_num);
if (ret != RET_OK) {
MS_LOG(ERROR) << "EmbeddingLookup error: error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
kernel::LiteKernel *CpuEmbeddingLookupFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *parameter, const lite::Context *ctx,
const KernelKey &desc) {
if (parameter == nullptr || ctx == nullptr) {
MS_LOG(ERROR) << "parameter or ctx is nullptr";
return nullptr;
}
MS_ASSERT(desc.type == PrimitiveType_EmbeddingLookup);
auto *kernel = new (std::nothrow) EmbeddingLookupCPUKernel(parameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create Kernel failed, name: " << parameter->name_;
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init Kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_EmbeddingLookup, CpuEmbeddingLookupFp32KernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h"
namespace mindspore::kernel {
class EmbeddingLookupCPUKernel : public LiteKernel {
public:
explicit EmbeddingLookupCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {}
~EmbeddingLookupCPUKernel() override{};
int Init() override;
int ReSize() override;
int Run() override;
int DoExcute(int task_id);
protected:
int thread_count_;
const lite::Context *ctx_;
EmbeddingLookupParameter *embedding_lookup_parameter_;
private:
float *input_addr_;
float *output_addr_;
int *ids_addr_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_
...@@ -81,10 +81,10 @@ int ArithmeticInt8CPUKernel::Init() { ...@@ -81,10 +81,10 @@ int ArithmeticInt8CPUKernel::Init() {
arithmetic_run_ = ElementNotEqual; arithmetic_run_ = ElementNotEqual;
break; break;
case PrimitiveType_Less: case PrimitiveType_Less:
arithmetic_run_ = ElementEqual; arithmetic_run_ = ElementLess;
break; break;
case PrimitiveType_LessEqual: case PrimitiveType_LessEqual:
arithmetic_run_ = ElementNotEqual; arithmetic_run_ = ElementLessEqual;
break; break;
case PrimitiveType_Greater: case PrimitiveType_Greater:
arithmetic_run_ = ElementGreater; arithmetic_run_ = ElementGreater;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h"
#include <string.h>
#include "include/errorcode.h"
#include "src/runtime/kernel/arm/nnacl/errorcode.h"
#include "mindspore/core/utils/log_adapter.h"
void l2_regulate(float *data, int size, float max_norm) {
float sum = 0;
for (int i = 0; i < size; ++i) {
sum += data[i];
}
if (sum != 0) {
for (int i = 0; i < size; ++i) {
data[i] *= max_norm / sum;
}
}
return;
}
int CopyData(float *input_data, int *ids, float *output_data, int num, EmbeddingLookupParameter *parameter) {
if (ids[num] >= parameter->layer_num_ || ids[num] < 0) {
MS_LOG(ERROR) << "Embedding lookup index out of range";
return NNACL_ERRCODE_INDEX_OUT_OF_RANGE;
}
float *out_data = output_data + num * parameter->layer_size_;
float *in_data = input_data + ids[num] * parameter->layer_size_;
if (!parameter->is_regulated_[ids[num]]) {
l2_regulate(in_data, parameter->layer_size_, parameter->max_norm_);
parameter->is_regulated_[ids[num]] = true;
}
memcpy(out_data, in_data, sizeof(float) * parameter->layer_size_);
return NNACL_OK;
}
int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id) {
for (size_t i = task_id; i < parameter->ids_size_; i += parameter->thread_num) {
int ret = CopyData(input_data, ids, output_data, i, parameter);
if (ret != NNACL_OK) {
return ret;
}
}
return NNACL_OK;
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_
#include "src/runtime/kernel/arm/nnacl/op_base.h"
struct EmbeddingLookupParameter {
OpParameter op_parameter_;
bool *is_regulated_;
float max_norm_;
int ids_size_;
int layer_size_;
int layer_num_;
int thread_num;
};
int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "src/runtime/kernel/arm/fp32/embedding_lookup.h"
#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h"
#include "src/common/file_utils.h"
#include "common/common_test.h"
#include "utils/log_adapter.h"
namespace mindspore {
using mindspore::lite::tensor::Tensor;
class TestEmbeddingLookupFp32 : public mindspore::Common {
public:
TestEmbeddingLookupFp32() {}
};
void ElTestInit(std::vector<Tensor *> *inputs_, std::vector<Tensor *> *outputs_,
EmbeddingLookupParameter *embedding_lookup_param) {
Tensor *in_t_first = new Tensor(kNumberTypeFloat32, {6, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t_first->MallocData();
float in_first[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
memcpy(in_t_first->Data(), in_first, sizeof(float) * in_t_first->ElementsNum());
inputs_->push_back(in_t_first);
Tensor *in_t_second = new Tensor(kNumberTypeFloat32, {4, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t_second->MallocData();
float in_second[] = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8};
memcpy(in_t_second->Data(), in_second, sizeof(float) * in_t_second->ElementsNum());
inputs_->push_back(in_t_second);
Tensor *ids_t = new Tensor(kNumberTypeFloat32, {2, 3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
ids_t->MallocData();
int ids[] = {1, 9, 2, 4, 6, 7};
memcpy(ids_t->Data(), ids, sizeof(int) * ids_t->ElementsNum());
inputs_->push_back(ids_t);
Tensor *outputs_t = new Tensor(kNumberTypeInt32, {2, 3, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
outputs_t->MallocData();
outputs_->push_back(outputs_t);
embedding_lookup_param->max_norm_ = 1;
}
TEST_F(TestEmbeddingLookupFp32, ElTest) {
std::vector<Tensor *> inputs_;
std::vector<Tensor *> outputs_;
auto embedding_lookup_param_ = new EmbeddingLookupParameter();
ElTestInit(&inputs_, &outputs_, embedding_lookup_param_);
lite::Context *ctx = new lite::Context;
ctx->thread_num_ = 2;
kernel::EmbeddingLookupCPUKernel *el = new kernel::EmbeddingLookupCPUKernel(
reinterpret_cast<OpParameter *>(embedding_lookup_param_), inputs_, outputs_, ctx);
el->Init();
el->Run();
std::cout << "output shape:" << std::endl;
for (int i = 0; i < outputs_.front()->shape().size(); ++i) {
std::cout << outputs_.front()->shape()[i] << ' ';
}
std::cout << std::endl;
float *out = reinterpret_cast<float *>(outputs_.front()->Data());
for (int i = 0; i < outputs_.front()->ElementsNum(); ++i) {
std::cout << out[i] << ' ';
}
std::cout << std::endl;
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册