From 1d1bebf063c788c736aec3e738bea255094c837b Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Sun, 6 Sep 2020 12:37:06 +0800 Subject: [PATCH] [cherry-pick 2.0-beta] Raise RuntimeError if run the callable object decorated by '@paddle.jit.to_static' not in dynamic mode. (#26750) (#27053) Change-Id: I21a07cc2bc39acb753ab8fc00c72e269ddef0df1 --- .../dygraph_to_static/program_translator.py | 14 ++++++++-- .../dygraph_to_static/test_declarative.py | 27 +++++++++++++++++-- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index ad7d6dfd3f9..cb489af44d0 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -24,6 +24,7 @@ import warnings import gast from paddle.fluid import framework +from paddle.fluid import in_dygraph_mode from paddle.fluid.dygraph import layers from paddle.fluid.data_feeder import check_type from paddle.fluid.layers.utils import flatten @@ -32,6 +33,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data +from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info @@ -283,13 +285,21 @@ class StaticLayer(object): Return: Outputs of decorated function. """ + # 1. call dygraph function directly if not enable `declarative` if not self._program_trans.enable_declarative: - warnings.warn( - "The decorator '@paddle.jit.to_static' doesn't work when setting ProgramTranslator.enable=False. " + logging_utils.warn( + "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. " "We will just return dygraph output.") return self._call_dygraph_function(*args, **kwargs) + if not in_dygraph_mode() and self._program_trans.enable_declarative: + raise RuntimeError( + "Failed to run the callable object {} decorated by '@paddle.jit.to_static', " + "because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the " + "following API: paddle.disable_static().".format( + self.dygraph_function)) + # 2. trace ops from dygraph layers and cache the generated program. args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs) try: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py index e5a33e59a3b..949286f63ef 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -13,13 +13,15 @@ # limitations under the License. import numpy as np +import unittest + import paddle -from paddle.static import InputSpec import paddle.fluid as fluid +from paddle.static import InputSpec from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram -import unittest +from test_basic_api_transformation import dyfunc_to_variable program_trans = ProgramTranslator() @@ -181,6 +183,9 @@ def foo_func(a, b, c=1, d=2): class TestDifferentInputSpecCacheProgram(unittest.TestCase): + def setUp(self): + program_trans.enable(True) + def test_with_different_input(self): with fluid.dygraph.guard(fluid.CPUPlace()): x_data = np.ones([16, 10]).astype('float32') @@ -272,5 +277,23 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): foo_3.concrete_program +class TestDeclarativeAPI(unittest.TestCase): + def test_error(self): + func = declarative(dyfunc_to_variable) + + paddle.enable_static() + + # Failed to run the callable object decorated by '@paddle.jit.to_static' + # if it does NOT in dynamic mode. + with self.assertRaises(RuntimeError): + func(np.ones(5).astype("int32")) + + program_trans.enable(False) + with self.assertRaises(AssertionError): + # AssertionError: We Only support to_variable in imperative mode, + # please use fluid.dygraph.guard() as context to run it in imperative Mode + func(np.ones(5).astype("int32")) + + if __name__ == '__main__': unittest.main() -- GitLab