未验证 提交 df790b9b 编写于 作者: Y YuanRisheng 提交者: GitHub

fix bugs when clip xshape (#44898)

上级 62a98130
...@@ -17,14 +17,10 @@ limitations under the License. */ ...@@ -17,14 +17,10 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.OutputSize("XShape") > 0 && ctx.OutputSize("InnerCache") > 0) { return KernelSignature("einsum_raw",
return KernelSignature("einsum_raw", {"Operands"},
{"Operands"}, {"equation"},
{"equation"}, {"Out", "InnerCache", "XShape"});
{"Out", "InnerCache", "XShape"});
} else {
return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"});
}
} }
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......
...@@ -18,12 +18,8 @@ ...@@ -18,12 +18,8 @@
namespace phi { namespace phi {
KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) { return KernelSignature(
return KernelSignature( "squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
"squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
} else {
return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out"});
}
} }
KernelSignature SqueezeGradOpArgumentMapping( KernelSignature SqueezeGradOpArgumentMapping(
......
...@@ -18,33 +18,18 @@ ...@@ -18,33 +18,18 @@
namespace phi { namespace phi {
KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) { if (ctx.InputSize("AxesTensorList") > 0) {
if (ctx.InputSize("AxesTensorList") > 0) { VLOG(2) << "unsqueeze2 in AxesTensorList";
VLOG(2) << "unsqueeze2 in AxesTensorList"; return KernelSignature(
return KernelSignature("unsqueeze_with_xshape", "unsqueeze_with_xshape", {"X"}, {"AxesTensorList"}, {"Out", "XShape"});
{"X"}, } else if (ctx.InputSize("AxesTensor") > 0) {
{"AxesTensorList"}, VLOG(2) << "unsqueeze2 in AxesTensor";
{"Out", "XShape"}); return KernelSignature(
} else if (ctx.InputSize("AxesTensor") > 0) { "unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
}
} else { } else {
if (ctx.InputSize("AxesTensorList") > 0) { VLOG(2) << "unsqueeze2 in axes";
VLOG(2) << "unsqueeze2 in AxesTensorList"; return KernelSignature(
return KernelSignature("unsqueeze", {"X"}, {"AxesTensorList"}, {"Out"}); "unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature("unsqueeze", {"X"}, {"AxesTensor"}, {"Out"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out"});
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册