Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f1563d2d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f1563d2d
编写于
7月 21, 2020
作者:
H
huanghui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
insert memcpy async if hccl op cascade
上级
48325dea
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
288 addition
and
58 deletion
+288
-58
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
...c/backend/optimizer/ascend/ascend_backend_optimization.cc
+2
-0
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc
...imizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc
+114
-0
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h
...timizer/ascend/enhancer/insert_memcpy_async_for_cascade.h
+39
-0
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc
...imizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc
+56
-41
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h
...timizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h
+1
-1
tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc
...e/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc
+44
-5
tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py
...est_input/pre_activate/insert_memcpy_async_for_hccl_op.py
+32
-11
未找到文件。
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
浏览文件 @
f1563d2d
...
...
@@ -87,6 +87,7 @@
#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h"
#include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h"
#include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h"
...
...
@@ -340,6 +341,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
other_pm
->
AddPass
(
std
::
make_shared
<
AllGatherFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
ReduceScatterFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
BroadcastFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
InsertMemcpyAsyncForCascade
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
ParameterTransOpFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
RefreshParameterFormat
>
());
optimizer
->
AddPassManager
(
other_pm
);
...
...
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc
0 → 100644
浏览文件 @
f1563d2d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
#include <vector>
#include <set>
#include <string>
#include "utils/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/optimizer/opt.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
bool
IsPartOutputsOfHcclOp
(
const
AnfNodePtr
&
node
,
const
CNodePtr
&
cur_hccl
,
const
FuncGraphPtr
&
graph
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
cur_hccl
);
MS_EXCEPTION_IF_NULL
(
graph
);
if
(
!
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimTupleGetItem
))
{
return
false
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
prev_node
=
cnode
->
input
(
kRealInputNodeIndexInTupleGetItem
);
MS_EXCEPTION_IF_NULL
(
prev_node
);
if
(
!
AnfAlgo
::
IsCommunicationOp
(
prev_node
))
{
return
false
;
}
auto
prev_hccl_op
=
prev_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
prev_hccl_op
);
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
&
node_users
=
manager
->
node_users
();
auto
iter
=
node_users
.
find
(
prev_hccl_op
);
if
(
iter
==
node_users
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"node has no output in manager"
;
}
for
(
const
auto
&
node_index
:
iter
->
second
)
{
AnfNodePtr
output
=
node_index
.
first
;
MS_EXCEPTION_IF_NULL
(
output
);
if
(
IsPrimitiveCNode
(
output
,
prim
::
kPrimTupleGetItem
))
{
bool
is_contain
=
false
;
for
(
size_t
i
=
1
;
i
<
cur_hccl
->
size
();
++
i
)
{
if
(
cur_hccl
->
input
(
i
)
==
output
)
{
is_contain
=
true
;
break
;
}
}
if
(
!
is_contain
)
{
return
true
;
}
}
}
return
false
;
}
}
// namespace
AnfNodePtr
InsertMemcpyAsyncForCascade
::
InsertMemcpyAsync
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
hccl_node
)
const
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
hccl_node
);
std
::
vector
<
AnfNodePtr
>
memcpy_async_list
;
std
::
vector
<
AnfNodePtr
>
new_inputs
=
{
hccl_node
->
input
(
0
)};
for
(
size_t
i
=
1
;
i
<
hccl_node
->
size
();
++
i
)
{
auto
input
=
hccl_node
->
input
(
i
);
MS_EXCEPTION_IF_NULL
(
input
);
// when input is also a hccl op and just part outputs of it linking with cur_hccl_op
if
(
IsPartOutputsOfHcclOp
(
input
,
hccl_node
,
graph
))
{
auto
memcpy_async
=
CreateMemcpyAsyncOp
(
graph
,
input
);
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
memcpy_async
->
set_kernel_info
(
kernel_info
);
MS_EXCEPTION_IF_NULL
(
kernel_select_
);
kernel_select_
->
SelectKernel
(
memcpy_async
->
cast
<
CNodePtr
>
());
new_inputs
.
push_back
(
memcpy_async
);
memcpy_async_list
.
push_back
(
memcpy_async
);
}
else
{
new_inputs
.
push_back
(
input
);
}
}
if
(
!
memcpy_async_list
.
empty
())
{
CNodePtr
new_hccl_node
=
std
::
make_shared
<
CNode
>
(
*
hccl_node
);
new_hccl_node
->
set_inputs
(
new_inputs
);
return
new_hccl_node
;
}
return
nullptr
;
}
const
AnfNodePtr
InsertMemcpyAsyncForCascade
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
if
(
func_graph
==
nullptr
||
node
==
nullptr
||
!
node
->
isa
<
CNode
>
())
{
return
nullptr
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
!
AnfAlgo
::
IsCommunicationOp
(
node
))
{
return
nullptr
;
}
return
InsertMemcpyAsync
(
func_graph
,
cnode
);
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h
0 → 100644
浏览文件 @
f1563d2d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace
mindspore
{
namespace
opt
{
class
InsertMemcpyAsyncForCascade
:
public
PatternProcessPass
{
public:
explicit
InsertMemcpyAsyncForCascade
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"insert_memcpy_async_for_cascade"
,
multigraph
),
kernel_select_
(
std
::
make_shared
<
KernelSelect
>
())
{}
~
InsertMemcpyAsyncForCascade
()
override
=
default
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
private:
AnfNodePtr
InsertMemcpyAsync
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
hccl_node
)
const
;
KernelSelectPtr
kernel_select_
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_OP_CASCADE_H_
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc
浏览文件 @
f1563d2d
...
...
@@ -32,12 +32,17 @@ const std::set<std::string> kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNe
bool
IsParameterOrValueNode
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_with_index
=
AnfAlgo
::
VisitKernelWithReturnType
(
node
,
0
,
true
);
return
kernel_with_index
.
first
->
isa
<
Parameter
>
()
||
kernel_with_index
.
first
->
isa
<
ValueNode
>
();
auto
real_node
=
kernel_with_index
.
first
;
MS_EXCEPTION_IF_NULL
(
real_node
);
if
(
real_node
->
isa
<
Parameter
>
())
{
return
true
;
}
return
real_node
->
isa
<
ValueNode
>
();
}
void
TransferControl
(
const
CNodePtr
&
hccl_node
,
const
AnfNodePtr
&
memcpy_async
,
const
FuncGraphPtr
&
graph
)
{
void
TransferControl
(
const
CNodePtr
&
hccl_node
,
const
std
::
vector
<
AnfNodePtr
>
&
memcpy_async_list
,
const
FuncGraphPtr
&
graph
)
{
MS_EXCEPTION_IF_NULL
(
hccl_node
);
MS_EXCEPTION_IF_NULL
(
memcpy_async
);
MS_EXCEPTION_IF_NULL
(
graph
);
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
...
...
@@ -48,49 +53,62 @@ void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async,
}
// find hccl_node's output which is a control depend
for
(
const
auto
&
node_index
:
iter
->
second
)
{
AnfNodePtr
output
=
node_index
.
first
;
int
output_index
=
node_index
.
second
;
if
(
AnfAlgo
::
CheckPrimitiveType
(
output
,
prim
::
kPrimControlDepend
))
{
CNodePtr
control_depend
=
output
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
control_depend
);
std
::
vector
<
AnfNodePtr
>
new_inputs
;
for
(
size_t
i
=
0
;
i
<
control_depend
->
size
();
++
i
)
{
if
(
i
==
IntToSize
(
output_index
))
{
new_inputs
.
push_back
(
memcpy_async
);
}
else
{
new_inputs
.
push_back
(
control_depend
->
input
(
i
));
}
if
(
!
AnfAlgo
::
CheckPrimitiveType
(
node_index
.
first
,
prim
::
kPrimControlDepend
))
{
continue
;
}
CNodePtr
control_depend
=
node_index
.
first
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
control_depend
);
std
::
vector
<
AnfNodePtr
>
new_inputs
;
for
(
size_t
i
=
0
;
i
<
control_depend
->
size
();
++
i
)
{
if
(
i
==
IntToSize
(
node_index
.
second
))
{
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs
=
{
NewValueNode
(
prim
::
kPrimMakeTuple
)};
make_tuple_inputs
.
insert
(
make_tuple_inputs
.
end
(),
memcpy_async_list
.
begin
(),
memcpy_async_list
.
end
());
make_tuple_inputs
.
emplace_back
(
hccl_node
);
auto
make_tuple
=
graph
->
NewCNode
(
make_tuple_inputs
);
MS_EXCEPTION_IF_NULL
(
make_tuple
);
new_inputs
.
push_back
(
make_tuple
);
}
else
{
new_inputs
.
push_back
(
control_depend
->
input
(
i
));
}
control_depend
->
set_inputs
(
new_inputs
);
}
control_depend
->
set_inputs
(
new_inputs
);
}
}
}
// namespace
bool
InsertMemcpyAsyncForHcclOp
::
NeedInsertMemcpy
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
input
)
const
{
bool
InsertMemcpyAsyncForHcclOp
::
NeedInsertMemcpy
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
input
,
const
CNodePtr
&
cur_node
)
const
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
input
);
MS_EXCEPTION_IF_NULL
(
cur_node
);
// when input is a parameter or is a value node
if
(
IsParameterOrValueNode
(
input
))
{
return
true
;
}
// when input is a Ref or some special cnodes
if
(
kernel_query_
->
IsTbeRef
(
input
)
||
kNeedInsertMemcpyOpSet
.
find
(
AnfAlgo
::
GetCNodeName
(
input
))
!=
kNeedInsertMemcpyOpSet
.
end
())
{
return
true
;
}
if
(
input
->
isa
<
CNode
>
())
{
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
&
node_users
=
manager
->
node_users
();
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
&
node_users
=
manager
->
node_users
();
auto
iter
=
node_users
.
find
(
input
);
if
(
iter
==
node_users
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"node has no output in manager"
;
}
// when input is used by others
if
(
iter
->
second
.
size
()
>
1
)
{
return
true
;
// when input is a Ref cnode
if
(
kernel_query_
->
IsTbeRef
(
input
))
{
return
true
;
}
// when input is some special cnodes
if
(
kNeedInsertMemcpyOpSet
.
find
(
AnfAlgo
::
GetCNodeName
(
input
))
!=
kNeedInsertMemcpyOpSet
.
end
())
{
return
true
;
}
// when input is used by others
auto
iter
=
node_users
.
find
(
input
);
if
(
iter
==
node_users
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"node has no output in manager"
;
}
if
(
iter
->
second
.
size
()
>
1
)
{
return
true
;
}
}
return
false
;
}
...
...
@@ -98,21 +116,20 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
void
InsertMemcpyAsyncForHcclOp
::
InsertMemcpyAsync
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
hccl_node
)
const
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
hccl_node
);
bool
has_insert_memcpy
=
false
;
AnfNodePtr
memcpy_async
=
nullptr
;
std
::
vector
<
AnfNodePtr
>
memcpy_async_list
;
std
::
vector
<
AnfNodePtr
>
new_inputs
=
{
hccl_node
->
input
(
0
)};
for
(
size_t
i
=
1
;
i
<
hccl_node
->
size
();
++
i
)
{
auto
input
=
hccl_node
->
input
(
i
);
if
(
NeedInsertMemcpy
(
graph
,
input
))
{
memcpy_async
=
CreateMemcpyAsyncOp
(
graph
,
input
);
has_insert_memcpy
=
true
;
if
(
NeedInsertMemcpy
(
graph
,
input
,
hccl_node
))
{
auto
memcpy_async
=
CreateMemcpyAsyncOp
(
graph
,
input
);
new_inputs
.
push_back
(
memcpy_async
);
memcpy_async_list
.
push_back
(
memcpy_async
);
}
else
{
new_inputs
.
push_back
(
input
);
}
}
if
(
has_insert_memcpy
)
{
if
(
!
memcpy_async_list
.
empty
()
)
{
CNodePtr
new_hccl_node
=
std
::
make_shared
<
CNode
>
(
*
hccl_node
);
new_hccl_node
->
set_inputs
(
new_inputs
);
auto
manager
=
graph
->
manager
();
...
...
@@ -122,9 +139,7 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
MS_LOG
(
DEBUG
)
<<
"end replace"
;
// transer hccl op's control to the memcpy_async
if
(
hccl_node
->
size
()
==
2
)
{
TransferControl
(
new_hccl_node
,
memcpy_async
,
graph
);
}
TransferControl
(
new_hccl_node
,
memcpy_async_list
,
graph
);
}
}
...
...
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h
浏览文件 @
f1563d2d
...
...
@@ -32,7 +32,7 @@ class InsertMemcpyAsyncForHcclOp : public PatternProcessPass {
private:
void
InsertMemcpyAsync
(
const
FuncGraphPtr
&
graph
,
const
CNodePtr
&
hccl_node
)
const
;
bool
NeedInsertMemcpy
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
input
)
const
;
bool
NeedInsertMemcpy
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
input
,
const
CNodePtr
&
cur_node
)
const
;
KernelQueryPtr
kernel_query_
;
};
}
// namespace opt
...
...
tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc
浏览文件 @
f1563d2d
...
...
@@ -22,6 +22,7 @@
#include "utils/utils.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/optimizer/common/optimizer.h"
#include "ir/param_value.h"
#define private public
#define protected public
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
...
...
@@ -44,12 +45,10 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery {
~
MockInsertMemcpyForHcclKernelQuery
()
override
=
default
;
bool
IsTbeRef
(
const
AnfNodePtr
&
node
)
override
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
==
nullptr
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
return
false
;
}
auto
name
=
AnfAlgo
::
GetCNodeName
(
cnode
);
return
name
==
"ApplyMomentum"
;
return
AnfAlgo
::
GetCNodeName
(
node
->
cast
<
CNodePtr
>
())
==
"ApplyMomentum"
;
}
};
...
...
@@ -105,6 +104,11 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) {
AbstractBasePtrList
args_spec_list
{
x_abstract
};
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
EXPECT_NE
(
kg
,
nullptr
);
for
(
auto
p
:
kg
->
parameters
())
{
auto
param
=
p
->
cast
<
ParameterPtr
>
();
EXPECT_NE
(
param
,
nullptr
);
param
->
set_default_param
(
std
::
make_shared
<
ParamValue
>
());
}
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
...
...
@@ -146,10 +150,16 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
ASSERT_TRUE
(
g
!=
nullptr
);
std
::
vector
<
int
>
shp_x
{
1
,
64
,
112
,
112
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp_x
);
AbstractBasePtrList
args_spec_list
{
x_abstract
,
x_abstract
,
x_abstract
,
x_abstract
,
x_abstract
};
AbstractBasePtrList
args_spec_list
{
x_abstract
,
x_abstract
};
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
EXPECT_NE
(
kg
,
nullptr
);
for
(
auto
p
:
kg
->
parameters
())
{
auto
param
=
p
->
cast
<
ParameterPtr
>
();
EXPECT_NE
(
param
,
nullptr
);
param
->
set_default_param
(
std
::
make_shared
<
ParamValue
>
());
}
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
pass
=
std
::
make_shared
<
opt
::
InsertMemcpyAsyncForHcclOp
>
();
...
...
@@ -161,5 +171,34 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_insert_memcpy_async_for_hccl_op_cond4"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWInsertMemcpyForHccl
,
test_cond5
)
{
get_py_fun_
.
SetDoResolve
(
true
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_insert_memcpy_async_for_hccl_op_cond5"
,
"before"
);
ASSERT_TRUE
(
g
!=
nullptr
);
std
::
vector
<
int
>
shp_x
{
1
,
64
,
112
,
112
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp_x
);
AbstractBasePtrList
args_spec_list
{
x_abstract
,
x_abstract
,
x_abstract
};
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
EXPECT_NE
(
kg
,
nullptr
);
for
(
auto
p
:
kg
->
parameters
())
{
auto
param
=
p
->
cast
<
ParameterPtr
>
();
EXPECT_NE
(
param
,
nullptr
);
param
->
set_default_param
(
std
::
make_shared
<
ParamValue
>
());
}
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
pass
=
std
::
make_shared
<
opt
::
InsertMemcpyAsyncForHcclOp
>
();
pass
->
kernel_query_
=
std
::
make_shared
<
MockInsertMemcpyForHcclKernelQuery
>
();
pm
->
AddPass
(
pass
);
optimizer
->
AddPassManager
(
pm
);
auto
new_graph
=
optimizer
->
Optimize
(
kg
);
kg
->
SetExecOrderByDefault
();
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_insert_memcpy_async_for_hccl_op_cond5"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
}
// namespace opt
}
// namespace mindspore
tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py
浏览文件 @
f1563d2d
...
...
@@ -17,6 +17,7 @@ from mindspore.ops import Primitive
from
mindspore.ops
import
operations
as
P
all_reduce
=
P
.
AllReduce
()
broadcast
=
P
.
Broadcast
(
1
)
memcpy_async
=
Primitive
(
'memcpy_async'
)
make_tuple
=
Primitive
(
'make_tuple'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
...
...
@@ -101,20 +102,40 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag):
fns
=
FnDict
()
@
fns
def
before
(
a
,
b
,
c
,
d
,
e
):
res1
=
apply_momentun
(
a
,
b
,
c
,
d
,
e
)
res2
=
all_reduce
(
a
)
res
=
control_depend
(
res1
,
res2
)
res
=
make_tuple
(
res
,
res2
)
def
before
(
a
,
b
):
x
=
relu
(
a
)
y
=
all_reduce
(
b
)
res
=
control_depend
(
x
,
y
)
return
res
@
fns
def
after
(
a
,
b
,
c
,
d
,
e
):
res1
=
apply_momentun
(
a
,
b
,
c
,
d
,
e
)
res2
=
memcpy_async
(
a
)
res3
=
all_reduce
(
res2
)
res
=
control_depend
(
res1
,
res2
)
res
=
make_tuple
(
res
,
res3
)
def
after
(
a
,
b
):
x
=
relu
(
a
)
y1
=
memcpy_async
(
b
)
y2
=
all_reduce
(
y1
)
res
=
control_depend
(
x
,
make_tuple
(
y1
,
y2
))
return
make_tuple
(
res
)
return
fns
[
tag
]
def
test_insert_memcpy_async_for_hccl_op_cond5
(
tag
):
fns
=
FnDict
()
@
fns
def
before
(
a
,
b
,
c
):
x
=
relu
(
a
)
y
=
broadcast
((
b
,
c
))
res
=
control_depend
(
x
,
y
)
return
res
@
fns
def
after
(
a
,
b
,
c
):
x
=
relu
(
a
)
m1
=
memcpy_async
(
b
)
m2
=
memcpy_async
(
c
)
y
=
broadcast
(
m1
,
m2
)
res
=
control_depend
(
x
,
make_tuple
(
m1
,
m2
,
y
))
return
make_tuple
(
res
)
return
fns
[
tag
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录