Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1c6a5f5c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1c6a5f5c
编写于
6月 15, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 15, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2121 replace first input of dropout_gen_mask of the subgraph instead of the whole sub graph
Merge pull request !2121 from yihuaijie/dev
上级
928c25eb
7857d59c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
52 addition
and
9 deletion
+52
-9
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc
+45
-6
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h
+1
-1
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+6
-2
未找到文件。
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc
浏览文件 @
1c6a5f5c
...
...
@@ -204,7 +204,7 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) {
PrimitivePtr
GetDropoutGenMaskPrim
(
const
CNodePtr
&
cnode
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
cnode
->
inputs
().
size
()
!=
DROPOUT_DO_MASK_CNODE_INPUT_SIZE
)
{
if
(
cnode
->
size
()
!=
DROPOUT_DO_MASK_CNODE_INPUT_SIZE
)
{
MS_LOG
(
EXCEPTION
)
<<
"The size of dropout do mask cnode's inputs must be "
<<
DROPOUT_DO_MASK_CNODE_INPUT_SIZE
;
}
...
...
@@ -215,8 +215,7 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
}
auto
dropout_gen_mask_cnode
=
dropout_gen_mask
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
dropout_gen_mask_cnode
);
if
(
dropout_gen_mask_cnode
->
inputs
().
size
()
!=
DROPOUT_GEN_MASK_CNODE_INPUT_SIZE
)
{
if
(
dropout_gen_mask_cnode
->
size
()
!=
DROPOUT_GEN_MASK_CNODE_INPUT_SIZE
)
{
MS_LOG
(
EXCEPTION
)
<<
"The size of dropout gen mask cnode's inputs must be "
<<
DROPOUT_GEN_MASK_CNODE_INPUT_SIZE
;
}
if
(
!
IsValueNode
<
Primitive
>
(
dropout_gen_mask_cnode
->
input
(
0
)))
{
...
...
@@ -233,11 +232,45 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
return
prim
;
}
void
SetGenMaskShape
(
const
CNodePtr
&
cnode
,
const
Shape
&
input_slice_shape
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
cnode
->
size
()
!=
DROPOUT_DO_MASK_CNODE_INPUT_SIZE
)
{
MS_LOG
(
EXCEPTION
)
<<
"The size of dropout do mask cnode's inputs must be "
<<
DROPOUT_DO_MASK_CNODE_INPUT_SIZE
;
}
AnfNodePtr
dropout_gen_mask
=
cnode
->
input
(
DROPOUT_GEN_MASK_INDEX
);
MS_EXCEPTION_IF_NULL
(
dropout_gen_mask
);
if
(
!
dropout_gen_mask
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"The dropout do mask cnode's input["
<<
DROPOUT_GEN_MASK_INDEX
<<
"] must be a cnode."
;
}
auto
dropout_gen_mask_cnode
=
dropout_gen_mask
->
cast
<
CNodePtr
>
();
if
(
dropout_gen_mask_cnode
->
size
()
!=
DROPOUT_GEN_MASK_CNODE_INPUT_SIZE
)
{
MS_LOG
(
EXCEPTION
)
<<
"The size of dropout gen mask cnode's inputs must be "
<<
DROPOUT_GEN_MASK_CNODE_INPUT_SIZE
;
}
if
(
!
IsValueNode
<
ValueTuple
>
(
dropout_gen_mask_cnode
->
input
(
1
)))
{
MS_LOG
(
EXCEPTION
)
<<
"The input[1] of dropout gen mask cnode is not ValueTuple."
;
}
FuncGraphPtr
func_graph
=
cnode
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
func_graph
);
FuncGraphManagerPtr
manager
=
func_graph
->
manager
();
if
(
manager
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: AddNode error since manager is nullptr."
;
}
ValuePtr
new_shape
=
MakeValue
(
input_slice_shape
);
AnfNodePtr
val
=
NewValueNode
(
new_shape
);
(
void
)
manager
->
Replace
(
dropout_gen_mask_cnode
->
input
(
1
),
val
);
}
// DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is
// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape
// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation
// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask.
Operator
DropoutDoMaskInfo
::
GetDropoutGenMaskReplaceOp
(
const
CNodePtr
&
cnode
)
{
std
::
vector
<
Operator
>
DropoutDoMaskInfo
::
GetDropoutGenMaskReplaceOp
(
const
CNodePtr
&
cnode
)
{
std
::
vector
<
Operator
>
replace_ops
;
MS_EXCEPTION_IF_NULL
(
cnode
);
PrimitivePtr
prim
=
GetDropoutGenMaskPrim
(
cnode
);
MS_EXCEPTION_IF_NULL
(
prim
);
...
...
@@ -260,15 +293,20 @@ Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
if
((
attr
.
find
(
SEED0
)
==
attr
.
end
())
||
(
attr
.
find
(
SEED1
)
==
attr
.
end
()))
{
MS_LOG
(
EXCEPTION
)
<<
"The attrs of dropout gen mask must be have seed0 and seed1"
;
}
Shape
input_slice_shape
=
inputs_tensor_info_
[
0
].
slice_shape
();
int32_t
seed_0
=
GetValue
<
int32_t
>
(
attr
[
SEED0
]);
int32_t
seed_1
=
GetValue
<
int32_t
>
(
attr
[
SEED1
]);
if
((
seed_0
==
0
)
&&
(
seed_1
==
0
)
&&
(
repeated_calc_num_
>
1
))
{
seed_0
=
SEED_NUM
;
seed_1
=
SEED_NUM
;
SEED_NUM
++
;
}
else
{
SetGenMaskShape
(
cnode
,
input_slice_shape
);
MS_LOG
(
DEBUG
)
<<
"The input slice shape droupout is "
<<
ShapeToString
(
input_slice_shape
);
return
replace_ops
;
}
Shape
input_slice_shape
=
inputs_tensor_info_
[
0
].
slice_shape
();
ValuePtr
new_shape
=
MakeValue
(
input_slice_shape
);
Attr
attr_0
=
std
::
make_pair
(
SEED0
,
MakeValue
(
seed_0
));
Attr
attr_1
=
std
::
make_pair
(
SEED1
,
MakeValue
(
seed_1
));
...
...
@@ -278,7 +316,8 @@ Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
OperatorParams
params
=
{
std
::
make_pair
(
param_0
,
1
),
std
::
make_pair
(
param_1
,
2
)};
OperatorArgs
args
=
std
::
make_pair
(
attrs
,
params
);
Operator
replace_op
=
{
std
::
make_pair
(
DROPOUT_GEN_MASK
,
args
)};
return
replace_op
;
replace_ops
.
push_back
(
replace_op
);
return
replace_ops
;
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h
浏览文件 @
1c6a5f5c
...
...
@@ -41,7 +41,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GenerateBatchStrategies
()
override
;
Operator
GetDropoutGenMaskReplaceOp
(
const
CNodePtr
&
cnode
);
std
::
vector
<
Operator
>
GetDropoutGenMaskReplaceOp
(
const
CNodePtr
&
cnode
);
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
1c6a5f5c
...
...
@@ -1876,11 +1876,15 @@ void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePt
DropoutDoMaskInfoPtr
dropout_do_mask
=
std
::
dynamic_pointer_cast
<
DropoutDoMaskInfo
>
(
distribute_operator
);
MS_EXCEPTION_IF_NULL
(
dropout_do_mask
);
Operator
replace_op
=
dropout_do_mask
->
GetDropoutGenMaskReplaceOp
(
cnode
);
std
::
vector
<
Operator
>
replace_op
=
dropout_do_mask
->
GetDropoutGenMaskReplaceOp
(
cnode
);
if
(
replace_op
.
empty
())
{
MS_LOG
(
DEBUG
)
<<
"No need to replace dropout_gen_mask"
;
return
;
}
if
(
cnode
->
inputs
().
size
()
!=
DROPOUT_DO_MASK_CNODE_INPUT_SIZE
)
{
MS_LOG
(
EXCEPTION
)
<<
"The size of drop out do mask cnode's input is not "
<<
DROPOUT_DO_MASK_CNODE_INPUT_SIZE
;
}
ReplaceOneOp
(
replace_op
,
cnode
->
input
(
DROPOUT_GEN_MASK_INDEX
)
->
cast
<
CNodePtr
>
());
ReplaceOneOp
(
replace_op
[
0
]
,
cnode
->
input
(
DROPOUT_GEN_MASK_INDEX
)
->
cast
<
CNodePtr
>
());
}
void
HandleSpecialNode
(
const
OperatorInfoPtr
&
distribute_operator
,
const
CNodePtr
&
cnode
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录