未验证 提交 0b4bb023 编写于 作者: L Leo Chen 提交者: GitHub

Add static mode check on data() (#27495)

* add static check on data()

* follow comments

* fix ut
上级 a5b32637
...@@ -19,10 +19,12 @@ from paddle.fluid import core ...@@ -19,10 +19,12 @@ from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_dtype, check_type from paddle.fluid.data_feeder import check_dtype, check_type
from ..utils import deprecated from ..utils import deprecated
from paddle.fluid.framework import static_only
__all__ = ['data'] __all__ = ['data']
@static_only
@deprecated(since="2.0.0", update_to="paddle.static.data") @deprecated(since="2.0.0", update_to="paddle.static.data")
def data(name, shape, dtype='float32', lod_level=0): def data(name, shape, dtype='float32', lod_level=0):
""" """
......
...@@ -217,7 +217,16 @@ def _dygraph_not_support_(func): ...@@ -217,7 +217,16 @@ def _dygraph_not_support_(func):
def _dygraph_only_(func): def _dygraph_only_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
assert in_dygraph_mode( assert in_dygraph_mode(
), "We Only support %s in dynamic mode, please call 'paddle.disable_static()' to enter dynamic mode." % func.__name__ ), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__
return func(*args, **kwargs)
return __impl__
def _static_only_(func):
def __impl__(*args, **kwargs):
assert not in_dygraph_mode(
), "We only support '%s()' in static graph mode, please call 'paddle.enable_static()' to enter static graph mode." % func.__name__
return func(*args, **kwargs) return func(*args, **kwargs)
return __impl__ return __impl__
...@@ -260,6 +269,7 @@ def deprecate_stat_dict(func): ...@@ -260,6 +269,7 @@ def deprecate_stat_dict(func):
dygraph_not_support = wrap_decorator(_dygraph_not_support_) dygraph_not_support = wrap_decorator(_dygraph_not_support_)
dygraph_only = wrap_decorator(_dygraph_only_) dygraph_only = wrap_decorator(_dygraph_only_)
static_only = wrap_decorator(_static_only_)
fake_interface_only = wrap_decorator(_fake_interface_only_) fake_interface_only = wrap_decorator(_fake_interface_only_)
......
...@@ -31,6 +31,7 @@ from ..unique_name import generate as unique_name ...@@ -31,6 +31,7 @@ from ..unique_name import generate as unique_name
import logging import logging
from ..data_feeder import check_dtype, check_type from ..data_feeder import check_dtype, check_type
from paddle.fluid.framework import static_only
__all__ = [ __all__ = [
'data', 'read_file', 'double_buffer', 'py_reader', 'data', 'read_file', 'double_buffer', 'py_reader',
...@@ -38,6 +39,7 @@ __all__ = [ ...@@ -38,6 +39,7 @@ __all__ = [
] ]
@static_only
def data(name, def data(name,
shape, shape,
append_batch_size=True, append_batch_size=True,
......
...@@ -99,5 +99,17 @@ class TestApiStaticDataError(unittest.TestCase): ...@@ -99,5 +99,17 @@ class TestApiStaticDataError(unittest.TestCase):
self.assertRaises(TypeError, test_shape_type) self.assertRaises(TypeError, test_shape_type)
class TestApiErrorWithDynamicMode(unittest.TestCase):
def test_error(self):
with program_guard(Program(), Program()):
paddle.disable_static()
self.assertRaises(AssertionError, fluid.data, 'a', [2, 25])
self.assertRaises(
AssertionError, fluid.layers.data, 'b', shape=[2, 25])
self.assertRaises(
AssertionError, paddle.static.data, 'c', shape=[2, 25])
paddle.enable_static()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -72,6 +72,7 @@ class TestDeprecatedDocorator(unittest.TestCase): ...@@ -72,6 +72,7 @@ class TestDeprecatedDocorator(unittest.TestCase):
test old fluid elementwise_mul api, it should fire Warinng function, test old fluid elementwise_mul api, it should fire Warinng function,
which insert the Warinng info on top of API's doc string. which insert the Warinng info on top of API's doc string.
""" """
paddle.enable_static()
# Initialization # Initialization
x = fluid.data(name='x', shape=[3, 2, 1], dtype='float32') x = fluid.data(name='x', shape=[3, 2, 1], dtype='float32')
...@@ -80,6 +81,7 @@ class TestDeprecatedDocorator(unittest.TestCase): ...@@ -80,6 +81,7 @@ class TestDeprecatedDocorator(unittest.TestCase):
# captured # captured
captured = get_warning_index(fluid.data) captured = get_warning_index(fluid.data)
paddle.disable_static()
# testting # testting
self.assertGreater(expected, captured) self.assertGreater(expected, captured)
......
...@@ -19,10 +19,12 @@ from paddle.fluid import core, Variable ...@@ -19,10 +19,12 @@ from paddle.fluid import core, Variable
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.framework import static_only
__all__ = ['data', 'InputSpec'] __all__ = ['data', 'InputSpec']
@static_only
def data(name, shape, dtype=None, lod_level=0): def data(name, shape, dtype=None, lod_level=0):
""" """
**Data Layer** **Data Layer**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册