diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index e5835dc56d4718267e4715f77b35b862023591b3..56dc40cc7c9a69086e3751c916358ec171ee06d4 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2408,7 +2408,8 @@ void SgdInferMeta(const MetaTensor& param, void StackInferMeta(const std::vector& x, int axis, - MetaTensor* out) { + MetaTensor* out, + MetaConfig config) { PADDLE_ENFORCE_GT(x.size(), 0UL, phi::errors::InvalidArgument( @@ -2416,17 +2417,10 @@ void StackInferMeta(const std::vector& x, " received value is:%d.", x.size())); const auto& input_dims = GetMetaTensorsDim(x); - for (size_t i = 1; i < input_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(input_dims[i], - input_dims[0], - phi::errors::InvalidArgument( - "Dims of all Inputs(X) must be the same, but" - " received input %d dim is:%d not equal to input 0" - " dim:%d.", - i, - input_dims[i], - input_dims[0])); - } + // we reuse concat logic to compute out_dim. we set concat_axis==-1 to check + // every axis in input_tensors. + auto out_dim = + phi::funcs::ComputeAndCheckShape(config.is_runtime, input_dims, -1); int rank = input_dims[0].size(); PADDLE_ENFORCE_GE( axis, @@ -2445,7 +2439,7 @@ void StackInferMeta(const std::vector& x, rank, axis)); if (axis < 0) axis += (rank + 1); - auto vec = phi::vectorize(input_dims[0]); + auto vec = phi::vectorize(out_dim); vec.insert(vec.begin() + axis, input_dims.size()); out->set_dims(phi::make_ddim(vec)); out->set_dtype(x.at(0)->dtype()); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 0296509e43750b97fc467d7ed1e03e066e6b816e..4e95303f1a02566343e3bc8a87ebb6661f3c0b01 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -452,7 +452,8 @@ void SgdInferMeta(const MetaTensor& param, void StackInferMeta(const std::vector& x, int axis, - MetaTensor* out); + MetaTensor* out, + MetaConfig config = MetaConfig()); void UnchangedMultiInferMeta(const std::vector& x, std::vector out); diff --git a/python/paddle/fluid/tests/unittests/test_stack_op.py b/python/paddle/fluid/tests/unittests/test_stack_op.py index f7b1254c880ef255681f3252d024a3bfa75a8afb..8fc004838da6c88a4aaecdc7ffda46476afac5f7 100644 --- a/python/paddle/fluid/tests/unittests/test_stack_op.py +++ b/python/paddle/fluid/tests/unittests/test_stack_op.py @@ -18,6 +18,7 @@ import paddle import paddle.fluid as fluid from op_test import OpTest, convert_float_to_uint16 import paddle.fluid.core as core +from paddle.fluid.framework import Program, program_guard class TestStackOpBase(OpTest): @@ -268,5 +269,30 @@ class API_DygraphTest(unittest.TestCase): self.assertRaises(Exception, paddle.stack, x) +class TestStackOpWithNegativeShape(unittest.TestCase): + + def test_out(self): + main_prg, startup_prg = Program(), Program() + with program_guard(main_prg, startup_prg): + b = paddle.static.data(name='b', shape=[-1], dtype='int64') + e = paddle.static.data(name='e', shape=[3], dtype='int64') + k = paddle.stack([b, e], axis=0) + exe = paddle.static.Executor() + exe.run(startup_prg) + out = exe.run(main_prg, + feed={ + 'b': np.ones([ + 3, + ]).astype("int64"), + 'e': np.zeros([ + 3, + ]).astype("int64") + }, + fetch_list=[k]) + np.testing.assert_allclose(out[0], + np.array([[1, 1, 1], [0, 0, 0]]), + rtol=1e-05) + + if __name__ == '__main__': unittest.main()