Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2703ac5b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2703ac5b
编写于
5月 19, 2020
作者:
W
wanghua
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bert percision problem
上级
2a1aad0f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
14 addition
and
7 deletion
+14
-7
example/bert_clue/run_pretrain.py
example/bert_clue/run_pretrain.py
+1
-0
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
+8
-4
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+1
-1
mindspore/ops/_op_impl/tbe/tanh.py
mindspore/ops/_op_impl/tbe/tanh.py
+2
-0
tests/st/networks/models/bert/bert_tdt_lossscale.py
tests/st/networks/models/bert/bert_tdt_lossscale.py
+2
-2
未找到文件。
example/bert_clue/run_pretrain.py
浏览文件 @
2703ac5b
...
...
@@ -25,6 +25,7 @@ from mindspore.train.model import Model
from
mindspore.train.parallel_utils
import
ParallelMode
from
mindspore.nn.wrap.loss_scale
import
DynamicLossScaleUpdateCell
from
mindspore.train.callback
import
Callback
,
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.model_zoo.Bert_NEZHA
import
BertNetworkWithLoss
,
BertTrainOneStepCell
,
BertTrainOneStepWithLossScaleCell
from
mindspore.nn.optim
import
Lamb
,
Momentum
,
AdamWeightDecayDynamicLR
from
dataset
import
create_bert_dataset
...
...
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
浏览文件 @
2703ac5b
...
...
@@ -40,6 +40,7 @@ enum MatchCountPriority : int {
MATCH_DTYPE_COUNT
=
MATCH_COUNT_PRIORITY_BEGIN
,
MATCH_FORMAT_COUNT
,
MATCH_SPECIAL_FORMAT_COUNT
,
MATCH_DEFAULT_FORMAT_COUNT
,
MATCH_OUTPUT_DTYPE_COUNT
,
MATCH_COUNT_PRIORITY_END
};
...
...
@@ -73,7 +74,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
auto
pre_output_format
=
AnfAlgo
::
GetPrevNodeOutputFormat
(
cnode
,
index
);
if
(
AnfAlgo
::
IsFeatureMapInput
(
cnode
,
index
)
&&
kNeedTransFormatSet
.
find
(
pre_output_format
)
!=
kNeedTransFormatSet
.
end
())
{
priority_matched_format
=
!
is_init
?
pr
iority_matched_format
:
pre_output
_format
;
priority_matched_format
=
!
is_init
?
pr
e_output_format
:
priority_matched
_format
;
is_init
=
true
;
}
// feature map has two or more special format;
...
...
@@ -83,7 +84,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
auto
input_shape_size
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
cnode
,
index
).
size
();
need_change_nd
=
(
need_change_nd
||
(
input_shape_size
!=
4
&&
input_shape_size
>
1
));
}
if
(
need_change_nd
)
{
if
(
need_change_nd
&&
priority_matched_format
!=
kOpFormat_FRAC_NZ
)
{
priority_matched_format
=
kOpFormat_DEFAULT
;
}
AnfAlgo
::
SetNodeAttr
(
kPriChoosenFormat
,
MakeValue
(
priority_matched_format
),
cnode
);
...
...
@@ -134,6 +135,9 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
if
(
kernel_build_info
.
GetInputFormat
(
input_index
)
==
pri_match_format
)
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_SPECIAL_FORMAT_COUNT
]
+=
base_score
;
}
if
(
kernel_build_info
.
GetInputFormat
(
input_index
)
==
kOpFormat_DEFAULT
)
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_DEFAULT_FORMAT_COUNT
]
+=
base_score
;
}
}
for
(
size_t
output_index
=
0
;
output_index
<
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
++
output_index
)
{
...
...
@@ -410,10 +414,10 @@ std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo(
if
(
kernel_info_list
.
empty
())
{
return
nullptr
;
}
std
::
vector
<
int
>
most_match_counts
=
{
-
1
,
-
1
,
-
1
,
-
1
};
std
::
vector
<
int
>
most_match_counts
=
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
};
size_t
selected_index
=
0
;
for
(
size_t
info_index
=
0
;
info_index
<
kernel_info_list
.
size
();
++
info_index
)
{
std
::
vector
<
int
>
cur_kernel_info_match_counts
=
{
0
,
0
,
0
,
0
};
std
::
vector
<
int
>
cur_kernel_info_match_counts
=
{
0
,
0
,
0
,
0
,
0
};
auto
kernel_build_info
=
*
(
kernel_info_list
[
info_index
]);
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
kernel_info_ptr
=
kernel_info_list
[
info_index
];
UpdateCurMatchCounts
(
*
kernel_info_ptr
,
kernel_node
,
&
cur_kernel_info_match_counts
);
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
2703ac5b
...
...
@@ -89,8 +89,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ClipByNormNoDivSquareSumFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambUpdateWithLRRuleFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ConfusionSoftmaxGradRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVWithDecayRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextRightRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambUpdateWithLrV2
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ReshapeTransposeFusion
>
());
...
...
mindspore/ops/_op_impl/tbe/tanh.py
浏览文件 @
2703ac5b
...
...
@@ -29,6 +29,8 @@ tanh_op_info = TBERegOp("Tanh") \
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
)
\
.
get_op_info
()
...
...
tests/st/networks/models/bert/bert_tdt_lossscale.py
浏览文件 @
2703ac5b
...
...
@@ -170,8 +170,8 @@ def test_bert_tdt():
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value
=
np
.
array
(
callback
.
loss_list
)
expect_loss_value
=
[
12.1918
125
,
11.966035
,
11.972114
,
11.982189
,
11.973948
,
12.610932
,
12.17564
,
12.840248
,
12.40
294
,
12.621653
]
expect_loss_value
=
[
12.1918
26
,
11.966009
,
11.972208
,
11.98216
,
11.973932
,
12.611078
,
12.17554
,
12.840299
,
12.40
3329
,
12.621632
]
print
(
"loss value: {}"
.
format
(
loss_value
))
assert
np
.
allclose
(
loss_value
,
expect_loss_value
,
0.00001
,
0.00001
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录