未验证 提交 b0111778 编写于 作者: A Aurelius84 提交者: GitHub

enhance error message (#26808)

上级 3ce727a8
...@@ -393,19 +393,43 @@ class StaticLayer(object): ...@@ -393,19 +393,43 @@ class StaticLayer(object):
def concrete_program(self): def concrete_program(self):
""" """
Returns recent ConcreteProgram instance of decorated function. Returns recent ConcreteProgram instance of decorated function.
Examples:
.. code-block:: python
import paddle
from paddle.jit import to_static
from paddle.static import InputSpec
paddle.disable_static()
def foo(x, y):
z = x + y
return z
# usage 1:
decorated_foo = to_static(foo, input_spec=[InputSpec([10], name='x'), InputSpec([10], name='y')])
print(decorated_foo.concrete_program)
# usage 2:
decorated_foo = to_static(foo)
out_foo = decorated_foo(paddle.rand([10]), paddle.rand([10]))
print(decorated_foo.concrete_program)
""" """
# if specific the `input_spec`, the length of program_cache will always 1, # if specific the `input_spec`, the length of program_cache will always 1,
# else, return the last one. # else, return the last one.
cached_program_len = len(self._program_cache) cached_program_len = len(self._program_cache)
# If specific `input_spec`, apply convertion from dygraph layers into static Program. # If specific `input_spec`, apply convertion from dygraph layers into static Program.
if cached_program_len == 0: if cached_program_len == 0:
if len(self._function_spec.flat_input_spec) > 0:
input_spec = self._function_spec.input_spec input_spec = self._function_spec.input_spec
has_input_spec = (input_spec is not None and len(input_spec) > 0)
if has_input_spec:
concrete_program, _ = self.get_concrete_program(*input_spec) concrete_program, _ = self.get_concrete_program(*input_spec)
return concrete_program return concrete_program
else: else:
raise ValueError("No valid transformed program for {}".format( raise ValueError(
self._function_spec)) "No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".
format(self._function_spec))
# If more than one programs have been cached, return the recent converted program by default. # If more than one programs have been cached, return the recent converted program by default.
elif cached_program_len > 1: elif cached_program_len > 1:
logging.warning( logging.warning(
......
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import paddle
from paddle.static import InputSpec from paddle.static import InputSpec
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram
import unittest import unittest
...@@ -246,6 +248,29 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): ...@@ -246,6 +248,29 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
concrete_program_5 = foo.get_concrete_program( concrete_program_5 = foo.get_concrete_program(
InputSpec([10]), InputSpec([10]), e=4) InputSpec([10]), InputSpec([10]), e=4)
def test_concrete_program(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
# usage 1
foo_1 = paddle.jit.to_static(
foo_func,
input_spec=[
InputSpec(
[10], name='x'), InputSpec(
[10], name='y')
])
self.assertTrue(isinstance(foo_1.concrete_program, ConcreteProgram))
# usage 2
foo_2 = paddle.jit.to_static(foo_func)
out = foo_2(paddle.rand([10]), paddle.rand([10]))
self.assertTrue(isinstance(foo_2.concrete_program, ConcreteProgram))
# raise error
foo_3 = paddle.jit.to_static(foo_func)
with self.assertRaises(ValueError):
foo_3.concrete_program
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册