From 8d6de4401289477ea9763bffb93840024ec5e693 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 19 Jun 2020 11:56:39 +0800 Subject: [PATCH] Refine check_type Error Message for @declarative (#25098) * Refine check_type for @declarative test=develop --- python/paddle/fluid/data_feeder.py | 21 +++++++-- .../dygraph_to_static/program_translator.py | 46 ++++++++++++++----- .../tests/unittests/test_imperative_basic.py | 9 ++++ 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 03e14a3fef..71c722fb86 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -60,7 +60,7 @@ def convert_dtype(dtype): u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8' ]: # this code is a little bit dangerous, since error could happen - # when casting no-asci code to str in python2. + # when casting no-ascii code to str in python2. # but since the set itself is limited, so currently, it is good. # however, jointly supporting python2 and python3, (as well as python4 maybe) # may still be a long-lasting problem. @@ -76,8 +76,7 @@ def check_variable_and_dtype(input, expected_dtype, op_name, extra_message=''): - check_type(input, input_name, (Variable, core.VarBase), op_name, - extra_message) + check_type(input, input_name, Variable, op_name, extra_message) check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message) @@ -91,6 +90,22 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''): # each step in dynamic graph mode, it will bring a heavy performance burden. if in_dygraph_mode(): return + + from .dygraph.dygraph_to_static.program_translator import in_declarative_mode + # NOTE: `in_declarative_mode` is used to determined whether this op is called under + # @declarative in transformation from dygrah to static layer. We add VarBase in + # expected_type to skip checking because varBase may be created and used in unusual way. + # Need a better design to be fix this. + if in_declarative_mode(): + if not isinstance(expected_type, tuple): + expected_type = (expected_type, ) + expected_type += (core.VarBase, ) + elif isinstance(input, core.VarBase): + raise TypeError( + "Please use `with fluid.dygraph.guard()` as context or `fluid.enable_dygraph()` to switch to imperative mode firstly. " + "Because received '{}' in {} is a imperative Variable.".format( + input_name, op_name)) + if not isinstance(input, expected_type): raise TypeError( "The type of '%s' in %s must be %s, but received %s. %s" % diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 6f8edcdb0f..a74a9ff331 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -31,6 +31,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code +from paddle.fluid.wrapped_decorator import signature_safe_contextmanager from paddle.fluid.dygraph.base import param_guard from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from @@ -155,6 +156,28 @@ class FunctionSpec(object): return self.__key() == self.__key() +# Flag that indicates whether running code under `@declarative` +_in_declarative_mode_ = False + + +def in_declarative_mode(): + """ + Return a bool value that indicates whether running code under `@declarative` + + """ + return _in_declarative_mode_ + + +@signature_safe_contextmanager +def _switch_declarative_mode_guard_(is_declarative=True): + + global _in_declarative_mode_ + original_val = _in_declarative_mode_ + _in_declarative_mode_ = is_declarative + yield + _in_declarative_mode_ = original_val + + class ConcreteProgram(object): def __init__(self, inputs, @@ -190,17 +213,18 @@ class ConcreteProgram(object): ).random_seed with framework.program_guard(main_program, startup_program): - # 1. Adds `fluid.data` layers for input if needed - inputs = func_spec.to_static_inputs(main_program) - - # 2. Gets all ParamBases in the function - all_parameters = list(func_spec.parameters().values()) - - # 3. Builds program only once and returns the output Variables. - with param_guard(func_spec.parameters(False)): - outputs = static_func(*inputs) - if not isinstance(outputs, (tuple, list)): - outputs = [outputs] if outputs else [] + with _switch_declarative_mode_guard_(is_declarative=True): + # 1. Adds `fluid.data` layers for input if needed + inputs = func_spec.to_static_inputs(main_program) + + # 2. Gets all ParamBases in the function + all_parameters = list(func_spec.parameters().values()) + + # 3. Builds program only once and returns the output Variables. + with param_guard(func_spec.parameters(False)): + outputs = static_func(*inputs) + if not isinstance(outputs, (tuple, list)): + outputs = [outputs] if outputs else [] return ConcreteProgram( inputs=inputs, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 9dead7dc74..75661644c1 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -645,5 +645,14 @@ class TestDygraphUtils(unittest.TestCase): self.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) +class TestDygraphGuardWithError(unittest.TestCase): + def test_without_guard(self): + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(np.zeros([10, 10])) + with self.assertRaisesRegexp(TypeError, + "Please use `with fluid.dygraph.guard()"): + y = fluid.layers.matmul(x, x) + + if __name__ == '__main__': unittest.main() -- GitLab