提交 14df7711 编写于 作者: H huanghui

fix confusion_softmax_grad_rule pass

上级 507b63ea
......@@ -47,7 +47,7 @@ void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_n
const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const {
return VectorRef(
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})});
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input1_, input0_})})});
}
const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
......
......@@ -41,7 +41,7 @@ def test_confusion_softmax_grad_rule(tag):
@fns
def before(input0, input1):
res = mul(input0, input1)
res = mul(input1, input0)
# input axis will be convert to attr in ConstructKernelGraph step
res = reduce_sum(res, axis)
res = sub(input0, res)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册