未验证 提交 8d6de440 编写于 作者: A Aurelius84 提交者: GitHub

Refine check_type Error Message for @declarative (#25098)

* Refine check_type for @declarative test=develop
上级 b23801a2
...@@ -60,7 +60,7 @@ def convert_dtype(dtype): ...@@ -60,7 +60,7 @@ def convert_dtype(dtype):
u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8' u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8'
]: ]:
# this code is a little bit dangerous, since error could happen # this code is a little bit dangerous, since error could happen
# when casting no-asci code to str in python2. # when casting no-ascii code to str in python2.
# but since the set itself is limited, so currently, it is good. # but since the set itself is limited, so currently, it is good.
# however, jointly supporting python2 and python3, (as well as python4 maybe) # however, jointly supporting python2 and python3, (as well as python4 maybe)
# may still be a long-lasting problem. # may still be a long-lasting problem.
...@@ -76,8 +76,7 @@ def check_variable_and_dtype(input, ...@@ -76,8 +76,7 @@ def check_variable_and_dtype(input,
expected_dtype, expected_dtype,
op_name, op_name,
extra_message=''): extra_message=''):
check_type(input, input_name, (Variable, core.VarBase), op_name, check_type(input, input_name, Variable, op_name, extra_message)
extra_message)
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message) check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
...@@ -91,6 +90,22 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''): ...@@ -91,6 +90,22 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''):
# each step in dynamic graph mode, it will bring a heavy performance burden. # each step in dynamic graph mode, it will bring a heavy performance burden.
if in_dygraph_mode(): if in_dygraph_mode():
return return
from .dygraph.dygraph_to_static.program_translator import in_declarative_mode
# NOTE: `in_declarative_mode` is used to determined whether this op is called under
# @declarative in transformation from dygrah to static layer. We add VarBase in
# expected_type to skip checking because varBase may be created and used in unusual way.
# Need a better design to be fix this.
if in_declarative_mode():
if not isinstance(expected_type, tuple):
expected_type = (expected_type, )
expected_type += (core.VarBase, )
elif isinstance(input, core.VarBase):
raise TypeError(
"Please use `with fluid.dygraph.guard()` as context or `fluid.enable_dygraph()` to switch to imperative mode firstly. "
"Because received '{}' in {} is a imperative Variable.".format(
input_name, op_name))
if not isinstance(input, expected_type): if not isinstance(input, expected_type):
raise TypeError( raise TypeError(
"The type of '%s' in %s must be %s, but received %s. %s" % "The type of '%s' in %s must be %s, but received %s. %s" %
......
...@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph ...@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
...@@ -155,6 +156,28 @@ class FunctionSpec(object): ...@@ -155,6 +156,28 @@ class FunctionSpec(object):
return self.__key() == self.__key() return self.__key() == self.__key()
# Flag that indicates whether running code under `@declarative`
_in_declarative_mode_ = False
def in_declarative_mode():
"""
Return a bool value that indicates whether running code under `@declarative`
"""
return _in_declarative_mode_
@signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True):
global _in_declarative_mode_
original_val = _in_declarative_mode_
_in_declarative_mode_ = is_declarative
yield
_in_declarative_mode_ = original_val
class ConcreteProgram(object): class ConcreteProgram(object):
def __init__(self, def __init__(self,
inputs, inputs,
...@@ -190,17 +213,18 @@ class ConcreteProgram(object): ...@@ -190,17 +213,18 @@ class ConcreteProgram(object):
).random_seed ).random_seed
with framework.program_guard(main_program, startup_program): with framework.program_guard(main_program, startup_program):
# 1. Adds `fluid.data` layers for input if needed with _switch_declarative_mode_guard_(is_declarative=True):
inputs = func_spec.to_static_inputs(main_program) # 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs(main_program)
# 2. Gets all ParamBases in the function
all_parameters = list(func_spec.parameters().values()) # 2. Gets all ParamBases in the function
all_parameters = list(func_spec.parameters().values())
# 3. Builds program only once and returns the output Variables.
with param_guard(func_spec.parameters(False)): # 3. Builds program only once and returns the output Variables.
outputs = static_func(*inputs) with param_guard(func_spec.parameters(False)):
if not isinstance(outputs, (tuple, list)): outputs = static_func(*inputs)
outputs = [outputs] if outputs else [] if not isinstance(outputs, (tuple, list)):
outputs = [outputs] if outputs else []
return ConcreteProgram( return ConcreteProgram(
inputs=inputs, inputs=inputs,
......
...@@ -645,5 +645,14 @@ class TestDygraphUtils(unittest.TestCase): ...@@ -645,5 +645,14 @@ class TestDygraphUtils(unittest.TestCase):
self.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))
class TestDygraphGuardWithError(unittest.TestCase):
def test_without_guard(self):
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(np.zeros([10, 10]))
with self.assertRaisesRegexp(TypeError,
"Please use `with fluid.dygraph.guard()"):
y = fluid.layers.matmul(x, x)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册