未验证 提交 0b4bb023 编写于 作者: L Leo Chen 提交者: GitHub

Add static mode check on data() (#27495)

* add static check on data()

* follow comments

* fix ut
上级 a5b32637
......@@ -19,10 +19,12 @@ from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_dtype, check_type
from ..utils import deprecated
from paddle.fluid.framework import static_only
__all__ = ['data']
@static_only
@deprecated(since="2.0.0", update_to="paddle.static.data")
def data(name, shape, dtype='float32', lod_level=0):
"""
......
......@@ -217,7 +217,16 @@ def _dygraph_not_support_(func):
def _dygraph_only_(func):
def __impl__(*args, **kwargs):
assert in_dygraph_mode(
), "We Only support %s in dynamic mode, please call 'paddle.disable_static()' to enter dynamic mode." % func.__name__
), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__
return func(*args, **kwargs)
return __impl__
def _static_only_(func):
def __impl__(*args, **kwargs):
assert not in_dygraph_mode(
), "We only support '%s()' in static graph mode, please call 'paddle.enable_static()' to enter static graph mode." % func.__name__
return func(*args, **kwargs)
return __impl__
......@@ -260,6 +269,7 @@ def deprecate_stat_dict(func):
dygraph_not_support = wrap_decorator(_dygraph_not_support_)
dygraph_only = wrap_decorator(_dygraph_only_)
static_only = wrap_decorator(_static_only_)
fake_interface_only = wrap_decorator(_fake_interface_only_)
......
......@@ -31,6 +31,7 @@ from ..unique_name import generate as unique_name
import logging
from ..data_feeder import check_dtype, check_type
from paddle.fluid.framework import static_only
__all__ = [
'data', 'read_file', 'double_buffer', 'py_reader',
......@@ -38,6 +39,7 @@ __all__ = [
]
@static_only
def data(name,
shape,
append_batch_size=True,
......
......@@ -99,5 +99,17 @@ class TestApiStaticDataError(unittest.TestCase):
self.assertRaises(TypeError, test_shape_type)
class TestApiErrorWithDynamicMode(unittest.TestCase):
def test_error(self):
with program_guard(Program(), Program()):
paddle.disable_static()
self.assertRaises(AssertionError, fluid.data, 'a', [2, 25])
self.assertRaises(
AssertionError, fluid.layers.data, 'b', shape=[2, 25])
self.assertRaises(
AssertionError, paddle.static.data, 'c', shape=[2, 25])
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -72,6 +72,7 @@ class TestDeprecatedDocorator(unittest.TestCase):
test old fluid elementwise_mul api, it should fire Warinng function,
which insert the Warinng info on top of API's doc string.
"""
paddle.enable_static()
# Initialization
x = fluid.data(name='x', shape=[3, 2, 1], dtype='float32')
......@@ -80,6 +81,7 @@ class TestDeprecatedDocorator(unittest.TestCase):
# captured
captured = get_warning_index(fluid.data)
paddle.disable_static()
# testting
self.assertGreater(expected, captured)
......
......@@ -19,10 +19,12 @@ from paddle.fluid import core, Variable
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_type
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.framework import static_only
__all__ = ['data', 'InputSpec']
@static_only
def data(name, shape, dtype=None, lod_level=0):
"""
**Data Layer**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册