From df790b9b94b53531dbdf3cae1407ec5f29500b6a Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 5 Aug 2022 16:44:13 +0800 Subject: [PATCH] fix bugs when clip xshape (#44898) --- paddle/phi/ops/compat/einsum_sig.cc | 12 +++------ paddle/phi/ops/compat/squeeze_sig.cc | 8 ++---- paddle/phi/ops/compat/unsqueeze_sig.cc | 37 ++++++++------------------ 3 files changed, 17 insertions(+), 40 deletions(-) diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc index e5aa5709855..1030946980f 100644 --- a/paddle/phi/ops/compat/einsum_sig.cc +++ b/paddle/phi/ops/compat/einsum_sig.cc @@ -17,14 +17,10 @@ limitations under the License. */ namespace phi { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.OutputSize("XShape") > 0 && ctx.OutputSize("InnerCache") > 0) { - return KernelSignature("einsum_raw", - {"Operands"}, - {"equation"}, - {"Out", "InnerCache", "XShape"}); - } else { - return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"}); - } + return KernelSignature("einsum_raw", + {"Operands"}, + {"equation"}, + {"Out", "InnerCache", "XShape"}); } KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { diff --git a/paddle/phi/ops/compat/squeeze_sig.cc b/paddle/phi/ops/compat/squeeze_sig.cc index a251b9f537c..4ca45903acf 100644 --- a/paddle/phi/ops/compat/squeeze_sig.cc +++ b/paddle/phi/ops/compat/squeeze_sig.cc @@ -18,12 +18,8 @@ namespace phi { KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.HasOutput("XShape")) { - return KernelSignature( - "squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); - } else { - return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out"}); - } + return KernelSignature( + "squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); } KernelSignature SqueezeGradOpArgumentMapping( diff --git a/paddle/phi/ops/compat/unsqueeze_sig.cc b/paddle/phi/ops/compat/unsqueeze_sig.cc index a2f184e7150..568097298b7 100644 --- a/paddle/phi/ops/compat/unsqueeze_sig.cc +++ b/paddle/phi/ops/compat/unsqueeze_sig.cc @@ -18,33 +18,18 @@ namespace phi { KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.HasOutput("XShape")) { - if (ctx.InputSize("AxesTensorList") > 0) { - VLOG(2) << "unsqueeze2 in AxesTensorList"; - return KernelSignature("unsqueeze_with_xshape", - {"X"}, - {"AxesTensorList"}, - {"Out", "XShape"}); - } else if (ctx.InputSize("AxesTensor") > 0) { - VLOG(2) << "unsqueeze2 in AxesTensor"; - return KernelSignature( - "unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"}); - } else { - VLOG(2) << "unsqueeze2 in axes"; - return KernelSignature( - "unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); - } + if (ctx.InputSize("AxesTensorList") > 0) { + VLOG(2) << "unsqueeze2 in AxesTensorList"; + return KernelSignature( + "unsqueeze_with_xshape", {"X"}, {"AxesTensorList"}, {"Out", "XShape"}); + } else if (ctx.InputSize("AxesTensor") > 0) { + VLOG(2) << "unsqueeze2 in AxesTensor"; + return KernelSignature( + "unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"}); } else { - if (ctx.InputSize("AxesTensorList") > 0) { - VLOG(2) << "unsqueeze2 in AxesTensorList"; - return KernelSignature("unsqueeze", {"X"}, {"AxesTensorList"}, {"Out"}); - } else if (ctx.InputSize("AxesTensor") > 0) { - VLOG(2) << "unsqueeze2 in AxesTensor"; - return KernelSignature("unsqueeze", {"X"}, {"AxesTensor"}, {"Out"}); - } else { - VLOG(2) << "unsqueeze2 in axes"; - return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out"}); - } + VLOG(2) << "unsqueeze2 in axes"; + return KernelSignature( + "unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); } } -- GitLab