From 2853f0c4f979b95015466113376448abc2daef4d Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 28 Oct 2020 02:29:04 -0500 Subject: [PATCH] Set static shape for shape tensor with constant [part 1] (#28275) * set static shape for shape tensor with constant * remove debug code * fix typo * add ut * refine code * refine example --- python/paddle/fluid/layers/nn.py | 1 + python/paddle/fluid/layers/utils.py | 53 ++++++++++++++++++- ...tatic_shape_inferrence_for_shape_tensor.py | 31 +++++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_static_shape_inferrence_for_shape_tensor.py diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d5157abf1a9..adde9cbd19f 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -15145,6 +15145,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0, helper.append_op( type="uniform_random", inputs=inputs, attrs=attrs, outputs={"Out": out}) + utils.try_set_static_shape_tensor(out, shape) return out diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index 2095c9957e7..0d278d493bc 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -17,7 +17,7 @@ import collections import copy import six import numpy as np -from ..framework import Variable +from ..framework import Variable, in_dygraph_mode from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..layer_helper import LayerHelper from sys import version_info @@ -378,3 +378,54 @@ def check_shape(shape): raise TypeError( "All elements in ``shape`` must be integers when it's a list or tuple" ) + + +def try_set_static_shape_tensor(tensor, shape): + """Try to set static shape of tensor from a shape tensor. + + For example, + + import paddle + paddle.enable_static() + data = paddle.static.data(name="x", shape=[-1, 2], dtype='float32') + shape = paddle.shape(data) # shape should be [-1, 2] instead of [-1, -1] + x = paddle.uniform(shape) + print(x.shape) + # (-1, 2) + + """ + if not in_dygraph_mode(): + # static mode, and shape is not all inferred (contains -1) + if -1 in tensor.shape: + if isinstance(shape, Variable): + shape = try_get_constant_shape_from_tensor(shape) + if shape: + tensor.desc.set_shape(shape) + + +def try_get_constant_shape_from_tensor(shape_tensor): + """Try to get shape from a tensor with constant value. + + For example, + + import paddle + paddle.enable_static() + data = paddle.static.data(name="x", shape=[-1, 2], dtype='float32') + shape = paddle.shape(data) # shape should be [-1, 2] instead of [-1, -1] + x = paddle.uniform(shape) + print(x.shape) + # (-1, 2) + + """ + if not in_dygraph_mode(): + try: + if shape_tensor.op is not None: + generate_op = shape_tensor.op + if generate_op.type == 'shape': + var = shape_tensor.block.vars[generate_op.input_arg_names[ + 0]] + return var.shape + except: + return None + + return None diff --git a/python/paddle/fluid/tests/unittests/test_static_shape_inferrence_for_shape_tensor.py b/python/paddle/fluid/tests/unittests/test_static_shape_inferrence_for_shape_tensor.py new file mode 100644 index 00000000000..2c6d646baf5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_static_shape_inferrence_for_shape_tensor.py @@ -0,0 +1,31 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest + + +class StaticShapeInferrenceTest(unittest.TestCase): + def test_static_graph(self): + paddle.enable_static() + data = paddle.fluid.layers.data( + name="x", shape=[-1, 2], dtype='float32') + shape = paddle.fluid.layers.shape(data) # shape should be [-1, 2] + x = paddle.fluid.layers.uniform_random(shape) + self.assertEqual(x.shape, data.shape) + paddle.disable_static() + + +if __name__ == '__main__': + unittest.main() -- GitLab