Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9779680d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9779680d
编写于
5月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1204 Support common hccl op
Merge pull request !1204 from caifubi/support-hccl-multi-group
上级
311b7e71
7d07e17f
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
111 addition
and
36 deletion
+111
-36
cmake/dependency_graphengine.cmake
cmake/dependency_graphengine.cmake
+2
-3
cmake/package.cmake
cmake/package.cmake
+1
-1
graphengine
graphengine
+1
-1
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
+1
-2
mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h
...rc/device/ascend/profiling/reporter/graph_desc_reporter.h
+1
-0
mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc
mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc
+4
-4
mindspore/ccsrc/device/kernel_runtime.cc
mindspore/ccsrc/device/kernel_runtime.cc
+61
-19
mindspore/ccsrc/device/kernel_runtime.h
mindspore/ccsrc/device/kernel_runtime.h
+6
-2
mindspore/ccsrc/kernel/hccl/hccl_kernel.cc
mindspore/ccsrc/kernel/hccl/hccl_kernel.cc
+2
-1
mindspore/ccsrc/kernel/hccl/hccl_kernel.h
mindspore/ccsrc/kernel/hccl/hccl_kernel.h
+1
-0
mindspore/ccsrc/kernel/hccl/hcom_util.cc
mindspore/ccsrc/kernel/hccl/hcom_util.cc
+12
-1
mindspore/ccsrc/kernel/hccl/hcom_util.h
mindspore/ccsrc/kernel/hccl/hcom_util.h
+2
-0
mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc
...re/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc
+1
-2
mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc
...re_activate/ascend/format_type/deal_ref_trans_and_cast.cc
+16
-0
未找到文件。
cmake/dependency_graphengine.cmake
浏览文件 @
9779680d
...
...
@@ -9,11 +9,11 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake)
include
(
${
GE_SOURCE_DIR
}
/cmake/external_libs/gtest.cmake
)
include
(
${
GE_SOURCE_DIR
}
/cmake/external_libs/protobuf.cmake
)
include
(
${
GE_SOURCE_DIR
}
/cmake/external_libs/onnx.cmake
)
include
(
${
GE_SOURCE_DIR
}
/cmake/external_libs/securec.cmake
)
# for CPU/GPU mode, find
c_sec and
slog from local prebuild
# for CPU/GPU mode, find slog from local prebuild
if
(
NOT ENABLE_D
)
set
(
GE_PREBUILD_PATH
${
GE_SOURCE_DIR
}
/third_party/prebuild/
${
CMAKE_HOST_SYSTEM_PROCESSOR
}
)
find_library
(
c_sec libc_sec.so
${
GE_PREBUILD_PATH
}
)
find_library
(
slog libslog.so
${
GE_PREBUILD_PATH
}
)
elseif
(
DEFINED ENV{D_LINK_PATH}
)
set
(
GE_LIB_PATH $ENV{D_LINK_PATH}
)
...
...
@@ -28,7 +28,6 @@ elseif (DEFINED ENV{D_LINK_PATH})
message
(
FATAL_ERROR
"Running on a unsupported architecture:
${
SYSTEM_TYPE
}
, build terminated"
)
endif
()
set
(
GE_LIB_PATH
${
GE_LIB_PATH
}
/
${
GE_SYS_ARCH
}
)
find_library
(
c_sec libc_sec.so
${
GE_LIB_PATH
}
)
find_library
(
slog libslog.so
${
GE_LIB_PATH
}
)
find_library
(
mmpa libmmpa.so
${
GE_LIB_PATH
}
)
find_library
(
runtime libruntime.so
${
GE_LIB_PATH
}
)
...
...
cmake/package.cmake
浏览文件 @
9779680d
...
...
@@ -153,7 +153,7 @@ if (NOT ENABLE_GE)
FILES
${
CMAKE_BINARY_DIR
}
/graphengine/src/common/graph/libgraph.so
${
CMAKE_SOURCE_DIR
}
/graphengine/third_party/prebuild/
${
CMAKE_HOST_SYSTEM_PROCESSOR
}
/libslog.so
${
CMAKE_SOURCE_DIR
}
/
graphengine/third_party/prebuild/
${
CMAKE_HOST_SYSTEM_PROCESSOR
}
/libc_sec.so
${
CMAKE_SOURCE_DIR
}
/
build/graphengine
/libc_sec.so
DESTINATION
${
INSTALL_LIB_DIR
}
COMPONENT mindspore
)
...
...
graphengine
@
579dcb75
Subproject commit
995b6dadc0fbbe4b80a08196886a53a18bffa60e
Subproject commit
579dcb75a990b533f9182733a6424f2bd66f0f23
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
浏览文件 @
9779680d
...
...
@@ -333,8 +333,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
bool
status
=
ge
::
model_runner
::
ModelRunner
::
Instance
().
LoadDavinciModel
(
device_id_
,
0
,
model_iter
->
first
,
model_iter
->
second
,
listener
);
if
(
!
status
)
{
MS_LOG
(
ERROR
)
<<
"load task failed"
;
return
false
;
MS_LOG
(
EXCEPTION
)
<<
"Load Task Failed"
;
}
if
(
ProfilingManager
::
GetInstance
().
IsProfiling
())
{
auto
task_ids
=
ge
::
model_runner
::
ModelRunner
::
Instance
().
GetTaskIdList
(
model_iter
->
first
);
...
...
mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h
浏览文件 @
9779680d
...
...
@@ -29,6 +29,7 @@ class GraphDescReporter : public DescReporter {
public:
GraphDescReporter
(
uint32_t
device_id
,
const
std
::
string
&
file_name
,
std
::
vector
<
CNodePtr
>
cnode_list
)
:
DescReporter
(
device_id
,
file_name
,
std
::
move
(
cnode_list
))
{}
~
GraphDescReporter
()
override
=
default
;
void
ReportData
()
override
;
};
}
// namespace ascend
...
...
mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc
浏览文件 @
9779680d
...
...
@@ -60,7 +60,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
const
string
tag_broadcast
=
kHcomBroadcast
+
std
::
to_string
(
task_counter
++
)
+
kUnderline
+
std
::
to_string
(
0
);
ret
=
hcom_broadcast
(
tag_broadcast
.
c_str
(),
reinterpret_cast
<
void
*>
(
task_info
->
input_data_addr
()),
static_cast
<
u64
>
(
task_info
->
count
()),
static_cast
<
hcclDataType_t
>
(
task_info
->
data_type
()),
static_cast
<
u32
>
(
task_info
->
root_id
()),
nullptr
,
stream
);
static_cast
<
u32
>
(
task_info
->
root_id
()),
task_info
->
group
().
c_str
()
,
stream
);
if
(
ret
!=
HCCL_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"hcom_broadcast fail, return ret: "
<<
static_cast
<
int
>
(
ret
);
return
false
;
...
...
@@ -70,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
const
string
tag_all_gather
=
kHcomAllGather
+
std
::
to_string
(
task_counter
++
)
+
kUnderline
+
std
::
to_string
(
0
);
ret
=
hcom_all_gather
(
tag_all_gather
.
c_str
(),
reinterpret_cast
<
void
*>
(
task_info
->
input_data_addr
()),
reinterpret_cast
<
void
*>
(
task_info
->
output_data_addr
()),
static_cast
<
u64
>
(
task_info
->
count
()),
static_cast
<
hcclDataType_t
>
(
task_info
->
data_type
()),
nullptr
,
stream
);
static_cast
<
hcclDataType_t
>
(
task_info
->
data_type
()),
task_info
->
group
().
c_str
()
,
stream
);
if
(
ret
!=
HCCL_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"hcom_all_gather fail, return ret: "
<<
ret
;
return
false
;
...
...
@@ -81,7 +81,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
ret
=
hcom_all_reduce
(
tag_all_reduce
.
c_str
(),
reinterpret_cast
<
void
*>
(
task_info
->
input_data_addr
()),
reinterpret_cast
<
void
*>
(
task_info
->
output_data_addr
()),
static_cast
<
u64
>
(
task_info
->
count
()),
static_cast
<
hcclDataType_t
>
(
task_info
->
data_type
()),
static_cast
<
hcclRedOp_t
>
(
task_info
->
op_type
()),
nullptr
,
stream
);
static_cast
<
hcclRedOp_t
>
(
task_info
->
op_type
()),
task_info
->
group
().
c_str
()
,
stream
);
if
(
ret
!=
HCCL_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"hcom_all_reduce fail, return ret: "
<<
ret
;
return
false
;
...
...
@@ -93,7 +93,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
ret
=
hcom_reduce_scatter
(
tag_reduce_scatter
.
c_str
(),
reinterpret_cast
<
void
*>
(
task_info
->
input_data_addr
()),
reinterpret_cast
<
void
*>
(
task_info
->
output_data_addr
()),
static_cast
<
u64
>
(
task_info
->
count
()),
static_cast
<
hcclDataType_t
>
(
task_info
->
data_type
()),
static_cast
<
hcclRedOp_t
>
(
task_info
->
op_type
()),
nullptr
,
stream
);
static_cast
<
hcclRedOp_t
>
(
task_info
->
op_type
()),
task_info
->
group
().
c_str
()
,
stream
);
if
(
ret
!=
HCCL_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"hcom_reduce_scatter fail, return ret: "
<<
ret
;
return
false
;
...
...
mindspore/ccsrc/device/kernel_runtime.cc
浏览文件 @
9779680d
...
...
@@ -15,6 +15,7 @@
*/
#include "device/kernel_runtime.h"
#include <vector>
#include <utility>
#include <numeric>
#include <functional>
...
...
@@ -130,20 +131,16 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
mem_manager_
->
ResetDynamicMemory
();
AssignStaticMemory
(
graph
);
AssignDynamicMemory
(
graph
);
UpdateRefNodeOutputMem
(
graph
);
}
void
KernelRuntime
::
RunOpAssignMemory
(
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
session
::
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
// assign memory for input nodes
RunOpAssignInputMemory
(
input_tensors
,
graph
);
AssignStaticMemoryValueNode
(
graph
);
for
(
const
auto
&
cnode
:
graph
->
execution_order
())
{
// assign memory for output nodes
RunOpAssignOutputMemory
(
cnode
);
// assign memory for workspace
RunOpAssignWorkSpaceMemory
(
cnode
);
}
UpdateRefNodeOutputMem
(
graph
);
...
...
@@ -280,12 +277,22 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
void
KernelRuntime
::
AssignStaticMemoryOutput
(
const
session
::
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
auto
nodes
=
AnfAlgo
::
GetAllOutput
(
graph
->
output
(),
{
prim
::
kPrimTupleGetItem
});
std
::
vector
<
session
::
KernelWithIndex
>
non_communication_op
;
// Assign Communicate Op Memory firstly.
for
(
const
auto
&
node
:
nodes
)
{
auto
item_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
node
,
0
,
true
);
MS_EXCEPTION_IF_NULL
(
item_with_index
.
first
);
if
(
!
item_with_index
.
first
->
isa
<
CNode
>
()
||
!
AnfAlgo
::
IsRealKernel
(
item_with_index
.
first
))
{
continue
;
}
if
(
AnfAlgo
::
IsCommunicationOp
(
item_with_index
.
first
))
{
AssignCommunicationNodeMem
(
kStaticMem
,
item_with_index
.
first
);
}
else
{
non_communication_op
.
emplace_back
(
item_with_index
);
}
}
for
(
const
auto
&
item_with_index
:
non_communication_op
)
{
AssignNodeOutputMem
(
kStaticMem
,
item_with_index
.
first
,
SizeToInt
(
item_with_index
.
second
));
}
}
...
...
@@ -322,6 +329,11 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
}
}
void
KernelRuntime
::
AssignCommunicationNodeMem
(
int
flag
,
const
AnfNodePtr
&
node
)
{
AssignCommunicationNodeInputMem
(
node
);
AssignCommunicationNodeOutputMem
(
flag
,
node
);
}
void
KernelRuntime
::
AssignCommunicationNodeOutputMem
(
int
flag
,
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
mem_manager_
);
...
...
@@ -335,8 +347,13 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
size_t
total_size
=
0
;
size_t
output_index
=
0
;
std
::
vector
<
size_t
>
align_size_list
;
for
(
uint64_t
mem_size
:
output_sizes
)
{
if
(
AnfAlgo
::
OutputAddrExist
(
node
,
output_index
++
))
{
MS_LOG
(
INFO
)
<<
"communication op addr exist"
;
continue
;
}
if
(
context_ptr
->
enable_hccl
())
{
mem_size
=
mem_manager_
->
GetCommonAlignSize
(
mem_size
);
}
...
...
@@ -353,7 +370,21 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
}
}
void
KernelRuntime
::
UpdateCommunicationOpInputMem
(
const
AnfNodePtr
&
node
)
{
DeviceAddressPtr
KernelRuntime
::
PreAssignCNodeMemory
(
const
AnfNodePtr
&
anf_node
,
size_t
index
)
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
auto
kernel_mod
=
AnfAlgo
::
GetKernelMod
(
anf_node
);
auto
output_sizes
=
kernel_mod
->
GetOutputSizeList
();
if
(
output_sizes
.
size
()
<=
index
)
{
MS_LOG
(
EXCEPTION
)
<<
"Previous node output size < node index"
;
}
std
::
string
output_format
=
AnfAlgo
::
GetOutputFormat
(
anf_node
,
index
);
auto
output_type
=
AnfAlgo
::
GetOutputDeviceDataType
(
anf_node
,
index
);
auto
address
=
CreateDeviceAddress
(
nullptr
,
output_sizes
[
index
],
output_format
,
output_type
);
AnfAlgo
::
SetOutputAddr
(
address
,
index
,
anf_node
.
get
());
return
address
;
}
void
KernelRuntime
::
AssignCommunicationNodeInputMem
(
const
AnfNodePtr
&
node
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
MS_EXCEPTION_IF_NULL
(
node
);
...
...
@@ -361,12 +392,16 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
size_t
total_size
=
0
;
std
::
vector
<
std
::
pair
<
mindspore
::
device
::
DeviceAddress
*
,
size_t
>>
addr_size
;
for
(
size_t
i
=
0
;
i
<
AnfAlgo
::
GetInputTensorNum
(
node
);
++
i
)
{
auto
address
=
AnfAlgo
::
GetPrevNodeMutableOutputAddr
(
node
,
i
);
MS_EXCEPTION_IF_NULL
(
address
);
auto
mem_size
=
address
->
size
();
if
(
context_ptr
->
enable_hccl
())
{
mem_size
=
mem_manager_
->
GetCommonAlignSize
(
mem_size
);
auto
input_node_with_index
=
AnfAlgo
::
GetPrevNodeOutput
(
node
,
i
);
auto
input_node
=
input_node_with_index
.
first
;
DeviceAddressPtr
address
=
nullptr
;
if
(
input_node
->
isa
<
CNode
>
())
{
address
=
PreAssignCNodeMemory
(
input_node
,
input_node_with_index
.
second
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Communication node inputs only support CNode"
;
}
MS_EXCEPTION_IF_NULL
(
address
);
auto
mem_size
=
mem_manager_
->
GetCommonAlignSize
(
address
->
size
());
total_size
+=
mem_size
;
addr_size
.
emplace_back
(
address
.
get
(),
mem_size
);
}
...
...
@@ -381,11 +416,6 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
void
KernelRuntime
::
AssignNodeOutputMem
(
int
flag
,
const
AnfNodePtr
&
node
,
int
index
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
mem_manager_
);
if
(
AnfAlgo
::
IsCommunicationOp
(
node
))
{
UpdateCommunicationOpInputMem
(
node
);
AssignCommunicationNodeOutputMem
(
flag
,
node
);
return
;
}
if
(
AnfAlgo
::
IsGetNext
(
NOT_NULL
(
node
))
&&
flag
==
kReuseDynamicMem
)
{
MS_LOG
(
INFO
)
<<
"GetNext disable mem_reuse"
;
flag
=
kDynamicMem
;
...
...
@@ -506,10 +536,22 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
mem_manager_
->
MallocReusedDynamicMem
(
graph
);
mem_flag
=
kReuseDynamicMem
;
}
auto
&
kernels
=
graph
->
execution_order
();
for
(
auto
&
kernel
:
kernels
)
{
AssignNodeOutputMem
(
mem_flag
,
kernel
,
kGetAllOuts
);
AssignWorkSpaceMem
(
mem_flag
,
kernel
);
auto
&
execution_nodes
=
graph
->
execution_order
();
std
::
vector
<
CNodePtr
>
compute_nodes
;
// communication nodes first
for
(
auto
&
node
:
execution_nodes
)
{
if
(
AnfAlgo
::
IsCommunicationOp
(
node
))
{
// skip if the memory is already alocated
AssignCommunicationNodeMem
(
mem_flag
,
node
);
}
else
{
compute_nodes
.
emplace_back
(
node
);
}
}
// then compute nodes
for
(
auto
&
node
:
compute_nodes
)
{
AssignNodeOutputMem
(
mem_flag
,
node
,
kGetAllOuts
);
AssignWorkSpaceMem
(
mem_flag
,
node
);
}
}
...
...
mindspore/ccsrc/device/kernel_runtime.h
浏览文件 @
9779680d
...
...
@@ -73,9 +73,12 @@ class KernelRuntime {
void
AssignNodeOutputMem
(
int
flag
,
const
AnfNodePtr
&
node
,
int
index
);
void
AssignWorkSpaceMem
(
int
flag
,
const
AnfNodePtr
&
node
);
void
AssignReuseWorkSpaceMem
(
const
AnfNodePtr
&
node
);
void
AssignCommunicationNodeOutputMem
(
int
flag
,
const
AnfNodePtr
&
node
);
void
UpdateRefNodeOutputMem
(
const
session
::
KernelGraph
*
graph
);
void
UpdateCommunicationOpInputMem
(
const
AnfNodePtr
&
node
);
void
AssignCommunicationNodeOutputMem
(
int
flag
,
const
AnfNodePtr
&
node
);
void
AssignCommunicationNodeInputMem
(
const
AnfNodePtr
&
node
);
void
AssignCommunicationNodeMem
(
int
flag
,
const
AnfNodePtr
&
node
);
#ifdef ENABLE_DUMP_E2E
bool
SetDumpConf
();
#endif
...
...
@@ -91,6 +94,7 @@ class KernelRuntime {
void
RunOpAssignOutputMemory
(
const
AnfNodePtr
&
kernel
);
void
RunOpAssignWorkSpaceMemory
(
const
AnfNodePtr
&
kernel
);
void
AssignValueNodeTensor
(
const
ValueNodePtr
&
value_node
,
const
ValuePtr
&
node_value
,
size_t
output_idx
);
DeviceAddressPtr
PreAssignCNodeMemory
(
const
AnfNodePtr
&
anf_node
,
size_t
index
);
protected:
uint32_t
device_id_
{
0
};
...
...
mindspore/ccsrc/kernel/hccl/hccl_kernel.cc
浏览文件 @
9779680d
...
...
@@ -90,6 +90,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
return
false
;
}
}
HcomUtil
::
GetHcomGroup
(
NOT_NULL
(
anf_node
),
NOT_NULL
(
&
group_
));
anf_node_
=
anf_node
;
return
true
;
}
...
...
@@ -147,7 +148,7 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
HcclTaskInfoPtr
task_info_ptr
=
std
::
make_shared
<
HcclTaskInfo
>
(
stream_id
,
hccl_type
,
input_data_addr
,
output_data_addr
,
workspace_address
,
workspace_num
,
0
,
private_def
,
nullptr
,
hccl_count_
,
root_id_
,
op_type_
,
data_type
,
RuntimeUtils
::
HcomBindModel
,
RuntimeUtils
::
HcomUnbindModel
,
hccl_count_
,
root_id_
,
op_type_
,
data_type
,
group_
,
RuntimeUtils
::
HcomBindModel
,
RuntimeUtils
::
HcomUnbindModel
,
RuntimeUtils
::
HcomDistribute
);
MS_EXCEPTION_IF_NULL
(
task_info_ptr
);
return
{
task_info_ptr
};
...
...
mindspore/ccsrc/kernel/hccl/hccl_kernel.h
浏览文件 @
9779680d
...
...
@@ -54,6 +54,7 @@ class HcclKernel : public AscendKernelMod {
mutable
std
::
vector
<
size_t
>
workspace_size_list_
;
AnfNodePtr
anf_node_
;
std
::
string
op_name_
;
std
::
string
group_
;
};
using
HcclKernelCreater
=
std
::
function
<
std
::
shared_ptr
<
HcclKernel
>
()
>
;
...
...
mindspore/ccsrc/kernel/hccl/hcom_util.cc
浏览文件 @
9779680d
...
...
@@ -176,11 +176,22 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) {
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
anf_node
);
MS_EXCEPTION_IF_NULL
(
primitive
);
if
(
primitive
->
GetAttr
(
"root_rank"
)
!=
nullptr
)
{
*
root_id
=
GetValue
<
const
vector
<
uint32_t
>>
(
primitive
->
GetAttr
(
"root_rank"
))[
0
]
;
*
root_id
=
(
uint32_t
)
GetValue
<
int
>
(
primitive
->
GetAttr
(
"root_rank"
))
;
}
else
{
MS_LOG
(
ERROR
)
<<
"HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"
;
return
false
;
}
return
true
;
}
void
HcomUtil
::
GetHcomGroup
(
NotNull
<
const
AnfNodePtr
&>
anf_node
,
NotNull
<
std
::
string
*>
group
)
{
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
anf_node
);
MS_EXCEPTION_IF_NULL
(
primitive
);
auto
attr
=
primitive
->
GetAttr
(
"group"
);
if
(
attr
!=
nullptr
)
{
*
group
=
GetValue
<
std
::
string
>
(
attr
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Get Hcom Group Attr of Op:"
<<
anf_node
->
fullname_with_scope
()
<<
" failed"
;
}
}
}
// namespace mindspore
mindspore/ccsrc/kernel/hccl/hcom_util.h
浏览文件 @
9779680d
...
...
@@ -23,6 +23,7 @@
#include <memory>
#include "ir/dtype.h"
#include "hccl/base.h"
#include "utils/contract.h"
namespace
mindspore
{
using
std
::
map
;
...
...
@@ -61,6 +62,7 @@ class HcomUtil {
const
vector
<
vector
<
size_t
>>
&
shape_list
,
uint64_t
*
total_count
);
static
bool
GetHcomOperationType
(
const
AnfNodePtr
&
anf_node
,
hcclRedOp_t
*
op_type
);
static
bool
GetHcomRootId
(
const
AnfNodePtr
&
anf_node
,
uint32_t
*
root_id
);
static
void
GetHcomGroup
(
NotNull
<
const
AnfNodePtr
&>
anf_node
,
NotNull
<
std
::
string
*>
group
);
};
}
// namespace mindspore
...
...
mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc
浏览文件 @
9779680d
...
...
@@ -66,8 +66,7 @@ const AnfNodePtr AddMemcpyAsync::Process(const FuncGraphPtr &func_graph, const A
return
nullptr
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
op_name
=
AnfAlgo
::
GetCNodeName
(
cnode
);
if
(
op_name
!=
kAllReduceOpName
&&
op_name
!=
kAllGatherOpName
&&
op_name
!=
kReduceScatterOpName
)
{
if
(
!
AnfAlgo
::
IsCommunicationOp
(
node
))
{
return
nullptr
;
}
return
AddMemcpyAsyncIfInputIsUsedByOthers
(
func_graph
,
cnode
);
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc
浏览文件 @
9779680d
...
...
@@ -173,6 +173,19 @@ const BaseRef DealRefTransAndCast::DefinePattern() const {
return
VectorRef
({
V
,
Xs
});
}
void
DealBroadCastAsRef
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
cnode
)
{
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kBroadcastOpName
)
{
auto
input_size
=
AnfAlgo
::
GetInputTensorNum
(
cnode
);
for
(
size_t
i
=
0
;
i
<
input_size
;
++
i
)
{
auto
input_node_with_index
=
AnfAlgo
::
GetPrevNodeOutput
(
cnode
,
i
);
auto
input_node
=
input_node_with_index
.
first
;
MS_EXCEPTION_IF_NULL
(
input_node
);
MS_LOG
(
INFO
)
<<
"origin node:"
<<
input_node
->
fullname_with_scope
();
AddRefPairToKernelGraph
(
func_graph
,
cnode
,
nullptr
,
cnode
,
i
,
input_node_with_index
);
}
}
}
const
AnfNodePtr
DealRefTransAndCast
::
Process
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
if
(
node
==
nullptr
||
!
node
->
isa
<
CNode
>
())
{
...
...
@@ -184,6 +197,9 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
cnode
))
{
return
nullptr
;
}
DealBroadCastAsRef
(
graph
,
cnode
);
auto
op_name
=
AnfAlgo
::
GetCNodeName
(
cnode
);
auto
op_info
=
mindspore
::
kernel
::
OpLib
::
FindOp
(
op_name
,
kernel
::
kTBE
);
if
(
op_info
==
nullptr
||
!
op_info
->
is_ref
())
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录