提交 d6f043b0 编写于 作者: Y yujianfeng

Fix batch norm bert fission for control depend case

上级 4009a8e5
......@@ -25,6 +25,7 @@ namespace opt {
namespace {
const std::vector<int> kOutputIndex{0, 3, 4, 5};
constexpr size_t kBatchNormRealOutputNum = 3;
constexpr size_t kBatchNormRealInputNum = 3;
bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
MS_EXCEPTION_IF_NULL(n1);
......@@ -56,6 +57,9 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s
for (const auto &node_index : manager->node_users()[bn]) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
continue;
}
auto tuple_getiterm_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
......@@ -77,7 +81,10 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
MS_EXCEPTION_IF_NULL(bn);
auto bn_cnode = bn->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(bn_cnode);
CheckCNodeInputSize(bn_cnode, kBatchNormInputNum + 1);
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) {
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than "
<< kBatchNormRealInputNum + 1;
}
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)};
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
......@@ -100,7 +107,10 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
MS_EXCEPTION_IF_NULL(bn);
auto bn_cnode = bn->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(bn_cnode);
CheckCNodeInputSize(bn_cnode, kBatchNormInputNum + 1);
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) {
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than "
<< kBatchNormRealInputNum + 1;
}
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum
<< ", but it is " << bn_training_reduce_outputs.size();
......@@ -164,7 +174,8 @@ const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, c
(void)manager->Replace(output, bn_training_update_v2_outputs[output_index]);
output_index++;
}
return nullptr;
// Return the new node for control depends.
return bn_training_update_v2;
}
} // namespace opt
} // namespace mindspore
......@@ -90,9 +90,19 @@ ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const {
}
auto tensor_ptr = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
auto *tensor_data = static_cast<float *>(tensor_ptr->data_c());
if (tensor_ptr->data_type() == kNumberTypeFloat16) {
auto *half_data = static_cast<const Eigen::half *>(tensor_ptr->data_c());
MS_EXCEPTION_IF_NULL(half_data);
float float_data = Eigen::half_impl::half_to_float(half_data[0]);
return MakeValue(float_data);
} else if (tensor_ptr->data_type() == kNumberTypeFloat32) {
auto *tensor_data = static_cast<const float *>(tensor_ptr->data_c());
MS_EXCEPTION_IF_NULL(tensor_data);
return MakeValue(tensor_data[0]);
} else {
MS_LOG(WARNING) << "The factor data type of value node " << value_node->DebugString() << " is not fp16 or fp32";
return nullptr;
}
}
AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
......
......@@ -65,7 +65,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimTupleGetItem->name()) {
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
return nullptr;
}
if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册