未验证 提交 59e8382d 编写于 作者: A Aurelius84 提交者: GitHub

[Phi]Fix expand_sig infershape BUG under static graph mode (#41936)

* [Phi]Fix expand_sig infershape BUG under static graph mode

* [Phi]Fix expand_sig infershape BUG under static graph mode

* [Phi]Fix unittest

* [Phi]Fix unittest
上级 84c8096c
...@@ -17,6 +17,11 @@ ...@@ -17,6 +17,11 @@
namespace phi { namespace phi {
KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) {
const auto& shape = paddle::any_cast<std::vector<int>>(ctx.Attr("shape"));
// Infer output shape by Attr("shape") in CompileTime if it is specified.
if (!ctx.IsRuntime() && !shape.empty()) {
return KernelSignature("expand", {"X"}, {"shape"}, {"Out"});
}
if (ctx.HasInput("Shape")) { if (ctx.HasInput("Shape")) {
return KernelSignature("expand", {"X"}, {"Shape"}, {"Out"}); return KernelSignature("expand", {"X"}, {"Shape"}, {"Out"});
} else if (ctx.InputSize("expand_shapes_tensor") > 0) { } else if (ctx.InputSize("expand_shapes_tensor") > 0) {
...@@ -27,6 +32,12 @@ KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -27,6 +32,12 @@ KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
const auto& shape = paddle::any_cast<std::vector<int>>(ctx.Attr("shape"));
// Infer output shape by Attr("shape") in CompileTime if it is specified.
if (!ctx.IsRuntime() && !shape.empty()) {
return KernelSignature(
"expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"});
}
if (ctx.HasInput("Shape")) { if (ctx.HasInput("Shape")) {
return KernelSignature( return KernelSignature(
"expand_grad", {"X", "Out@GRAD"}, {"Shape"}, {"X@GRAD"}); "expand_grad", {"X", "Out@GRAD"}, {"Shape"}, {"X@GRAD"});
......
...@@ -231,6 +231,18 @@ class TestExpandV2API(unittest.TestCase): ...@@ -231,6 +231,18 @@ class TestExpandV2API(unittest.TestCase):
assert np.array_equal(res_3, np.tile(input, (1, 1))) assert np.array_equal(res_3, np.tile(input, (1, 1)))
class TestExpandInferShape(unittest.TestCase):
def test_shape_with_var(self):
with program_guard(Program(), Program()):
x = paddle.static.data(shape=[-1, 1, 3], name='x')
fake_var = paddle.randn([2, 3])
target_shape = [
-1, paddle.shape(fake_var)[0], paddle.shape(fake_var)[1]
]
out = paddle.expand(x, shape=target_shape)
self.assertListEqual(list(out.shape), [-1, -1, -1])
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册