Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9c005b18
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9c005b18
编写于
6月 06, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
convert dropoutdomask nodes's kernel info's first format which connected with same dropoutgenmask
上级
10ebd81b
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
230 addition
and
73 deletion
+230
-73
mindspore/ccsrc/kernel/kernel.h
mindspore/ccsrc/kernel/kernel.h
+1
-1
mindspore/ccsrc/kernel/kernel_build_info.cc
mindspore/ccsrc/kernel/kernel_build_info.cc
+15
-0
mindspore/ccsrc/kernel/kernel_build_info.h
mindspore/ccsrc/kernel/kernel_build_info.h
+4
-0
mindspore/ccsrc/kernel/kernel_query.cc
mindspore/ccsrc/kernel/kernel_query.cc
+9
-1
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+1
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+1
-0
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+3
-21
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h
...e/ccsrc/pre_activate/ascend/ascend_backend_optimization.h
+0
-1
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
+1
-1
mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.cc
.../pre_activate/ascend/format_type/insert_cast_for_runop.cc
+0
-48
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
...ctivate/ascend/format_type/rectify_do_mask_kernel_info.cc
+154
-0
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h
...activate/ascend/format_type/rectify_do_mask_kernel_info.h
+41
-0
未找到文件。
mindspore/ccsrc/kernel/kernel.h
浏览文件 @
9c005b18
...
...
@@ -31,7 +31,7 @@ enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AUTO_DIFF_KERNEL, AICPU_KERNEL,
namespace
kernel
{
enum
Axis
{
enum
Axis
:
int
{
N
=
0
,
C
,
H
,
...
...
mindspore/ccsrc/kernel/kernel_build_info.cc
浏览文件 @
9c005b18
...
...
@@ -167,5 +167,20 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) {
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
kernel_build_info_
->
op_pattern_
=
pattern
;
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetInputFormat
(
const
std
::
string
&
format
,
size_t
index
)
{
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
if
(
index
>=
kernel_build_info_
->
inputs_format_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"index outof range!"
;
}
kernel_build_info_
->
inputs_format_
[
index
]
=
format
;
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetOutputFormat
(
const
std
::
string
&
format
,
size_t
index
)
{
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
if
(
index
>=
kernel_build_info_
->
outputs_format_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"index outof range!"
;
}
kernel_build_info_
->
outputs_format_
[
index
]
=
format
;
}
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/kernel_build_info.h
浏览文件 @
9c005b18
...
...
@@ -131,6 +131,10 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void
SetOpPattern
(
OpPattern
pattern
);
void
SetInputFormat
(
const
std
::
string
&
format
,
size_t
index
);
void
SetOutputFormat
(
const
std
::
string
&
format
,
size_t
index
);
std
::
shared_ptr
<
KernelBuildInfo
>
Build
();
private:
...
...
mindspore/ccsrc/kernel/kernel_query.cc
浏览文件 @
9c005b18
...
...
@@ -41,8 +41,16 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
}
else
{
MS_LOG
(
WARNING
)
<<
"All kernel Info list does not match any kernel info "
;
for
(
size_t
index
=
0
;
index
<
kernel_info_list
->
size
();
++
index
)
{
std
::
ostringstream
buffer
;
MS_EXCEPTION_IF_NULL
(
kernel_info_list
->
at
(
index
));
MS_LOG
(
WARNING
)
<<
"kernel [ "
<<
index
<<
" ] :"
<<
kernel_info_list
->
at
(
index
)
->
ToString
();
if
(
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
)
!=
kernel_info_list
->
at
(
index
)
->
GetOutputNum
())
{
buffer
<<
"Kernel node's output size ["
<<
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
)
<<
"]"
<<
" cannot match the kernel's output size ["
<<
kernel_info_list
->
at
(
index
)
->
GetOutputNum
()
<<
"]"
;
}
else
{
buffer
<<
"Kernel node's output size ["
<<
AnfAlgo
::
GetInputTensorNum
(
kernel_node
)
<<
"]"
<<
" cannot match the kernel's output size ["
<<
kernel_info_list
->
at
(
index
)
->
GetInputNum
()
<<
"]"
;
}
MS_LOG
(
WARNING
)
<<
"kernel [ "
<<
index
<<
" ] :"
<<
kernel_info_list
->
at
(
index
)
->
ToString
()
<<
buffer
.
str
();
}
kernel_info_list
->
clear
();
MS_LOG
(
WARNING
)
<<
"node"
<<
kernel_node
->
DebugString
()
<<
"'s output size : ["
...
...
mindspore/ccsrc/operator/ops.cc
浏览文件 @
9c005b18
...
...
@@ -205,6 +205,7 @@ const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGr
const
PrimitivePtr
kPrimLayerNormXBackprop
=
std
::
make_shared
<
Primitive
>
(
"LayerNormXBackprop"
);
const
PrimitivePtr
kPrimLayerNormBetaGammaBackprop
=
std
::
make_shared
<
Primitive
>
(
"LayerNormBetaGammaBackprop"
);
const
PrimitivePtr
kPrimDropoutGenMask
=
std
::
make_shared
<
Primitive
>
(
"DropoutGenMask"
);
const
PrimitivePtr
kPrimDropoutDoMask
=
std
::
make_shared
<
Primitive
>
(
"DropoutDoMask"
);
const
PrimitivePtr
kPrimOneHot
=
std
::
make_shared
<
Primitive
>
(
"OneHot"
);
const
PrimitivePtr
kPrimGelu
=
std
::
make_shared
<
Primitive
>
(
"Gelu"
);
const
PrimitivePtr
kPrimGeluGrad
=
std
::
make_shared
<
Primitive
>
(
"GeluGrad"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
9c005b18
...
...
@@ -211,6 +211,7 @@ extern const PrimitivePtr kPrimLayerNormGrad;
extern
const
PrimitivePtr
kPrimLayerNormXBackprop
;
extern
const
PrimitivePtr
kPrimLayerNormBetaGammaBackprop
;
extern
const
PrimitivePtr
kPrimDropoutGenMask
;
extern
const
PrimitivePtr
kPrimDropoutDoMask
;
extern
const
PrimitivePtr
kPrimOneHot
;
extern
const
PrimitivePtr
kPrimGelu
;
extern
const
PrimitivePtr
kPrimGeluGrad
;
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
9c005b18
...
...
@@ -54,6 +54,7 @@
#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h"
#include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h"
#include "pre_activate/ascend/format_type/insert_trans_op.h"
#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h"
#include "pre_activate/pass/getitem_tuple.h"
#include "pre_activate/pass/optimize_dependence.h"
#include "pre_activate/pass/erase_visit_attr.h"
...
...
@@ -79,7 +80,6 @@
#include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h"
#include "pre_activate/ascend/enhancer/add_memcpy_async.h"
#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h"
#include "pre_activate/ascend/format_type/insert_cast_for_runop.h"
#include "pre_activate/ascend/format_type/insert_transdata_for_runop.h"
#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h"
#include "pre_activate/ascend/ir_fission/addn_fission.h"
...
...
@@ -145,6 +145,7 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
optimizer
=
std
::
make_shared
<
GraphOptimizer
>
();
auto
data_layout_pm
=
std
::
make_shared
<
PassManager
>
(
"pynative_transop_pm"
);
data_layout_pm
->
AddPass
(
std
::
make_shared
<
RectifyDoMaskKernelInfo
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
RunOpInsertTransData
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
CommonSubexpressionElimination
>
());
...
...
@@ -157,30 +158,11 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
kernel_graph
->
SetExecOrderByDefault
();
}
void
RunOpAscendMixPrecision
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
optimizer
=
std
::
make_shared
<
GraphOptimizer
>
();
auto
mixed_precision_pm
=
std
::
make_shared
<
PassManager
>
(
"pynative_transop_pm"
);
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
RunOpInsertCast
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
CommonSubexpressionElimination
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
EliminateRedundantOp
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
OptimizeDependence
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
EraseVisitAttr
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
DealRefTransAndCast
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
MergeCastToOp
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
LayerNormBetaGammaBackpropFusion
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
EraseVisitAttr
>
());
optimizer
->
AddPassManager
(
mixed_precision_pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
}
void
AscendDataLayout
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
optimizer
=
std
::
make_shared
<
GraphOptimizer
>
();
auto
data_layout_pm
=
std
::
make_shared
<
PassManager
>
(
"transop_pm"
);
data_layout_pm
->
AddPass
(
std
::
make_shared
<
RectifyDoMaskKernelInfo
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
InsertTransOp
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
CommonSubexpressionElimination
>
());
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h
浏览文件 @
9c005b18
...
...
@@ -20,7 +20,6 @@
namespace
mindspore
{
namespace
opt
{
void
RunOpAscendDataLayout
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
);
void
RunOpAscendMixPrecision
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
);
void
RunOpAscendBackendIRFusionOptimization
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
);
void
AscendDataLayout
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
);
void
AscendMixPrecision
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
);
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
浏览文件 @
9c005b18
...
...
@@ -65,7 +65,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
dtype
=
AnfAlgo
::
GetInputDeviceDataType
(
cnode
,
insert_index
);
dst_format
=
AnfAlgo
::
GetInputFormat
(
cnode
,
insert_index
);
input_node
=
AnfAlgo
::
GetInputNode
(
cnode
,
insert_index
);
padding_axis
=
AnfAlgo
::
GetInputReshapeType
(
node
,
0
);
padding_axis
=
AnfAlgo
::
GetInputReshapeType
(
node
,
insert_index
);
}
bool
need_padding
=
false
;
if
(
is_insert_input
)
{
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.cc
已删除
100644 → 0
浏览文件 @
10ebd81b
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/format_type/insert_cast_for_runop.h"
#include <memory>
#include "device/kernel_info.h"
#include "pre_activate/ascend/ascend_helper.h"
#include "pre_activate/common/helper.h"
#include "kernel/oplib/oplib.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
const
BaseRef
RunOpInsertCast
::
DefinePattern
()
const
{
VarPtr
V
=
std
::
make_shared
<
CondVar
>
(
UnVisited
);
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
return
VectorRef
({
V
,
Xs
});
}
const
AnfNodePtr
RunOpInsertCast
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
node
)
||
func_graph
==
nullptr
)
{
return
nullptr
;
}
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
node
);
// process input
CNodePtr
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
return
InsertCastForInput
(
func_graph
,
cnode
);
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
0 → 100644
浏览文件 @
9c005b18
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h"
#include <vector>
#include <map>
#include <string>
#include <memory>
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel_build_info.h"
#include "utils/utils.h"
#include "kernel/common_utils.h"
#include "utils/context/ms_context.h"
namespace
mindspore
{
namespace
opt
{
const
BaseRef
RectifyDoMaskKernelInfo
::
DefinePattern
()
const
{
VarPtr
X
=
std
::
make_shared
<
Var
>
();
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
return
VectorRef
({
X
,
Xs
});
}
const
AnfNodePtr
RectifyDoMaskKernelInfo
::
Process
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
if
(
node
==
nullptr
||
!
node
->
isa
<
CNode
>
())
{
return
nullptr
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
()
==
kPynativeMode
)
{
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
!=
prim
::
kPrimDropoutDoMask
->
name
())
{
return
nullptr
;
}
auto
do_mask_input_format
=
AnfAlgo
::
GetInputFormat
(
node
,
0
);
if
(
do_mask_input_format
!=
kOpFormat_DEFAULT
)
{
auto
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
AnfAlgo
::
GetSelectKernelBuildInfo
(
node
));
builder
->
SetInputFormat
(
kOpFormat_DEFAULT
,
0
);
builder
->
SetOutputFormat
(
kOpFormat_DEFAULT
,
0
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
node
.
get
());
}
return
nullptr
;
}
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
!=
prim
::
kPrimDropoutGenMask
->
name
())
{
return
nullptr
;
}
std
::
vector
<
CNodePtr
>
do_mask_node_list
;
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
node_map
=
manager
->
node_users
();
auto
iter
=
node_map
.
find
(
node
);
if
(
iter
==
node_map
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot find the node "
<<
node
->
DebugString
()
<<
" in the graph manager!"
;
}
auto
gen_mask_output_nodes
=
iter
->
second
;
for
(
const
auto
&
output_node
:
gen_mask_output_nodes
)
{
if
(
AnfAlgo
::
GetCNodeName
(
output_node
.
first
)
==
prim
::
kPrimDropoutDoMask
->
name
())
{
auto
output_cnode
=
output_node
.
first
->
cast
<
CNodePtr
>
();
do_mask_node_list
.
push_back
(
output_cnode
);
}
}
std
::
vector
<
size_t
>
input_shape
;
for
(
const
auto
&
output_node
:
do_mask_node_list
)
{
if
(
input_shape
.
empty
())
{
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
output_node
,
0
);
continue
;
}
auto
shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
output_node
,
0
);
if
(
!
kernel
::
IsSameShape
(
shape
,
input_shape
))
{
MS_LOG
(
EXCEPTION
)
<<
"The DropOutGenMask connected with same genmask's shape must be equal!"
<<
" GenMask "
<<
node
->
DebugString
();
}
}
RectifyKernelInfo
(
do_mask_node_list
);
return
nullptr
;
}
void
RectifyDoMaskKernelInfo
::
RectifyKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
)
const
{
std
::
map
<
std
::
string
,
size_t
>
format_counter
;
std
::
string
special_format
;
std
::
string
convert_format
;
for
(
const
auto
&
do_mask
:
do_mask_node_list
)
{
auto
do_mask_data_format
=
AnfAlgo
::
GetInputFormat
(
do_mask
,
0
);
if
(
special_format
.
empty
()
&&
kNeedTransFormatSet
.
find
(
do_mask_data_format
)
!=
kNeedTransFormatSet
.
end
())
{
special_format
=
do_mask_data_format
;
}
if
(
format_counter
.
find
(
do_mask_data_format
)
==
format_counter
.
end
())
{
format_counter
[
do_mask_data_format
]
=
1
;
}
else
{
format_counter
[
do_mask_data_format
]
=
format_counter
[
do_mask_data_format
]
+
1
;
}
// if has two or more special format we need change all domask's format to default that can avoid insert more
// transdata
if
(
format_counter
.
size
()
>
2
)
{
convert_format
=
kOpFormat_DEFAULT
;
break
;
}
if
(
kNeedTransFormatSet
.
find
(
do_mask_data_format
)
!=
kNeedTransFormatSet
.
end
()
&&
special_format
!=
do_mask_data_format
)
{
convert_format
=
kOpFormat_DEFAULT
;
break
;
}
}
if
(
format_counter
.
size
()
==
1
)
{
return
;
}
if
(
convert_format
.
empty
())
{
convert_format
=
GetConvertFormat
(
format_counter
);
}
RectifyDropOutDoMaskKernelInfo
(
do_mask_node_list
,
convert_format
);
}
std
::
string
RectifyDoMaskKernelInfo
::
GetConvertFormat
(
const
std
::
map
<
std
::
string
,
size_t
>
&
format_counter
)
const
{
std
::
string
convert_format
;
size_t
counter
=
0
;
for
(
const
auto
&
iter
:
format_counter
)
{
if
(
counter
<
iter
.
second
)
{
convert_format
=
iter
.
first
;
}
if
(
counter
==
iter
.
second
&&
kNeedTransFormatSet
.
find
(
convert_format
)
==
kNeedTransFormatSet
.
end
())
{
convert_format
=
iter
.
first
;
}
}
return
convert_format
;
}
void
RectifyDoMaskKernelInfo
::
RectifyDropOutDoMaskKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
const
std
::
string
&
format
)
const
{
for
(
const
auto
&
do_mask
:
do_mask_node_list
)
{
auto
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
AnfAlgo
::
GetSelectKernelBuildInfo
(
do_mask
));
builder
->
SetInputFormat
(
format
,
0
);
builder
->
SetOutputFormat
(
format
,
0
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
do_mask
.
get
());
}
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/format_type/
insert_cast_for_runop
.h
→
mindspore/ccsrc/pre_activate/ascend/format_type/
rectify_do_mask_kernel_info
.h
浏览文件 @
9c005b18
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
...
...
@@ -13,23 +13,29 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H
#include <map>
#include <string>
#include <vector>
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/pattern_engine.h"
#include "ir/anf.h"
namespace
mindspore
{
namespace
opt
{
class
R
unOpInsertCast
:
public
PatternProcessPass
{
class
R
ectifyDoMaskKernelInfo
:
public
PatternProcessPass
{
public:
explicit
RunOpInsertCast
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"insert_cast_for_runop"
,
multigraph
)
{}
~
RunOpInsertCast
()
override
=
default
;
explicit
RectifyDoMaskKernelInfo
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"batch_norm_bert_fission"
,
multigraph
)
{}
~
RectifyDoMaskKernelInfo
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
private:
void
RectifyKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
)
const
;
std
::
string
GetConvertFormat
(
const
std
::
map
<
std
::
string
,
size_t
>
&
format_counter
)
const
;
void
RectifyDropOutDoMaskKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
const
std
::
string
&
format
)
const
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录