提交 2703ac5b 编写于 作者: W wanghua

fix bert percision problem

上级 2a1aad0f
...@@ -25,6 +25,7 @@ from mindspore.train.model import Model ...@@ -25,6 +25,7 @@ from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
from dataset import create_bert_dataset from dataset import create_bert_dataset
......
...@@ -40,6 +40,7 @@ enum MatchCountPriority : int { ...@@ -40,6 +40,7 @@ enum MatchCountPriority : int {
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
MATCH_FORMAT_COUNT, MATCH_FORMAT_COUNT,
MATCH_SPECIAL_FORMAT_COUNT, MATCH_SPECIAL_FORMAT_COUNT,
MATCH_DEFAULT_FORMAT_COUNT,
MATCH_OUTPUT_DTYPE_COUNT, MATCH_OUTPUT_DTYPE_COUNT,
MATCH_COUNT_PRIORITY_END MATCH_COUNT_PRIORITY_END
}; };
...@@ -73,7 +74,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { ...@@ -73,7 +74,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
if (AnfAlgo::IsFeatureMapInput(cnode, index) && if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) { kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) {
priority_matched_format = !is_init ? priority_matched_format : pre_output_format; priority_matched_format = !is_init ? pre_output_format : priority_matched_format;
is_init = true; is_init = true;
} }
// feature map has two or more special format; // feature map has two or more special format;
...@@ -83,7 +84,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { ...@@ -83,7 +84,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size();
need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1));
} }
if (need_change_nd) { if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) {
priority_matched_format = kOpFormat_DEFAULT; priority_matched_format = kOpFormat_DEFAULT;
} }
AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
...@@ -134,6 +135,9 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons ...@@ -134,6 +135,9 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
(*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score;
} }
if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) {
(*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score;
}
} }
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
...@@ -410,10 +414,10 @@ std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo( ...@@ -410,10 +414,10 @@ std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo(
if (kernel_info_list.empty()) { if (kernel_info_list.empty()) {
return nullptr; return nullptr;
} }
std::vector<int> most_match_counts = {-1, -1, -1, -1}; std::vector<int> most_match_counts = {-1, -1, -1, -1, -1};
size_t selected_index = 0; size_t selected_index = 0;
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0}; std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0};
auto kernel_build_info = *(kernel_info_list[info_index]); auto kernel_build_info = *(kernel_info_list[info_index]);
std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index]; std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index];
UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
......
...@@ -89,8 +89,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ...@@ -89,8 +89,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>()); ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>()); ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
......
...@@ -29,6 +29,8 @@ tanh_op_info = TBERegOp("Tanh") \ ...@@ -29,6 +29,8 @@ tanh_op_info = TBERegOp("Tanh") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.get_op_info() .get_op_info()
......
...@@ -170,8 +170,8 @@ def test_bert_tdt(): ...@@ -170,8 +170,8 @@ def test_bert_tdt():
# assertion occurs while the loss value, overflow state or loss_scale value is wrong # assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list) loss_value = np.array(callback.loss_list)
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982189, 11.973948, 12.610932, 12.17564, 12.840248, expect_loss_value = [12.191826, 11.966009, 11.972208, 11.98216, 11.973932, 12.611078, 12.17554, 12.840299,
12.40294, 12.621653] 12.403329, 12.621632]
print("loss value: {}".format(loss_value)) print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001) assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册