未验证 提交 7ba6279a 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] Refine error msg of @to_static if not in imperative mode (#27371)

* refine error mesg
上级 dd4c2d86
......@@ -370,6 +370,7 @@ class StaticLayer(object):
Returns:
Traced ConcreteProgram and executable translated Layer.
"""
# 1. unify args/kwargs and replace Tensor with InputSpec
if len(args) != len(self._function_spec.args_name):
args, kwargs = self._function_spec.unified_args_and_kwargs(args,
......@@ -522,6 +523,19 @@ def _switch_declarative_mode_guard_(is_declarative=True):
_in_declarative_mode_ = original_val
def _verify_init_in_dynamic_mode(class_instance):
"""
Verifies the instance is initialized in dynamic mode.
"""
if isinstance(class_instance, layers.Layer):
if not class_instance._init_in_dynamic_mode:
raise RuntimeError(
" `paddle.jit.to_static` is only available in dynamic mode. Please call `paddle.disable_static()` before "
"initializing your Layer class `{}` . Because parameters of Layer class should be initialized firstly "
"in dynamic mode while applying transformation.".format(
class_instance))
class ConcreteProgram(object):
__slots__ = [
......@@ -554,6 +568,9 @@ class ConcreteProgram(object):
func_spec(FunctionSpec): A FunctionSpec instance for decorated function.
input_spec(list[InputSpec]):
"""
# verify the instance is initialized in imperative mode.
_verify_init_in_dynamic_mode(class_instance)
# Transforms dygraph function into static function and caches it.
dygraph_function = func_spec.dygraph_function
static_func = convert_to_static(dygraph_function)
......
......@@ -91,6 +91,7 @@ class Layer(core.Layer):
self._helper = LayerObjectHelper(self._full_name)
self._built = False
self._dtype = dtype
self._init_in_dynamic_mode = framework.in_dygraph_mode()
self._parameters = collections.OrderedDict()
# Buffers the variable (not parameter) created in layer
......
......@@ -358,5 +358,24 @@ class TestDecorateModelDirectly(unittest.TestCase):
self.assertListEqual(list(input_shape), [-1, 16, 10])
class TestErrorWithInitFromStaticMode(unittest.TestCase):
def test_raise_error(self):
# disable imperative
paddle.enable_static()
net = SimpleNet()
with self.assertRaisesRegexp(RuntimeError,
"only available in dynamic mode"):
net.forward.concrete_program
with self.assertRaisesRegexp(RuntimeError,
"only available in dynamic mode"):
net.forward.inputs
with self.assertRaisesRegexp(RuntimeError,
"only available in dynamic mode"):
net.forward.outputs
if __name__ == '__main__':
unittest.main()
......@@ -21,6 +21,7 @@ import numpy as np
import textwrap
import unittest
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import declarative
......@@ -279,5 +280,33 @@ class TestEnableDeclarative(unittest.TestCase):
static_output.numpy(), dygraph_output.numpy(), atol=1e-4))
class Net(fluid.dygraph.layers.Layer):
def __init__(self):
super(Net, self).__init__()
def forward(self, x):
return x + 1
class TestErrorWithInitFromStaticMode(unittest.TestCase):
def setUp(self):
self.program_translator = ProgramTranslator()
self.x = np.random.randn(10, 32).astype('float32')
def test_raise_error(self):
# disable imperative
paddle.enable_static()
net = Net()
self.program_translator.enable(True)
with self.assertRaisesRegexp(RuntimeError,
"only available in dynamic mode"):
self.program_translator.get_output(net.forward, self.x)
with self.assertRaisesRegexp(RuntimeError,
"only available in dynamic mode"):
self.program_translator.get_program(net.forward, self.x)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册