diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc index e5aa5709855965047e52efe3b8caa78e62c14e6c..1030946980f86fcf265864e7cd44afcaf111cf32 100644 --- a/paddle/phi/ops/compat/einsum_sig.cc +++ b/paddle/phi/ops/compat/einsum_sig.cc @@ -17,14 +17,10 @@ limitations under the License. */ namespace phi { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.OutputSize("XShape") > 0 && ctx.OutputSize("InnerCache") > 0) { - return KernelSignature("einsum_raw", - {"Operands"}, - {"equation"}, - {"Out", "InnerCache", "XShape"}); - } else { - return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"}); - } + return KernelSignature("einsum_raw", + {"Operands"}, + {"equation"}, + {"Out", "InnerCache", "XShape"}); } KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { diff --git a/paddle/phi/ops/compat/squeeze_sig.cc b/paddle/phi/ops/compat/squeeze_sig.cc index a251b9f537ccfa54d07daa3e9238fbc863c864ff..4ca45903acfa00386c9cbfed191ddb9b50443230 100644 --- a/paddle/phi/ops/compat/squeeze_sig.cc +++ b/paddle/phi/ops/compat/squeeze_sig.cc @@ -18,12 +18,8 @@ namespace phi { KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.HasOutput("XShape")) { - return KernelSignature( - "squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); - } else { - return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out"}); - } + return KernelSignature( + "squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); } KernelSignature SqueezeGradOpArgumentMapping( diff --git a/paddle/phi/ops/compat/unsqueeze_sig.cc b/paddle/phi/ops/compat/unsqueeze_sig.cc index a2f184e7150b8a18a4af176168ce337e3db8db6c..568097298b7acc86584b2de962e9ea06d73a26f5 100644 --- a/paddle/phi/ops/compat/unsqueeze_sig.cc +++ b/paddle/phi/ops/compat/unsqueeze_sig.cc @@ -18,33 +18,18 @@ namespace phi { KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.HasOutput("XShape")) { - if (ctx.InputSize("AxesTensorList") > 0) { - VLOG(2) << "unsqueeze2 in AxesTensorList"; - return KernelSignature("unsqueeze_with_xshape", - {"X"}, - {"AxesTensorList"}, - {"Out", "XShape"}); - } else if (ctx.InputSize("AxesTensor") > 0) { - VLOG(2) << "unsqueeze2 in AxesTensor"; - return KernelSignature( - "unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"}); - } else { - VLOG(2) << "unsqueeze2 in axes"; - return KernelSignature( - "unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); - } + if (ctx.InputSize("AxesTensorList") > 0) { + VLOG(2) << "unsqueeze2 in AxesTensorList"; + return KernelSignature( + "unsqueeze_with_xshape", {"X"}, {"AxesTensorList"}, {"Out", "XShape"}); + } else if (ctx.InputSize("AxesTensor") > 0) { + VLOG(2) << "unsqueeze2 in AxesTensor"; + return KernelSignature( + "unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"}); } else { - if (ctx.InputSize("AxesTensorList") > 0) { - VLOG(2) << "unsqueeze2 in AxesTensorList"; - return KernelSignature("unsqueeze", {"X"}, {"AxesTensorList"}, {"Out"}); - } else if (ctx.InputSize("AxesTensor") > 0) { - VLOG(2) << "unsqueeze2 in AxesTensor"; - return KernelSignature("unsqueeze", {"X"}, {"AxesTensor"}, {"Out"}); - } else { - VLOG(2) << "unsqueeze2 in axes"; - return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out"}); - } + VLOG(2) << "unsqueeze2 in axes"; + return KernelSignature( + "unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); } }