未验证 提交 df790b9b 编写于 作者: Y YuanRisheng 提交者: GitHub

fix bugs when clip xshape (#44898)

上级 62a98130
...@@ -17,14 +17,10 @@ limitations under the License. */ ...@@ -17,14 +17,10 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.OutputSize("XShape") > 0 && ctx.OutputSize("InnerCache") > 0) {
return KernelSignature("einsum_raw", return KernelSignature("einsum_raw",
{"Operands"}, {"Operands"},
{"equation"}, {"equation"},
{"Out", "InnerCache", "XShape"}); {"Out", "InnerCache", "XShape"});
} else {
return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"});
}
} }
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......
...@@ -18,12 +18,8 @@ ...@@ -18,12 +18,8 @@
namespace phi { namespace phi {
KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) {
return KernelSignature( return KernelSignature(
"squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); "squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
} else {
return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out"});
}
} }
KernelSignature SqueezeGradOpArgumentMapping( KernelSignature SqueezeGradOpArgumentMapping(
......
...@@ -18,13 +18,10 @@ ...@@ -18,13 +18,10 @@
namespace phi { namespace phi {
KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) {
if (ctx.InputSize("AxesTensorList") > 0) { if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList"; VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature("unsqueeze_with_xshape", return KernelSignature(
{"X"}, "unsqueeze_with_xshape", {"X"}, {"AxesTensorList"}, {"Out", "XShape"});
{"AxesTensorList"},
{"Out", "XShape"});
} else if (ctx.InputSize("AxesTensor") > 0) { } else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor"; VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature( return KernelSignature(
...@@ -34,18 +31,6 @@ KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -34,18 +31,6 @@ KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); "unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
} }
} else {
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature("unsqueeze", {"X"}, {"AxesTensorList"}, {"Out"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature("unsqueeze", {"X"}, {"AxesTensor"}, {"Out"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out"});
}
}
} }
KernelSignature UnsqueezeGradOpArgumentMapping( KernelSignature UnsqueezeGradOpArgumentMapping(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册