未验证 提交 1d4d89ba 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Add data_layer_not_check (#23351)

1. Add data_layer_not_check because it is needed in dygraph_to_static where input can be variable size
2. Remove warnings in static analysis because python cannot do exact static analysis
上级 d2801060
...@@ -23,10 +23,10 @@ import warnings ...@@ -23,10 +23,10 @@ import warnings
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid import core, executor from paddle.fluid import core, executor
from paddle.fluid.data import data
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
__all__ = ['ProgramTranslator', 'convert_function_with_cache'] __all__ = ['ProgramTranslator', 'convert_function_with_cache']
...@@ -186,9 +186,9 @@ class ProgramCache(object): ...@@ -186,9 +186,9 @@ class ProgramCache(object):
batch_data, numpy.ndarray batch_data, numpy.ndarray
), "Input {} should be numpy.ndarray, but received {}.".format( ), "Input {} should be numpy.ndarray, but received {}.".format(
feed_name, type(batch_data)) feed_name, type(batch_data))
feed_layer = data( feed_layer = data_layer_not_check(
name=feed_name, name=feed_name,
shape=[-1] + list(batch_data.shape[1:]), shape=list(batch_data.shape),
dtype=str(batch_data.dtype)) dtype=str(batch_data.dtype))
self._inputs.append(feed_layer) self._inputs.append(feed_layer)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from __future__ import print_function from __future__ import print_function
import gast import gast
import warnings
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list
__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor'] __all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor']
...@@ -71,19 +70,12 @@ class NodeVarType(object): ...@@ -71,19 +70,12 @@ class NodeVarType(object):
] ]
if in_type1 not in supported_types: if in_type1 not in supported_types:
warnings.warn("Binary Op on un supported in_type1 = %d " %
(in_type1))
return NodeVarType.UNKNOWN return NodeVarType.UNKNOWN
if in_type2 not in supported_types: if in_type2 not in supported_types:
warnings.warn("Binary Op on un supported in_type2 = %d " %
(in_type2))
return NodeVarType.UNKNOWN return NodeVarType.UNKNOWN
forbidden_types = [NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR] forbidden_types = [NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR]
if in_type1 in forbidden_types and in_type2 in forbidden_types: if in_type1 in forbidden_types and in_type2 in forbidden_types:
warnings.warn(
"Binary Op on un supported types: in_type1 = %d, in_type2 = %d"
% (in_type1, in_type2))
return NodeVarType.UNKNOWN return NodeVarType.UNKNOWN
return max(in_type1, in_type2) return max(in_type1, in_type2)
......
...@@ -17,19 +17,71 @@ from __future__ import print_function ...@@ -17,19 +17,71 @@ from __future__ import print_function
import six import six
import gast import gast
from paddle.fluid import core
from paddle.fluid.layers import fill_constant from paddle.fluid.layers import fill_constant
from paddle.fluid.layer_helper import LayerHelper
__all__ = ['to_static_variable_gast_node', 'create_static_variable_gast_node'] __all__ = [
'to_static_variable_gast_node', 'create_static_variable_gast_node',
'data_layer_not_check'
]
def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
"""
This function creates a variable on the global block. Unlike
`paddle.fluid.data` , the created variable doesn't check the dtype and the
shape of feed data because dygraph input data can be variable-length.
This API is used in translating dygraph into static graph.
Note:
The default :code:`stop_gradient` attribute of the Variable created by
this API is true, which means the gradient won't be passed backward
through the data Varaible. Set :code:`var.stop_gradient = False` If
user would like to pass backward gradient.
Args:
name (str): The name/alias of the variable, see :ref:`api_guide_Name`
for more details.
shape (list|tuple): List|Tuple of integers declaring the shape. You can
set "None" at a dimension to indicate the dimension can be of any
size. For example, it is useful to set changeable batch size as "None"
dtype (np.dtype|VarType|str, optional): The type of the data. Supported
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
uint8. Default: float32
lod_level (int, optional): The LoD level of the LoDTensor. Usually users
don't have to set this value. For more details about when and how to
use LoD level, see :ref:`user_guide_lod_tensor` . Default: 0
Returns:
Variable: The global variable that gives access to the data.
"""
helper = LayerHelper('data', **locals())
shape = list(shape)
for i in six.moves.range(len(shape)):
if shape[i] is None:
shape[i] = -1
return helper.create_global_variable(
name=name,
shape=shape,
dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=False)
def to_static_variable_gast_node(name): def to_static_variable_gast_node(name):
func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})".format( func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func\
name, name) .to_static_variable({})".format(name, name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
def create_static_variable_gast_node(name): def create_static_variable_gast_node(name):
func_code = "{} = fluid.data(name='{}', shape=[-1], dtype='float32')".format( func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func\
.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format(
name, name) name, name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
......
...@@ -18,8 +18,34 @@ import gast ...@@ -18,8 +18,34 @@ import gast
import six import six
import unittest import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
class TestDataLayerNotCheck(unittest.TestCase):
def test_create_none_shape(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
d = data_layer_not_check(name="d", shape=(None, -1, 3))
self.assertEqual(d.shape, (-1, -1, 3))
self.assertEqual(d.name, "d")
def test_feed_mismatch_shape(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
d = data_layer_not_check(name="d", shape=(1, 2, 3))
feed_in_data = np.random.uniform(size=[1, 2, 4]).astype(np.float32)
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
ret = exe.run(main_program,
feed={d.name: feed_in_data},
fetch_list=[d.name])
self.assertTrue(np.allclose(ret, feed_in_data))
class TestVariableTransFunc(unittest.TestCase): class TestVariableTransFunc(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册