提交 c8d33568 编写于 作者: Y yujianfeng

Add an new output to FusedMulApplyMomentum

上级 f23bfe0d
......@@ -23,6 +23,7 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kAccumIndex = 1;
bool CheckValueNodeInputOfMul(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<ValueNode>()) {
......@@ -79,9 +80,19 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
input_names_value[3] = "x1";
input_names_value.emplace_back("x2");
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node);
new_node->set_abstract(node->abstract());
auto node_to_output = cnode->input(kAccumIndex + 1);
MS_EXCEPTION_IF_NULL(node_to_output);
AbstractBasePtrList abstract_list{node->abstract(), node_to_output->abstract()};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
new_node->set_abstract(abstract_tuple);
new_node->set_scope(node->scope());
return new_node;
// Create Output
std::vector<AnfNodePtr> new_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, &new_outputs);
if (new_outputs.size() != kFusedMulApplyMomentumOutputNum) {
MS_LOG(EXCEPTION) << "Failed to create outputs of " << new_node->DebugString();
}
return new_outputs[0];
}
} // namespace opt
} // namespace mindspore
......@@ -92,6 +92,7 @@ constexpr size_t kApplyMomentumInputNum = 6;
constexpr size_t kBiasAddInputNum = 3;
constexpr size_t kTopkInputNum = 3;
constexpr size_t kLarsV2InputNum = 5;
constexpr size_t kFusedMulApplyMomentumOutputNum = 2;
enum FusedBatchNormInput {
kX = 1,
......
......@@ -47,6 +47,6 @@ def test_momentum_lossscale_fusion(tag):
@fns
def after(input0, input1, input2, input3, input4):
return make_tuple(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant))
return make_tuple(tuple_getitem(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant), 0))
return fns[tag]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册