Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6d4d692f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6d4d692f
编写于
8月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4033 add arm cpu op: embedding_lookup
Merge pull request !4033 from 陶云浩/lite
上级
f8e4ab86
8ef07907
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
458 addition
and
8 deletion
+458
-8
mindspore/lite/schema/ops.fbs
mindspore/lite/schema/ops.fbs
+1
-2
mindspore/lite/src/model_impl.cc
mindspore/lite/src/model_impl.cc
+2
-0
mindspore/lite/src/ops/embedding_lookup.cc
mindspore/lite/src/ops/embedding_lookup.cc
+63
-0
mindspore/lite/src/ops/ops.cc
mindspore/lite/src/ops/ops.cc
+2
-0
mindspore/lite/src/ops/ops.h
mindspore/lite/src/ops/ops.h
+7
-0
mindspore/lite/src/populate_parameter.cc
mindspore/lite/src/populate_parameter.cc
+19
-0
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h
+4
-4
mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc
...pore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc
+130
-0
mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h
...spore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h
+49
-0
mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc
...spore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc
+2
-2
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc
...ite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc
+60
-0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h
...lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h
+34
-0
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc
...src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc
+85
-0
未找到文件。
mindspore/lite/schema/ops.fbs
浏览文件 @
6d4d692f
...
...
@@ -727,8 +727,7 @@ table AddN {
table EmbeddingLookup {
ids: [int];
maxNorm: float;
maxNorm: float = 0.0;
}
table EmbeddingLookupSparse {
...
...
mindspore/lite/src/model_impl.cc
浏览文件 @
6d4d692f
...
...
@@ -216,6 +216,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return
new
lite
::
MatMul
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
case
schema
::
PrimitiveType_QuantDTypeCast
:
return
new
lite
::
QuantDTypeCast
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
case
schema
::
PrimitiveType_EmbeddingLookup
:
return
new
lite
::
EmbeddingLookup
(
const_cast
<
schema
::
Primitive
*>
(
srcPrim
));
default:
break
;
}
...
...
mindspore/lite/src/ops/embedding_lookup.cc
0 → 100644
浏览文件 @
6d4d692f
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "src/ir/tensor.h"
#include "utils/log_adapter.h"
namespace
mindspore
::
lite
{
int
EmbeddingLookup
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
tensor
::
Tensor
*>
outputs_
)
{
MS_ASSERT
(
this
->
primitive
!=
nullptr
);
if
(
inputs_
.
size
()
<
kDoubleNum
)
{
MS_LOG
(
ERROR
)
<<
"Embedding Lookup should have at least two inputs"
;
return
RET_INPUT_TENSOR_ERROR
;
}
if
(
outputs_
.
size
()
!=
kSingleNum
)
{
MS_LOG
(
ERROR
)
<<
"Embedding Lookup should have one outputs"
;
return
RET_INPUT_TENSOR_ERROR
;
}
auto
params_
=
inputs_
.
front
();
MS_ASSERT
(
params_
!=
nullptr
);
auto
ids
=
inputs_
.
back
();
MS_ASSERT
(
ids
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
auto
embedding_shape
=
params_
->
shape
();
embedding_shape
.
erase
(
embedding_shape
.
begin
());
std
::
vector
<
int
>
output_shape
(
ids
->
shape
());
for
(
size_t
i
=
0
;
i
<
embedding_shape
.
size
();
++
i
)
{
output_shape
.
push_back
(
embedding_shape
.
at
(
i
));
}
for
(
int
i
=
1
;
i
<
inputs_
.
size
()
-
1
;
++
i
)
{
auto
embedding_shape_t
=
inputs_
.
at
(
i
)
->
shape
();
embedding_shape_t
.
erase
(
embedding_shape_t
.
begin
());
if
(
embedding_shape_t
!=
embedding_shape
)
{
MS_LOG
(
ERROR
)
<<
"The embedded layers should have the same shape"
;
return
RET_INPUT_TENSOR_ERROR
;
}
}
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
params_
->
data_type
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/ops.cc
浏览文件 @
6d4d692f
...
...
@@ -141,6 +141,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) {
return
new
lite
::
QuantDTypeCast
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_MatMul
:
return
new
lite
::
MatMul
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_EmbeddingLookup
:
return
new
lite
::
EmbeddingLookup
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
default:
break
;
}
...
...
mindspore/lite/src/ops/ops.h
浏览文件 @
6d4d692f
...
...
@@ -778,6 +778,13 @@ class Lstm : public Primitive {
const
schema
::
Lstm
*
GetAttribute
()
const
{
return
this
->
primitive
->
value_as_Lstm
();
}
int
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs
,
std
::
vector
<
tensor
::
Tensor
*>
outputs
)
override
;
};
class
EmbeddingLookup
:
public
Primitive
{
public:
explicit
EmbeddingLookup
(
schema
::
Primitive
*
primitive
)
:
Primitive
(
primitive
)
{}
const
schema
::
EmbeddingLookup
*
GetAttribute
()
const
{
return
this
->
primitive
->
value_as_EmbeddingLookup
();
}
int
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
tensor
::
Tensor
*>
outputs_
)
override
;
};
}
// namespace lite
}
// namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_OPS_H_
mindspore/lite/src/populate_parameter.cc
浏览文件 @
6d4d692f
...
...
@@ -69,6 +69,7 @@
#include "src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h"
#include "src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h"
#include "src/runtime/kernel/arm/nnacl/fp32/lstm.h"
#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h"
namespace
mindspore
::
kernel
{
OpParameter
*
PopulateBatchNorm
(
const
lite
::
Primitive
*
primitive
)
{
...
...
@@ -1209,6 +1210,23 @@ OpParameter *PopulateLstmParameter(const lite::Primitive *primitive) {
return
reinterpret_cast
<
OpParameter
*>
(
lstm_param
);
}
OpParameter
*
PopulateEmbeddingLookupParameter
(
const
lite
::
Primitive
*
primitive
)
{
EmbeddingLookupParameter
*
embedding_lookup_parameter
=
new
(
std
::
nothrow
)
EmbeddingLookupParameter
();
if
(
embedding_lookup_parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new EmbeddingLookupParameter failed"
;
return
nullptr
;
}
embedding_lookup_parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_EmbeddingLookup
();
embedding_lookup_parameter
->
max_norm_
=
param
->
maxNorm
();
if
(
embedding_lookup_parameter
->
max_norm_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Embedding lookup max norm should be positive number, got "
<<
embedding_lookup_parameter
->
max_norm_
;
return
nullptr
;
}
return
reinterpret_cast
<
OpParameter
*>
(
embedding_lookup_parameter
);
}
PopulateParameterRegistry
::
PopulateParameterRegistry
()
{
populate_parameter_funcs_
[
schema
::
PrimitiveType_SoftMax
]
=
PopulateSoftmaxParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Activation
]
=
PopulateActivationParameter
;
...
...
@@ -1286,6 +1304,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_
[
schema
::
PrimitiveType_PriorBox
]
=
PopulatePriorBoxParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_QuantDTypeCast
]
=
PopulateQuantDTypeCastParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Lstm
]
=
PopulateLstmParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_EmbeddingLookup
]
=
PopulateEmbeddingLookupParameter
;
}
PopulateParameterRegistry
*
PopulateParameterRegistry
::
GetInstance
()
{
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h
浏览文件 @
6d4d692f
...
...
@@ -137,12 +137,12 @@ class ArithmeticCPUKernel : public LiteKernel {
arithmetic_broadcast_run_
=
BroadcastNotEqual
;
break
;
case
PrimitiveType_Less
:
arithmetic_run_
=
Element
Equal
;
arithmetic_broadcast_run_
=
Broadcast
Equal
;
arithmetic_run_
=
Element
Less
;
arithmetic_broadcast_run_
=
Broadcast
Less
;
break
;
case
PrimitiveType_LessEqual
:
arithmetic_run_
=
Element
Not
Equal
;
arithmetic_broadcast_run_
=
Broadcast
Not
Equal
;
arithmetic_run_
=
Element
Less
Equal
;
arithmetic_broadcast_run_
=
Broadcast
Less
Equal
;
break
;
case
PrimitiveType_Greater
:
arithmetic_run_
=
ElementGreater
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc
0 → 100644
浏览文件 @
6d4d692f
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/embedding_lookup.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
using
mindspore
::
lite
::
KernelRegistrar
;
using
mindspore
::
lite
::
RET_ERROR
;
using
mindspore
::
lite
::
RET_OK
;
using
mindspore
::
schema
::
PrimitiveType_EmbeddingLookup
;
namespace
mindspore
::
kernel
{
int
EmbeddingLookupCPUKernel
::
Init
()
{
embedding_lookup_parameter_
=
reinterpret_cast
<
EmbeddingLookupParameter
*>
(
opParameter
);
embedding_lookup_parameter_
->
thread_num
=
thread_count_
;
embedding_lookup_parameter_
->
ids_size_
=
inputs_
.
back
()
->
ElementsNum
();
embedding_lookup_parameter_
->
layer_size_
=
1
;
auto
in_shape
=
inputs_
.
front
()
->
shape
();
for
(
int
i
=
1
;
i
<
in_shape
.
size
();
++
i
)
{
embedding_lookup_parameter_
->
layer_size_
*=
in_shape
[
i
];
}
embedding_lookup_parameter_
->
layer_num_
=
0
;
for
(
int
i
=
0
;
i
<
inputs_
.
size
()
-
1
;
++
i
)
{
embedding_lookup_parameter_
->
layer_num_
+=
inputs_
[
i
]
->
shape
()[
0
];
}
input_addr_
=
reinterpret_cast
<
float
*>
(
std
::
malloc
(
sizeof
(
float
)
*
embedding_lookup_parameter_
->
layer_size_
*
embedding_lookup_parameter_
->
layer_num_
));
if
(
input_addr_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Create memory failed"
;
return
mindspore
::
lite
::
RET_MEMORY_FAILED
;
}
embedding_lookup_parameter_
->
is_regulated_
=
reinterpret_cast
<
bool
*>
(
std
::
malloc
(
sizeof
(
bool
)
*
embedding_lookup_parameter_
->
layer_num_
));
if
(
embedding_lookup_parameter_
->
is_regulated_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Create memory failed"
;
return
mindspore
::
lite
::
RET_MEMORY_FAILED
;
}
for
(
int
i
=
0
;
i
<
embedding_lookup_parameter_
->
layer_num_
;
++
i
)
{
embedding_lookup_parameter_
->
is_regulated_
[
i
]
=
embedding_lookup_parameter_
->
max_norm_
==
0
;
}
return
RET_OK
;
}
int
EmbeddingLookupCPUKernel
::
ReSize
()
{
return
RET_OK
;
}
int
EmbeddingLookupCPUKernel
::
DoExcute
(
int
task_id
)
{
int
error_code
=
EmbeddingLookup
(
input_addr_
,
ids_addr_
,
output_addr_
,
embedding_lookup_parameter_
,
task_id
);
if
(
error_code
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"embedding lookup error error_code["
<<
error_code
<<
"]"
;
return
RET_ERROR
;
}
return
RET_OK
;
}
int
EmbeddingLookupRun
(
int
task_id
,
LiteParallelGroupEnv
*
penv
,
void
*
cdata
)
{
auto
EmbeddingLookupData
=
reinterpret_cast
<
EmbeddingLookupCPUKernel
*>
(
cdata
);
auto
ret
=
EmbeddingLookupData
->
DoExcute
(
task_id
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"EmbeddingLookupRun error task_id["
<<
task_id
<<
"] error_code["
<<
ret
<<
"]"
;
return
RET_ERROR
;
}
return
RET_OK
;
}
int
EmbeddingLookupCPUKernel
::
Run
()
{
int
dest_loc
=
0
;
for
(
int
i
=
0
;
i
<
inputs_
.
size
()
-
1
;
i
++
)
{
auto
input_t
=
reinterpret_cast
<
float
*>
(
inputs_
.
at
(
i
)
->
Data
());
memcpy
(
input_addr_
+
dest_loc
,
input_t
,
sizeof
(
float
)
*
inputs_
.
at
(
i
)
->
ElementsNum
());
dest_loc
+=
inputs_
.
at
(
i
)
->
ElementsNum
();
}
output_addr_
=
reinterpret_cast
<
float
*>
(
outputs_
.
front
()
->
Data
());
ids_addr_
=
reinterpret_cast
<
int
*>
(
inputs_
.
back
()
->
Data
());
auto
ret
=
LiteBackendParallelLaunch
(
EmbeddingLookupRun
,
this
,
embedding_lookup_parameter_
->
thread_num
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"EmbeddingLookup error: error_code["
<<
ret
<<
"]"
;
return
RET_ERROR
;
}
return
RET_OK
;
}
kernel
::
LiteKernel
*
CpuEmbeddingLookupFp32KernelCreator
(
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
OpParameter
*
parameter
,
const
lite
::
Context
*
ctx
,
const
KernelKey
&
desc
)
{
if
(
parameter
==
nullptr
||
ctx
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"parameter or ctx is nullptr"
;
return
nullptr
;
}
MS_ASSERT
(
desc
.
type
==
PrimitiveType_EmbeddingLookup
);
auto
*
kernel
=
new
(
std
::
nothrow
)
EmbeddingLookupCPUKernel
(
parameter
,
inputs
,
outputs
,
ctx
);
if
(
kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Create Kernel failed, name: "
<<
parameter
->
name_
;
return
nullptr
;
}
auto
ret
=
kernel
->
Init
();
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Init Kernel failed, name: "
<<
parameter
->
name_
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
static_cast
<
schema
::
PrimitiveType
>
(
parameter
->
type_
));
delete
kernel
;
return
nullptr
;
}
return
kernel
;
}
REG_KERNEL
(
kCPU
,
kNumberTypeFloat32
,
PrimitiveType_EmbeddingLookup
,
CpuEmbeddingLookupFp32KernelCreator
)
}
// namespace mindspore::kernel
mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h
0 → 100644
浏览文件 @
6d4d692f
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h"
namespace
mindspore
::
kernel
{
class
EmbeddingLookupCPUKernel
:
public
LiteKernel
{
public:
explicit
EmbeddingLookupCPUKernel
(
OpParameter
*
parameter
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
const
lite
::
Context
*
ctx
)
:
LiteKernel
(
parameter
,
inputs
,
outputs
),
ctx_
(
ctx
),
thread_count_
(
ctx
->
thread_num_
)
{}
~
EmbeddingLookupCPUKernel
()
override
{};
int
Init
()
override
;
int
ReSize
()
override
;
int
Run
()
override
;
int
DoExcute
(
int
task_id
);
protected:
int
thread_count_
;
const
lite
::
Context
*
ctx_
;
EmbeddingLookupParameter
*
embedding_lookup_parameter_
;
private:
float
*
input_addr_
;
float
*
output_addr_
;
int
*
ids_addr_
;
};
}
// namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_
mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc
浏览文件 @
6d4d692f
...
...
@@ -81,10 +81,10 @@ int ArithmeticInt8CPUKernel::Init() {
arithmetic_run_
=
ElementNotEqual
;
break
;
case
PrimitiveType_Less
:
arithmetic_run_
=
Element
Equal
;
arithmetic_run_
=
Element
Less
;
break
;
case
PrimitiveType_LessEqual
:
arithmetic_run_
=
Element
Not
Equal
;
arithmetic_run_
=
Element
Less
Equal
;
break
;
case
PrimitiveType_Greater
:
arithmetic_run_
=
ElementGreater
;
...
...
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc
0 → 100644
浏览文件 @
6d4d692f
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h"
#include <string.h>
#include "include/errorcode.h"
#include "src/runtime/kernel/arm/nnacl/errorcode.h"
#include "mindspore/core/utils/log_adapter.h"
void
l2_regulate
(
float
*
data
,
int
size
,
float
max_norm
)
{
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
sum
+=
data
[
i
];
}
if
(
sum
!=
0
)
{
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
*=
max_norm
/
sum
;
}
}
return
;
}
int
CopyData
(
float
*
input_data
,
int
*
ids
,
float
*
output_data
,
int
num
,
EmbeddingLookupParameter
*
parameter
)
{
if
(
ids
[
num
]
>=
parameter
->
layer_num_
||
ids
[
num
]
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Embedding lookup index out of range"
;
return
NNACL_ERRCODE_INDEX_OUT_OF_RANGE
;
}
float
*
out_data
=
output_data
+
num
*
parameter
->
layer_size_
;
float
*
in_data
=
input_data
+
ids
[
num
]
*
parameter
->
layer_size_
;
if
(
!
parameter
->
is_regulated_
[
ids
[
num
]])
{
l2_regulate
(
in_data
,
parameter
->
layer_size_
,
parameter
->
max_norm_
);
parameter
->
is_regulated_
[
ids
[
num
]]
=
true
;
}
memcpy
(
out_data
,
in_data
,
sizeof
(
float
)
*
parameter
->
layer_size_
);
return
NNACL_OK
;
}
int
EmbeddingLookup
(
float
*
input_data
,
int
*
ids
,
float
*
output_data
,
EmbeddingLookupParameter
*
parameter
,
int
task_id
)
{
for
(
size_t
i
=
task_id
;
i
<
parameter
->
ids_size_
;
i
+=
parameter
->
thread_num
)
{
int
ret
=
CopyData
(
input_data
,
ids
,
output_data
,
i
,
parameter
);
if
(
ret
!=
NNACL_OK
)
{
return
ret
;
}
}
return
NNACL_OK
;
}
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h
0 → 100644
浏览文件 @
6d4d692f
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_
#include "src/runtime/kernel/arm/nnacl/op_base.h"
struct
EmbeddingLookupParameter
{
OpParameter
op_parameter_
;
bool
*
is_regulated_
;
float
max_norm_
;
int
ids_size_
;
int
layer_size_
;
int
layer_num_
;
int
thread_num
;
};
int
EmbeddingLookup
(
float
*
input_data
,
int
*
ids
,
float
*
output_data
,
EmbeddingLookupParameter
*
parameter
,
int
task_id
);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc
0 → 100644
浏览文件 @
6d4d692f
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "src/runtime/kernel/arm/fp32/embedding_lookup.h"
#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h"
#include "src/common/file_utils.h"
#include "common/common_test.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
using
mindspore
::
lite
::
tensor
::
Tensor
;
class
TestEmbeddingLookupFp32
:
public
mindspore
::
Common
{
public:
TestEmbeddingLookupFp32
()
{}
};
void
ElTestInit
(
std
::
vector
<
Tensor
*>
*
inputs_
,
std
::
vector
<
Tensor
*>
*
outputs_
,
EmbeddingLookupParameter
*
embedding_lookup_param
)
{
Tensor
*
in_t_first
=
new
Tensor
(
kNumberTypeFloat32
,
{
6
,
2
},
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
in_t_first
->
MallocData
();
float
in_first
[]
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
};
memcpy
(
in_t_first
->
Data
(),
in_first
,
sizeof
(
float
)
*
in_t_first
->
ElementsNum
());
inputs_
->
push_back
(
in_t_first
);
Tensor
*
in_t_second
=
new
Tensor
(
kNumberTypeFloat32
,
{
4
,
2
},
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
in_t_second
->
MallocData
();
float
in_second
[]
=
{
1.1
,
2.2
,
3.3
,
4.4
,
5.5
,
6.6
,
7.7
,
8.8
};
memcpy
(
in_t_second
->
Data
(),
in_second
,
sizeof
(
float
)
*
in_t_second
->
ElementsNum
());
inputs_
->
push_back
(
in_t_second
);
Tensor
*
ids_t
=
new
Tensor
(
kNumberTypeFloat32
,
{
2
,
3
},
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
ids_t
->
MallocData
();
int
ids
[]
=
{
1
,
9
,
2
,
4
,
6
,
7
};
memcpy
(
ids_t
->
Data
(),
ids
,
sizeof
(
int
)
*
ids_t
->
ElementsNum
());
inputs_
->
push_back
(
ids_t
);
Tensor
*
outputs_t
=
new
Tensor
(
kNumberTypeInt32
,
{
2
,
3
,
2
},
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
outputs_t
->
MallocData
();
outputs_
->
push_back
(
outputs_t
);
embedding_lookup_param
->
max_norm_
=
1
;
}
TEST_F
(
TestEmbeddingLookupFp32
,
ElTest
)
{
std
::
vector
<
Tensor
*>
inputs_
;
std
::
vector
<
Tensor
*>
outputs_
;
auto
embedding_lookup_param_
=
new
EmbeddingLookupParameter
();
ElTestInit
(
&
inputs_
,
&
outputs_
,
embedding_lookup_param_
);
lite
::
Context
*
ctx
=
new
lite
::
Context
;
ctx
->
thread_num_
=
2
;
kernel
::
EmbeddingLookupCPUKernel
*
el
=
new
kernel
::
EmbeddingLookupCPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
embedding_lookup_param_
),
inputs_
,
outputs_
,
ctx
);
el
->
Init
();
el
->
Run
();
std
::
cout
<<
"output shape:"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
outputs_
.
front
()
->
shape
().
size
();
++
i
)
{
std
::
cout
<<
outputs_
.
front
()
->
shape
()[
i
]
<<
' '
;
}
std
::
cout
<<
std
::
endl
;
float
*
out
=
reinterpret_cast
<
float
*>
(
outputs_
.
front
()
->
Data
());
for
(
int
i
=
0
;
i
<
outputs_
.
front
()
->
ElementsNum
();
++
i
)
{
std
::
cout
<<
out
[
i
]
<<
' '
;
}
std
::
cout
<<
std
::
endl
;
}
}
// namespace mindspore
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录