Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
780459b1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
780459b1
编写于
6月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2212 optimize embedingLookup
Merge pull request !2212 from dengwentao/embedding
上级
518e9552
98aa1b36
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
77 addition
and
60 deletion
+77
-60
mindspore/ccsrc/kernel/CMakeLists.txt
mindspore/ccsrc/kernel/CMakeLists.txt
+0
-2
mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc
+77
-58
未找到文件。
mindspore/ccsrc/kernel/CMakeLists.txt
浏览文件 @
780459b1
...
...
@@ -26,8 +26,6 @@ if (ENABLE_CPU)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/allgather_cpu_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/reduce_scatter_cpu_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/embedding_look_up_comm_grad_cpu_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/embedding_look_up_cpu_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/sub_cpu_kernel.cc"
)
endif
()
endif
()
...
...
mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc
浏览文件 @
780459b1
...
...
@@ -28,30 +28,31 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
input_shape_
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
input_lens_
=
1
;
for
(
auto
shape
:
input_shape_
)
{
MS_LOG
(
DEBUG
)
<<
"input shape: "
<<
shape
;
MS_LOG
(
INFO
)
<<
"input shape: "
<<
shape
;
input_lens_
=
input_lens_
*
shape
;
}
MS_LOG
(
DEBUG
)
<<
"input lens: "
<<
input_lens_
;
MS_LOG
(
INFO
)
<<
"input lens: "
<<
input_lens_
;
indices_shape_
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
1
);
indices_lens_
=
1
;
for
(
auto
shape
:
indices_shape_
)
{
MS_LOG
(
DEBUG
)
<<
"indice shape: "
<<
shape
;
MS_LOG
(
INFO
)
<<
"indice shape: "
<<
shape
;
indices_lens_
=
indices_lens_
*
shape
;
}
MS_LOG
(
DEBUG
)
<<
"indice lens: "
<<
indices_lens_
;
MS_LOG
(
INFO
)
<<
"indice lens: "
<<
indices_lens_
;
output_shape_
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
for
(
auto
shape
:
output_shape_
)
{
MS_LOG
(
DEBUG
)
<<
"output shape: "
<<
shape
;
MS_LOG
(
INFO
)
<<
"output shape: "
<<
shape
;
}
auto
output_type
=
AnfAlgo
::
GetOutputInferDataType
(
kernel_node
,
0
);
MS_LOG
(
DEBUG
)
<<
"output type: "
<<
output_type
;
MS_LOG
(
INFO
)
<<
"output type: "
<<
output_type
;
axis_
=
4
-
input_shape_
.
size
();
MS_LOG
(
DEBUG
)
<<
"axis_: "
<<
axis_
;
MS_LOG
(
INFO
)
<<
"axis_: "
<<
axis_
;
reduce_scatter_flag_
=
AnfAlgo
::
GetNodeAttr
<
bool
>
(
kernel_node
,
"reduce_scatter_flag"
);
MS_LOG
(
DEBUG
)
<<
"reduce_scatter_flag: "
<<
reduce_scatter_flag_
;
MS_LOG
(
INFO
)
<<
"reduce_scatter_flag: "
<<
reduce_scatter_flag_
;
#ifdef ENABLE_MPI
if
(
reduce_scatter_flag_
)
{
size_t
gatherv2_out_lens
=
1
;
for
(
int
i
=
0
;
i
<
SizeToInt
(
input_shape_
.
size
());
i
++
)
{
...
...
@@ -66,7 +67,7 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
}
gatherv2_out_lens_
=
gatherv2_out_lens
*
sizeof
(
float
);
MS_LOG
(
DEBUG
)
<<
"gatherv2 out lens: "
<<
gatherv2_out_lens_
;
MS_LOG
(
INFO
)
<<
"gatherv2 out lens: "
<<
gatherv2_out_lens_
;
gather_v2_out_
=
malloc
(
gatherv2_out_lens_
);
if
(
gather_v2_out_
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"EmbeddingLookUpCPUKernel malloc failed, malloc lens: "
<<
gatherv2_out_lens_
;
...
...
@@ -77,10 +78,15 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
split_num_
=
AnfAlgo
::
GetNodeAttr
<
int
>
(
kernel_node
,
"split_num"
);
MS_LOG
(
DEBUG
)
<<
"split_num: "
<<
split_num_
;
MS_LOG
(
INFO
)
<<
"split_num: "
<<
split_num_
;
}
#else
if
(
reduce_scatter_flag_
)
{
MS_LOG
(
EXCEPTION
)
<<
"Not Enable MPI, please build version with -M on when set reduce_scatter_flag true"
;
}
#endif
offset_
=
AnfAlgo
::
GetNodeAttr
<
int
>
(
kernel_node
,
"offset"
);
MS_LOG
(
DEBUG
)
<<
"offset: "
<<
offset_
;
MS_LOG
(
INFO
)
<<
"offset: "
<<
offset_
;
CPUKernelUtils
::
ExpandDimsTo4
(
&
input_shape_
);
CPUKernelUtils
::
ExpandDimsTo4
(
&
output_shape_
);
}
...
...
@@ -97,13 +103,8 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
auto
output_addr
=
reinterpret_cast
<
float
*>
(
outputs
[
0
]
->
addr
);
MS_LOG
(
DEBUG
)
<<
"output addr: "
<<
output_addr
<<
"output size: "
<<
outputs
[
0
]
->
size
;
float
*
gather_out_addr
=
reduce_scatter_flag_
?
reinterpret_cast
<
float
*>
(
gather_v2_out_
)
:
output_addr
;
if
(
!
reduce_scatter_flag_
)
{
auto
ret
=
memset_s
(
gather_out_addr
,
outputs
[
0
]
->
size
,
0
,
outputs
[
0
]
->
size
);
if
(
ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"EmbeddingLookUpCPUKernel memset out buff failed"
;
}
}
MS_LOG
(
DEBUG
)
<<
"gatherv2 out addr: "
<<
gather_out_addr
;
size_t
dim0
=
input_shape_
[
0
];
size_t
dim1
=
input_shape_
[
1
];
size_t
dim2
=
input_shape_
[
2
];
...
...
@@ -130,6 +131,7 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
LookUpTable
(
inputs
,
0
,
0
,
0
,
&
gather_out_addr
);
}
#ifdef ENABLE_MPI
if
(
reduce_scatter_flag_
)
{
size_t
one_split_lens
=
gatherv2_out_lens_
/
split_num_
/
sizeof
(
float
);
size_t
reduce_scatter_out_lens
=
one_split_lens
/
8
;
...
...
@@ -140,6 +142,8 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
one_split_lens
/
8
,
"sum"
);
}
}
#endif
#if defined(_WIN32) || defined(_WIN64)
auto
end_time
=
std
::
chrono
::
steady_clock
::
now
();
std
::
chrono
::
duration
<
double
,
std
::
ratio
<
1
,
1000000
>>
cost
=
end_time
-
start_time
;
...
...
@@ -153,67 +157,82 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
return
true
;
}
void
memcpy_task
(
std
::
vector
<
float
*>
*
mem_dest_addr_list
,
std
::
vector
<
float
*>
*
mem_src_addr_list
,
size_t
start
,
size_t
end
,
size_t
lens
)
{
for
(
size_t
i
=
start
;
i
<
end
;
i
++
)
{
auto
ret
=
memcpy_s
((
*
mem_dest_addr_list
)[
i
],
lens
,
(
*
mem_src_addr_list
)[
i
],
lens
);
if
(
ret
!=
EOK
)
{
MS_LOG
(
EXCEPTION
)
<<
"memery copy failed."
;
}
}
return
;
}
void
EmbeddingLookUpCPUKernel
::
LookUpTable
(
const
std
::
vector
<
kernel
::
AddressPtr
>
&
inputs
,
size_t
dim0
,
size_t
dim1
,
size_t
dim2
,
float
**
output_addr
)
{
auto
input_addr
=
reinterpret_cast
<
float
*>
(
inputs
[
0
]
->
addr
);
auto
indices_addr
=
reinterpret_cast
<
int
*>
(
inputs
[
1
]
->
addr
);
size_t
num
=
CPUKernelUtils
::
GetElementNumOnAxis
(
input_shape_
,
axis_
);
void
LookUpTable_task
(
float
*
input_addr
,
float
*
output_addr
,
int
*
indices_addr
,
size_t
indices_lens
,
size_t
num
,
size_t
dim0
,
size_t
dim1
,
size_t
dim2
,
int
offset
,
size_t
axis
,
std
::
vector
<
size_t
>
input_shape
,
size_t
input_lens
)
{
size_t
lens
=
num
*
sizeof
(
float
);
std
::
vector
<
float
*>
mem_dest_addr_list
;
std
::
vector
<
float
*>
mem_src_addr_list
;
for
(
size_t
i
=
0
;
i
<
indices_lens_
;
++
i
)
{
int
indices
=
indices_addr
[
i
]
-
offset_
;
for
(
size_t
i
=
0
;
i
<
indices_lens
;
++
i
)
{
int
indices
=
indices_addr
[
i
]
-
offset
;
if
(
indices
>=
0
)
{
size_t
index
=
IntToSize
(
indices
);
if
(
index
<
input_shape
_
[
axis_
])
{
if
(
index
<
input_shape
[
axis
])
{
size_t
pos
=
0
;
if
(
axis
_
==
3
)
{
pos
=
CPUKernelUtils
::
CalcOffset
(
input_shape
_
,
dim0
,
dim1
,
dim2
,
index
);
}
else
if
(
axis
_
==
2
)
{
pos
=
CPUKernelUtils
::
CalcOffset
(
input_shape
_
,
dim0
,
dim1
,
index
,
0
);
}
else
if
(
axis
_
==
1
)
{
pos
=
CPUKernelUtils
::
CalcOffset
(
input_shape
_
,
dim0
,
index
,
0
,
0
);
}
else
if
(
axis
_
==
0
)
{
pos
=
CPUKernelUtils
::
CalcOffset
(
input_shape
_
,
index
,
0
,
0
,
0
);
if
(
axis
==
3
)
{
pos
=
CPUKernelUtils
::
CalcOffset
(
input_shape
,
dim0
,
dim1
,
dim2
,
index
);
}
else
if
(
axis
==
2
)
{
pos
=
CPUKernelUtils
::
CalcOffset
(
input_shape
,
dim0
,
dim1
,
index
,
0
);
}
else
if
(
axis
==
1
)
{
pos
=
CPUKernelUtils
::
CalcOffset
(
input_shape
,
dim0
,
index
,
0
,
0
);
}
else
if
(
axis
==
0
)
{
pos
=
CPUKernelUtils
::
CalcOffset
(
input_shape
,
index
,
0
,
0
,
0
);
}
if
(
pos
+
num
<=
input_lens_
)
{
mem_dest_addr_list
.
push_back
(
*
output_addr
);
mem_src_addr_list
.
push_back
(
input_addr
+
pos
);
if
(
pos
+
num
<=
input_lens
)
{
auto
ret
=
memcpy_s
(
output_addr
,
lens
,
input_addr
+
pos
,
lens
);
if
(
ret
!=
EOK
)
{
MS_LOG
(
EXCEPTION
)
<<
"LookUpTable task memcpy failed."
;
}
}
else
{
auto
ret
=
memset_s
(
output_addr
,
lens
,
0
,
lens
);
if
(
ret
!=
EOK
)
{
MS_LOG
(
EXCEPTION
)
<<
"LookUpTable task memset failed."
;
}
}
}
else
{
auto
ret
=
memset_s
(
output_addr
,
lens
,
0
,
lens
);
if
(
ret
!=
EOK
)
{
MS_LOG
(
EXCEPTION
)
<<
"LookUpTable task memset failed."
;
}
}
}
else
{
auto
ret
=
memset_s
(
output_addr
,
lens
,
0
,
lens
);
if
(
ret
!=
EOK
)
{
MS_LOG
(
EXCEPTION
)
<<
"LookUpTable task memset failed."
;
}
}
*
output_addr
+=
num
;
output_addr
+=
num
;
}
}
void
EmbeddingLookUpCPUKernel
::
LookUpTable
(
const
std
::
vector
<
kernel
::
AddressPtr
>
&
inputs
,
size_t
dim0
,
size_t
dim1
,
size_t
dim2
,
float
**
output_addr
)
{
auto
input_addr
=
reinterpret_cast
<
float
*>
(
inputs
[
0
]
->
addr
);
auto
indices_addr
=
reinterpret_cast
<
int
*>
(
inputs
[
1
]
->
addr
);
size_t
num
=
CPUKernelUtils
::
GetElementNumOnAxis
(
input_shape_
,
axis_
);
float
*
task_out_addr
=
*
output_addr
;
const
size_t
thread_num
=
8
;
std
::
thread
threads
[
8
];
size_t
memcpy_lens
=
mem_dest_addr_list
.
size
();
size_t
start
=
0
;
size_t
ones_copy_lens
=
(
memcpy_lens
+
thread_num
-
1
)
/
thread_num
;
size_t
task_proc_lens
=
(
indices_lens_
+
thread_num
-
1
)
/
thread_num
;
size_t
i
;
size_t
task_offset
=
0
;
MS_LOG
(
DEBUG
)
<<
"indices_lens_: "
<<
indices_lens_
<<
" one task proc lens:"
<<
task_proc_lens
;
for
(
i
=
0
;
i
<
thread_num
;
i
++
)
{
if
(
start
>
memcpy_lens
)
{
if
(
task_offset
>=
indices_lens_
)
{
break
;
}
auto
end
=
(
start
+
ones_copy_lens
)
>
memcpy_lens
?
memcpy_lens
:
start
+
ones_copy_lens
;
threads
[
i
]
=
std
::
thread
(
memcpy_task
,
&
mem_dest_addr_list
,
&
mem_src_addr_list
,
start
,
end
,
lens
);
start
=
start
+
ones_copy_lens
;
MS_LOG
(
DEBUG
)
<<
"task_offset: "
<<
task_offset
<<
" task_proc_lenss:"
<<
task_proc_lens
;
threads
[
i
]
=
std
::
thread
(
LookUpTable_task
,
input_addr
,
task_out_addr
+
task_offset
*
num
,
indices_addr
+
task_offset
,
task_proc_lens
,
num
,
dim0
,
dim1
,
dim2
,
offset_
,
axis_
,
input_shape_
,
input_lens_
);
task_offset
+=
task_proc_lens
;
if
(
task_offset
+
task_proc_lens
>
indices_lens_
)
{
task_proc_lens
=
indices_lens_
-
task_offset
;
}
}
for
(
size_t
j
=
0
;
j
<
i
;
j
++
)
{
threads
[
j
].
join
();
}
*
output_addr
+=
num
*
indices_lens_
;
}
void
EmbeddingLookUpCPUKernel
::
CheckParam
(
const
CNodePtr
&
kernel_node
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录