From 0b4bb023a7ef93669e9007f7e6241f24c6e98cb6 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 25 Sep 2020 21:35:40 +0800 Subject: [PATCH] Add static mode check on data() (#27495) * add static check on data() * follow comments * fix ut --- python/paddle/fluid/data.py | 2 ++ python/paddle/fluid/framework.py | 12 +++++++++++- python/paddle/fluid/layers/io.py | 2 ++ python/paddle/fluid/tests/unittests/test_data.py | 12 ++++++++++++ .../tests/unittests/test_deprecated_decorator.py | 2 ++ python/paddle/static/input.py | 2 ++ 6 files changed, 31 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/data.py b/python/paddle/fluid/data.py index dc57e9f71e..05ea66f544 100644 --- a/python/paddle/fluid/data.py +++ b/python/paddle/fluid/data.py @@ -19,10 +19,12 @@ from paddle.fluid import core from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.data_feeder import check_dtype, check_type from ..utils import deprecated +from paddle.fluid.framework import static_only __all__ = ['data'] +@static_only @deprecated(since="2.0.0", update_to="paddle.static.data") def data(name, shape, dtype='float32', lod_level=0): """ diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 797b32f5d4..c7e66bb287 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -217,7 +217,16 @@ def _dygraph_not_support_(func): def _dygraph_only_(func): def __impl__(*args, **kwargs): assert in_dygraph_mode( - ), "We Only support %s in dynamic mode, please call 'paddle.disable_static()' to enter dynamic mode." % func.__name__ + ), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__ + return func(*args, **kwargs) + + return __impl__ + + +def _static_only_(func): + def __impl__(*args, **kwargs): + assert not in_dygraph_mode( + ), "We only support '%s()' in static graph mode, please call 'paddle.enable_static()' to enter static graph mode." % func.__name__ return func(*args, **kwargs) return __impl__ @@ -260,6 +269,7 @@ def deprecate_stat_dict(func): dygraph_not_support = wrap_decorator(_dygraph_not_support_) dygraph_only = wrap_decorator(_dygraph_only_) +static_only = wrap_decorator(_static_only_) fake_interface_only = wrap_decorator(_fake_interface_only_) diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index d513d44acf..6b98dea429 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -31,6 +31,7 @@ from ..unique_name import generate as unique_name import logging from ..data_feeder import check_dtype, check_type +from paddle.fluid.framework import static_only __all__ = [ 'data', 'read_file', 'double_buffer', 'py_reader', @@ -38,6 +39,7 @@ __all__ = [ ] +@static_only def data(name, shape, append_batch_size=True, diff --git a/python/paddle/fluid/tests/unittests/test_data.py b/python/paddle/fluid/tests/unittests/test_data.py index 8070148f8b..98739f6e16 100644 --- a/python/paddle/fluid/tests/unittests/test_data.py +++ b/python/paddle/fluid/tests/unittests/test_data.py @@ -99,5 +99,17 @@ class TestApiStaticDataError(unittest.TestCase): self.assertRaises(TypeError, test_shape_type) +class TestApiErrorWithDynamicMode(unittest.TestCase): + def test_error(self): + with program_guard(Program(), Program()): + paddle.disable_static() + self.assertRaises(AssertionError, fluid.data, 'a', [2, 25]) + self.assertRaises( + AssertionError, fluid.layers.data, 'b', shape=[2, 25]) + self.assertRaises( + AssertionError, paddle.static.data, 'c', shape=[2, 25]) + paddle.enable_static() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_deprecated_decorator.py b/python/paddle/fluid/tests/unittests/test_deprecated_decorator.py index 2a80e20d69..97b6594eb3 100755 --- a/python/paddle/fluid/tests/unittests/test_deprecated_decorator.py +++ b/python/paddle/fluid/tests/unittests/test_deprecated_decorator.py @@ -72,6 +72,7 @@ class TestDeprecatedDocorator(unittest.TestCase): test old fluid elementwise_mul api, it should fire Warinng function, which insert the Warinng info on top of API's doc string. """ + paddle.enable_static() # Initialization x = fluid.data(name='x', shape=[3, 2, 1], dtype='float32') @@ -80,6 +81,7 @@ class TestDeprecatedDocorator(unittest.TestCase): # captured captured = get_warning_index(fluid.data) + paddle.disable_static() # testting self.assertGreater(expected, captured) diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index eb70320ea7..d7a3cfcdb9 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -19,10 +19,12 @@ from paddle.fluid import core, Variable from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.data_feeder import check_type from paddle.fluid.framework import convert_np_dtype_to_dtype_ +from paddle.fluid.framework import static_only __all__ = ['data', 'InputSpec'] +@static_only def data(name, shape, dtype=None, lod_level=0): """ **Data Layer** -- GitLab