提交 90d98aa6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1694 Check the input size of BatchNorm before fission in bert

Merge pull request !1694 from YuJianfeng/master
......@@ -149,8 +149,17 @@ const BaseRef BatchNormBertFission::DefinePattern() const {
const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> bn_outputs;
if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) {
MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed";
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() != kBatchNormRealInputNum + 1) {
MS_LOG(INFO) << "The input size of BatchNorm should be " << kBatchNormRealInputNum
<< ". The node should not be changed";
return nullptr;
}
AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node);
......
......@@ -28,7 +28,7 @@ class TestHWBatchNormBertFission : public BackendCommon {
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) {
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fission) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp_x{32, 64, 112, 112};
......@@ -40,6 +40,23 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) {
args_spec_list.push_back(y_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto ret = kg->get_return();
EXPECT_NE(ret, nullptr);
auto make_tuple0 = ret->input(1);
EXPECT_NE(make_tuple0, nullptr);
auto tuple_getitem0 = make_tuple0->cast<CNodePtr>()->input(1);
EXPECT_NE(tuple_getitem0, nullptr);
auto make_tuple1 = tuple_getitem0->cast<CNodePtr>()->input(1);
EXPECT_NE(make_tuple1, nullptr);
auto tuple_getitem1 = make_tuple1->cast<CNodePtr>()->input(1);
EXPECT_NE(tuple_getitem1, nullptr);
auto bn = tuple_getitem1->cast<CNodePtr>()->input(1);
EXPECT_NE(bn, nullptr);
auto bn_cnode = bn->cast<CNodePtr>();
EXPECT_NE(bn_cnode, nullptr);
auto inputs = bn_cnode->inputs();
std::vector<AnfNodePtr> new_inputs(inputs.begin(), inputs.begin() + 4);
bn_cnode->set_inputs(new_inputs);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
......@@ -50,5 +67,27 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_no_fission) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp_x{32, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
std::vector<int> shp_y{64};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
AbstractBasePtrList args_spec_list{x_abstract};
for (size_t i = 0; i < 4; ++i) {
args_spec_list.push_back(y_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BatchNormBertFission>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
}
} // namespace opt
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册