未验证 提交 2853f0c4 编写于 作者: L Leo Chen 提交者: GitHub

Set static shape for shape tensor with constant [part 1] (#28275)

* set static shape for shape tensor with constant

* remove debug code

* fix typo

* add ut

* refine code

* refine example
上级 4dc8c44b
......@@ -15145,6 +15145,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0,
helper.append_op(
type="uniform_random", inputs=inputs, attrs=attrs,
outputs={"Out": out})
utils.try_set_static_shape_tensor(out, shape)
return out
......
......@@ -17,7 +17,7 @@ import collections
import copy
import six
import numpy as np
from ..framework import Variable
from ..framework import Variable, in_dygraph_mode
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..layer_helper import LayerHelper
from sys import version_info
......@@ -378,3 +378,54 @@ def check_shape(shape):
raise TypeError(
"All elements in ``shape`` must be integers when it's a list or tuple"
)
def try_set_static_shape_tensor(tensor, shape):
"""Try to set static shape of tensor from a shape tensor.
For example,
import paddle
paddle.enable_static()
data = paddle.static.data(name="x", shape=[-1, 2], dtype='float32')
shape = paddle.shape(data) # shape should be [-1, 2] instead of [-1, -1]
x = paddle.uniform(shape)
print(x.shape)
# (-1, 2)
"""
if not in_dygraph_mode():
# static mode, and shape is not all inferred (contains -1)
if -1 in tensor.shape:
if isinstance(shape, Variable):
shape = try_get_constant_shape_from_tensor(shape)
if shape:
tensor.desc.set_shape(shape)
def try_get_constant_shape_from_tensor(shape_tensor):
"""Try to get shape from a tensor with constant value.
For example,
import paddle
paddle.enable_static()
data = paddle.static.data(name="x", shape=[-1, 2], dtype='float32')
shape = paddle.shape(data) # shape should be [-1, 2] instead of [-1, -1]
x = paddle.uniform(shape)
print(x.shape)
# (-1, 2)
"""
if not in_dygraph_mode():
try:
if shape_tensor.op is not None:
generate_op = shape_tensor.op
if generate_op.type == 'shape':
var = shape_tensor.block.vars[generate_op.input_arg_names[
0]]
return var.shape
except:
return None
return None
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
class StaticShapeInferrenceTest(unittest.TestCase):
def test_static_graph(self):
paddle.enable_static()
data = paddle.fluid.layers.data(
name="x", shape=[-1, 2], dtype='float32')
shape = paddle.fluid.layers.shape(data) # shape should be [-1, 2]
x = paddle.fluid.layers.uniform_random(shape)
self.assertEqual(x.shape, data.shape)
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册