未验证 提交 5a4ceb32 编写于 作者: X xiongkun 提交者: GitHub

[InferMeta] add compile-time infermeta logic for stack infermeta. (#45528)

* add compile-time infermeta logic for stack infermeta.

* add unittest for stack infermeta where -1 exists in shapes.

* remove backward changes.
上级 36739748
...@@ -2408,7 +2408,8 @@ void SgdInferMeta(const MetaTensor& param, ...@@ -2408,7 +2408,8 @@ void SgdInferMeta(const MetaTensor& param,
void StackInferMeta(const std::vector<const MetaTensor*>& x, void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis, int axis,
MetaTensor* out) { MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_GT(x.size(), PADDLE_ENFORCE_GT(x.size(),
0UL, 0UL,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
...@@ -2416,17 +2417,10 @@ void StackInferMeta(const std::vector<const MetaTensor*>& x, ...@@ -2416,17 +2417,10 @@ void StackInferMeta(const std::vector<const MetaTensor*>& x,
" received value is:%d.", " received value is:%d.",
x.size())); x.size()));
const auto& input_dims = GetMetaTensorsDim(x); const auto& input_dims = GetMetaTensorsDim(x);
for (size_t i = 1; i < input_dims.size(); ++i) { // we reuse concat logic to compute out_dim. we set concat_axis==-1 to check
PADDLE_ENFORCE_EQ(input_dims[i], // every axis in input_tensors.
input_dims[0], auto out_dim =
phi::errors::InvalidArgument( phi::funcs::ComputeAndCheckShape(config.is_runtime, input_dims, -1);
"Dims of all Inputs(X) must be the same, but"
" received input %d dim is:%d not equal to input 0"
" dim:%d.",
i,
input_dims[i],
input_dims[0]));
}
int rank = input_dims[0].size(); int rank = input_dims[0].size();
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
axis, axis,
...@@ -2445,7 +2439,7 @@ void StackInferMeta(const std::vector<const MetaTensor*>& x, ...@@ -2445,7 +2439,7 @@ void StackInferMeta(const std::vector<const MetaTensor*>& x,
rank, rank,
axis)); axis));
if (axis < 0) axis += (rank + 1); if (axis < 0) axis += (rank + 1);
auto vec = phi::vectorize<int>(input_dims[0]); auto vec = phi::vectorize<int>(out_dim);
vec.insert(vec.begin() + axis, input_dims.size()); vec.insert(vec.begin() + axis, input_dims.size());
out->set_dims(phi::make_ddim(vec)); out->set_dims(phi::make_ddim(vec));
out->set_dtype(x.at(0)->dtype()); out->set_dtype(x.at(0)->dtype());
......
...@@ -452,7 +452,8 @@ void SgdInferMeta(const MetaTensor& param, ...@@ -452,7 +452,8 @@ void SgdInferMeta(const MetaTensor& param,
void StackInferMeta(const std::vector<const MetaTensor*>& x, void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis, int axis,
MetaTensor* out); MetaTensor* out,
MetaConfig config = MetaConfig());
void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x, void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out); std::vector<MetaTensor*> out);
......
...@@ -18,6 +18,7 @@ import paddle ...@@ -18,6 +18,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest, convert_float_to_uint16 from op_test import OpTest, convert_float_to_uint16
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import Program, program_guard
class TestStackOpBase(OpTest): class TestStackOpBase(OpTest):
...@@ -268,5 +269,30 @@ class API_DygraphTest(unittest.TestCase): ...@@ -268,5 +269,30 @@ class API_DygraphTest(unittest.TestCase):
self.assertRaises(Exception, paddle.stack, x) self.assertRaises(Exception, paddle.stack, x)
class TestStackOpWithNegativeShape(unittest.TestCase):
def test_out(self):
main_prg, startup_prg = Program(), Program()
with program_guard(main_prg, startup_prg):
b = paddle.static.data(name='b', shape=[-1], dtype='int64')
e = paddle.static.data(name='e', shape=[3], dtype='int64')
k = paddle.stack([b, e], axis=0)
exe = paddle.static.Executor()
exe.run(startup_prg)
out = exe.run(main_prg,
feed={
'b': np.ones([
3,
]).astype("int64"),
'e': np.zeros([
3,
]).astype("int64")
},
fetch_list=[k])
np.testing.assert_allclose(out[0],
np.array([[1, 1, 1], [0, 0, 0]]),
rtol=1e-05)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册