未验证 提交 435ab2aa 编写于 作者: L liym27 提交者: GitHub

Raise RuntimeError if run the callable object decorated by...

Raise RuntimeError if run the callable object decorated by '@paddle.jit.to_static' not in dynamic mode.  (#26750)
上级 2d2c31a6
...@@ -24,6 +24,7 @@ import warnings ...@@ -24,6 +24,7 @@ import warnings
import gast import gast
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid import in_dygraph_mode
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
...@@ -32,6 +33,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph ...@@ -32,6 +33,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info
from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map
from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info
...@@ -283,13 +285,21 @@ class StaticLayer(object): ...@@ -283,13 +285,21 @@ class StaticLayer(object):
Return: Return:
Outputs of decorated function. Outputs of decorated function.
""" """
# 1. call dygraph function directly if not enable `declarative` # 1. call dygraph function directly if not enable `declarative`
if not self._program_trans.enable_declarative: if not self._program_trans.enable_declarative:
warnings.warn( logging_utils.warn(
"The decorator '@paddle.jit.to_static' doesn't work when setting ProgramTranslator.enable=False. " "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. "
"We will just return dygraph output.") "We will just return dygraph output.")
return self._call_dygraph_function(*args, **kwargs) return self._call_dygraph_function(*args, **kwargs)
if not in_dygraph_mode() and self._program_trans.enable_declarative:
raise RuntimeError(
"Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
"because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
"following API: paddle.disable_static().".format(
self.dygraph_function))
# 2. trace ops from dygraph layers and cache the generated program. # 2. trace ops from dygraph layers and cache the generated program.
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs) args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
try: try:
......
...@@ -13,13 +13,15 @@ ...@@ -13,13 +13,15 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import unittest
import paddle import paddle
from paddle.static import InputSpec
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.static import InputSpec
from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram
import unittest from test_basic_api_transformation import dyfunc_to_variable
program_trans = ProgramTranslator() program_trans = ProgramTranslator()
...@@ -181,6 +183,9 @@ def foo_func(a, b, c=1, d=2): ...@@ -181,6 +183,9 @@ def foo_func(a, b, c=1, d=2):
class TestDifferentInputSpecCacheProgram(unittest.TestCase): class TestDifferentInputSpecCacheProgram(unittest.TestCase):
def setUp(self):
program_trans.enable(True)
def test_with_different_input(self): def test_with_different_input(self):
with fluid.dygraph.guard(fluid.CPUPlace()): with fluid.dygraph.guard(fluid.CPUPlace()):
x_data = np.ones([16, 10]).astype('float32') x_data = np.ones([16, 10]).astype('float32')
...@@ -272,5 +277,23 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): ...@@ -272,5 +277,23 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
foo_3.concrete_program foo_3.concrete_program
class TestDeclarativeAPI(unittest.TestCase):
def test_error(self):
func = declarative(dyfunc_to_variable)
paddle.enable_static()
# Failed to run the callable object decorated by '@paddle.jit.to_static'
# if it does NOT in dynamic mode.
with self.assertRaises(RuntimeError):
func(np.ones(5).astype("int32"))
program_trans.enable(False)
with self.assertRaises(AssertionError):
# AssertionError: We Only support to_variable in imperative mode,
# please use fluid.dygraph.guard() as context to run it in imperative Mode
func(np.ones(5).astype("int32"))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册