提交 d4295465 编写于 作者: W wanghua 提交者: 高东海

modify bert test file

上级 2de97f25
...@@ -82,12 +82,6 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: ...@@ -82,12 +82,6 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
} }
return true; return true;
}; };
if (AnfAlgo::GetCNodeName(kernel_node) == "LayerNormBetaGammaBackprop" ||
AnfAlgo::GetCNodeName(kernel_node) == "LayerNormXBackprop") {
if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, 0) != kernel_build_info.GetInputFormat(0)) {
return true;
}
}
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
...@@ -161,7 +155,7 @@ bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best ...@@ -161,7 +155,7 @@ bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best
return false; return false;
} }
} }
return true; return false;
} }
void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node, void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
......
...@@ -27,7 +27,7 @@ from mindspore.common.tensor import Tensor ...@@ -27,7 +27,7 @@ from mindspore.common.tensor import Tensor
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell
from mindspore.nn.optim import Lamb from mindspore.nn.optim import Momentum
from mindspore import log as logger from mindspore import log as logger
_current_dir = os.path.dirname(os.path.realpath(__file__)) _current_dir = os.path.dirname(os.path.realpath(__file__))
DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册