提交 87371be6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2127 Fix exception when FusedBatchNorm's first input 's shape dims not equal 4

Merge pull request !2127 from huanghui/fix-fusebatchnorm-split
......@@ -28,14 +28,14 @@
namespace mindspore {
namespace opt {
namespace {
void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
std::vector<AnfNodePtr> *bn_training_reduce_outputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_cnode);
if (bn_cnode->inputs().size() != kBnInputNum) {
MS_LOG(EXCEPTION) << "BN node has wrong input size";
MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString();
return false;
}
// All the inputs of BNTrainingReduce are from the inputs of BN
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName))};
bn_training_reduce_inputs.push_back(bn_cnode->input(1));
......@@ -45,8 +45,9 @@ void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
MS_EXCEPTION_IF_NULL(kernel_info);
bn_training_reduce->set_kernel_info(kernel_info);
std::vector<size_t> bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0);
if (bn_shape_i0.size() != kShape4dDims) {
MS_LOG(EXCEPTION) << "Get shape of FusedBatchNorm fail";
if (bn_shape_i0.size() < kShape2dDims) {
MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims;
return false;
}
std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[1]};
auto types = {kNumberTypeFloat32, kNumberTypeFloat32};
......@@ -56,6 +57,7 @@ void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
AnfAlgo::CopyNodeAttrs(bn_cnode, bn_training_reduce);
CreateMultipleOutputsOfAnfNode(graph, bn_training_reduce, kBNTrainingReduceOutputNum, bn_training_reduce_outputs);
return true;
}
AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
......@@ -99,11 +101,15 @@ AnfNodePtr SplitFusedBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNo
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() < kBnInputNum) {
MS_LOG(EXCEPTION) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs.";
MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs.";
return nullptr;
}
// Create BNTrainingReduce node and get outputs of BNTrainingReduce
std::vector<AnfNodePtr> bn_training_reduce_outputs;
CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs);
if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) {
MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split";
return nullptr;
}
if (bn_training_reduce_outputs.size() != kBN1OutputNum) {
MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册