未验证 提交 1d1bebf0 编写于 作者: L liym27 提交者: GitHub

[cherry-pick 2.0-beta] Raise RuntimeError if run the callable object decorated...

[cherry-pick 2.0-beta] Raise RuntimeError if run the callable object decorated by '@paddle.jit.to_static' not in dynamic mode.  (#26750) (#27053)

Change-Id: I21a07cc2bc39acb753ab8fc00c72e269ddef0df1
上级 a068168f
......@@ -24,6 +24,7 @@ import warnings
import gast
from paddle.fluid import framework
from paddle.fluid import in_dygraph_mode
from paddle.fluid.dygraph import layers
from paddle.fluid.data_feeder import check_type
from paddle.fluid.layers.utils import flatten
......@@ -32,6 +33,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info
from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map
from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info
......@@ -283,13 +285,21 @@ class StaticLayer(object):
Return:
Outputs of decorated function.
"""
# 1. call dygraph function directly if not enable `declarative`
if not self._program_trans.enable_declarative:
warnings.warn(
"The decorator '@paddle.jit.to_static' doesn't work when setting ProgramTranslator.enable=False. "
logging_utils.warn(
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. "
"We will just return dygraph output.")
return self._call_dygraph_function(*args, **kwargs)
if not in_dygraph_mode() and self._program_trans.enable_declarative:
raise RuntimeError(
"Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
"because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
"following API: paddle.disable_static().".format(
self.dygraph_function))
# 2. trace ops from dygraph layers and cache the generated program.
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
try:
......
......@@ -13,13 +13,15 @@
# limitations under the License.
import numpy as np
import unittest
import paddle
from paddle.static import InputSpec
import paddle.fluid as fluid
from paddle.static import InputSpec
from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram
import unittest
from test_basic_api_transformation import dyfunc_to_variable
program_trans = ProgramTranslator()
......@@ -181,6 +183,9 @@ def foo_func(a, b, c=1, d=2):
class TestDifferentInputSpecCacheProgram(unittest.TestCase):
def setUp(self):
program_trans.enable(True)
def test_with_different_input(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
x_data = np.ones([16, 10]).astype('float32')
......@@ -272,5 +277,23 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
foo_3.concrete_program
class TestDeclarativeAPI(unittest.TestCase):
def test_error(self):
func = declarative(dyfunc_to_variable)
paddle.enable_static()
# Failed to run the callable object decorated by '@paddle.jit.to_static'
# if it does NOT in dynamic mode.
with self.assertRaises(RuntimeError):
func(np.ones(5).astype("int32"))
program_trans.enable(False)
with self.assertRaises(AssertionError):
# AssertionError: We Only support to_variable in imperative mode,
# please use fluid.dygraph.guard() as context to run it in imperative Mode
func(np.ones(5).astype("int32"))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册