未验证 提交 df790b9b 编写于 作者: Y YuanRisheng 提交者: GitHub

fix bugs when clip xshape (#44898)

上级 62a98130
......@@ -17,14 +17,10 @@ limitations under the License. */
namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.OutputSize("XShape") > 0 && ctx.OutputSize("InnerCache") > 0) {
return KernelSignature("einsum_raw",
{"Operands"},
{"equation"},
{"Out", "InnerCache", "XShape"});
} else {
return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"});
}
return KernelSignature("einsum_raw",
{"Operands"},
{"equation"},
{"Out", "InnerCache", "XShape"});
}
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......
......@@ -18,12 +18,8 @@
namespace phi {
KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) {
return KernelSignature(
"squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
} else {
return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out"});
}
return KernelSignature(
"squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
}
KernelSignature SqueezeGradOpArgumentMapping(
......
......@@ -18,33 +18,18 @@
namespace phi {
KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) {
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature("unsqueeze_with_xshape",
{"X"},
{"AxesTensorList"},
{"Out", "XShape"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
}
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"AxesTensorList"}, {"Out", "XShape"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
} else {
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature("unsqueeze", {"X"}, {"AxesTensorList"}, {"Out"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature("unsqueeze", {"X"}, {"AxesTensor"}, {"Out"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out"});
}
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册