diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index a7c8d131fb9c81e024806f666368b9f6b94798d5..d05b9fafa1053a8267cd251edd2d3b5de3149aeb 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -82,12 +82,6 @@ bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel:: } return true; }; - if (AnfAlgo::GetCNodeName(kernel_node) == "LayerNormBetaGammaBackprop" || - AnfAlgo::GetCNodeName(kernel_node) == "LayerNormXBackprop") { - if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, 0) != kernel_build_info.GetInputFormat(0)) { - return true; - } - } if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); @@ -161,7 +155,7 @@ bool PriorityChooseItem(const std::vector &cur_item, std::vector *best return false; } } - return true; + return false; } void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr &kernel_node, diff --git a/tests/st/networks/models/bert/bert_tdt_no_lossscale.py b/tests/st/networks/models/bert/bert_tdt_no_lossscale.py index 6f3ffc7daddddfce963ad8fe5b9417589b7612a4..9cc11997e68e0b421257347787d276cae6813c4b 100644 --- a/tests/st/networks/models/bert/bert_tdt_no_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_no_lossscale.py @@ -27,7 +27,7 @@ from mindspore.common.tensor import Tensor from mindspore.train.model import Model from mindspore.train.callback import Callback from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell -from mindspore.nn.optim import Lamb +from mindspore.nn.optim import Momentum from mindspore import log as logger _current_dir = os.path.dirname(os.path.realpath(__file__)) DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"]