diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3994fa99f0ce22d3e88ebb33661527274d296d68..564fc1fd5adf26547d59e9d947a6935eda0ebb27 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -203,7 +203,7 @@ def in_dygraph_mode(): def _dygraph_not_support_(func): def __impl__(*args, **kwargs): assert not in_dygraph_mode( - ), "We don't support %s in Dygraph mode" % func.__name__ + ), "We don't support %s in imperative mode" % func.__name__ return func(*args, **kwargs) return __impl__ @@ -212,14 +212,31 @@ def _dygraph_not_support_(func): def _dygraph_only_(func): def __impl__(*args, **kwargs): assert in_dygraph_mode( - ), "We Only support %s in Dygraph mode, please use fluid.dygraph.guard() as context to run it in Dygraph Mode" % func.__name__ + ), "We Only support %s in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative Mode" % func.__name__ return func(*args, **kwargs) return __impl__ +# NOTE(zhiqiu): This decorator is used for the APIs of Variable which is only +# used to make Variable and VarBase has same interfaces, like numpy. Since VarBase is not exposed in our +# official docments, logically, we want to keep VarBase and logically consistent. While, actually, +# in our implementation, there some APIs not supported, like numpy, because Variable contains the desc. +# So, those APIs are listed under class Variable to generate docs only. +# TODO(zhiqiu): We should make VarBase consistent with Variable in future, for example, by inheritting +# same base class. +def _fake_interface_only_(func): + def __impl__(*args, **kwargs): + raise AssertionError( + "'%s' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode" + % func.__name__) + + return __impl__ + + dygraph_not_support = wrap_decorator(_dygraph_not_support_) dygraph_only = wrap_decorator(_dygraph_only_) +fake_interface_only = wrap_decorator(_fake_interface_only_) def _dygraph_tracer(): @@ -592,7 +609,6 @@ class VariableMetaClass(type): def __instancecheck__(cls, instance): t = type(instance) if in_dygraph_mode(): - return issubclass(t, core.VarBase) else: return issubclass(t, Variable) @@ -954,7 +970,7 @@ class Variable(object): self._stop_gradient = stop_gradient self.is_data = is_data - @dygraph_only + @fake_interface_only def detach(self): """ **Notes**: @@ -984,7 +1000,7 @@ class Variable(object): """ pass - @dygraph_only + @fake_interface_only def numpy(self): """ **Notes**: @@ -1016,7 +1032,7 @@ class Variable(object): """ pass - @dygraph_only + @fake_interface_only def set_value(self, value): """ **Notes**: @@ -1047,7 +1063,7 @@ class Variable(object): """ pass - @dygraph_only + @fake_interface_only def backward(self, backward_strategy=None): """ **Notes**: @@ -1085,7 +1101,7 @@ class Variable(object): """ pass - @dygraph_only + @fake_interface_only def gradient(self): """ **Notes**: @@ -1133,7 +1149,7 @@ class Variable(object): """ pass - @dygraph_only + @fake_interface_only def clear_gradient(self): """ **Notes**: @@ -1200,9 +1216,6 @@ class Variable(object): print("=============with detail===============") print(new_variable.to_string(True, True)) """ - if in_dygraph_mode(): - return - assert isinstance(throw_on_error, bool) and isinstance(with_details, bool) protostr = self.desc.serialize_to_string() @@ -1249,17 +1262,11 @@ class Variable(object): assert linear.weight.gradient() is None assert (out1.gradient() == 0).all() """ - if in_dygraph_mode(): - pass - else: - return self._stop_gradient + return self._stop_gradient @stop_gradient.setter def stop_gradient(self, s): - if in_dygraph_mode(): - pass - else: - self._stop_gradient = s + self._stop_gradient = s @property def persistable(self): @@ -1284,19 +1291,11 @@ class Variable(object): dtype='float32') print("persistable of current Var is: {}".format(new_variable.persistable)) """ - if in_dygraph_mode(): - pass - else: - return self.desc.persistable() + return self.desc.persistable() @persistable.setter def persistable(self, p): - if in_dygraph_mode(): - logging.warn( - "There will be no use to set persistable in Dygraph Mode, since " - "you can just do it by hold it as normal Python variable") - else: - self.desc.set_persistable(p) + self.desc.set_persistable(p) @property def name(self): @@ -1316,10 +1315,7 @@ class Variable(object): dtype='float32') print("name of current Var is: {}".format(new_variable.name)) """ - if in_dygraph_mode(): - pass - else: - return cpt.to_text(self.desc.name()) + return cpt.to_text(self.desc.name()) @property def grad_name(self): @@ -1343,10 +1339,7 @@ class Variable(object): @name.setter def name(self, new_name): - if in_dygraph_mode(): - pass - else: - self.desc.set_name(new_name) + self.desc.set_name(new_name) @property def shape(self): @@ -1368,10 +1361,7 @@ class Variable(object): """ # convert to tuple, make it as same as numpy API. - if in_dygraph_mode(): - pass - else: - return tuple(self.desc.shape()) + return tuple(self.desc.shape()) @property def dtype(self): @@ -1391,13 +1381,9 @@ class Variable(object): dtype='float32') print("Dtype of current Var is: {}".format(new_variable.dtype)) """ - if in_dygraph_mode(): - pass - else: - return self.desc.dtype() + return self.desc.dtype() @property - @dygraph_not_support def lod_level(self): """ Indicating ``LoD`` info of current Variable, please refer to :ref:`api_fluid_LoDTensor_en` to check the meaning @@ -1420,10 +1406,6 @@ class Variable(object): dtype='float32') print("LoD Level of current Var is: {}".format(new_variable.lod_level)) """ - # TODO(minqiyang): Support lod_level in dygraph mode - if in_dygraph_mode(): - raise Exception("Dygraph model DO NOT supprt lod") - if self.type == core.VarDesc.VarType.SELECTED_ROWS: raise Exception("SelectedRows DO NOT supprt lod") @@ -1447,10 +1429,7 @@ class Variable(object): dtype='float32') print("Type of current Var is: {}".format(new_variable.type)) """ - if in_dygraph_mode(): - pass - else: - return self.desc.type() + return self.desc.type() def _set_error_clip(self, error_clip): """ @@ -2018,10 +1997,7 @@ class Operator(object): @property def type(self): - if in_dygraph_mode(): - return self._type - else: - return self.desc.type() + return self.desc.type() def input(self, name): """ @@ -3977,7 +3953,6 @@ class Program(object): def _version(self): return self.desc._version() - @dygraph_not_support def clone(self, for_test=False): """ **Notes**: @@ -4664,7 +4639,6 @@ class Program(object): if other_var.stop_gradient: var.stop_gradient = True - @dygraph_not_support def list_vars(self): """ Get all :ref:`api_guide_Variable_en` from this Program. A iterable object is returned. @@ -4687,7 +4661,6 @@ class Program(object): for each_var in list(each_block.vars.values()): yield each_var - @dygraph_not_support def all_parameters(self): """ Get all :ref:`api_guide_parameter_en` from this Program. A list object is returned. diff --git a/python/paddle/fluid/tests/unittests/test_detach.py b/python/paddle/fluid/tests/unittests/test_detach.py index 59e9e9e41274e7a7b5492b9df70009383d60bb62..f0103f89a5940befb55b2148ecdb0453eeb5215c 100644 --- a/python/paddle/fluid/tests/unittests/test_detach.py +++ b/python/paddle/fluid/tests/unittests/test_detach.py @@ -158,7 +158,7 @@ class Test_Detach(unittest.TestCase): assert type(e) == AssertionError assert str( e - ) == 'We Only support detach in Dygraph mode, please use fluid.dygraph.guard() as context to run it in Dygraph Mode' + ) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode" if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 5023d0e87d50eb7e34ccadc2455cbe7250d16f9f..8d5ab0a5be757a3f6f0c86854d281210e01d3d99 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -184,27 +184,35 @@ class TestVariable(unittest.TestCase): with fluid.program_guard(default_main_program()): self._tostring() - # NOTE(zhiqiu): for coverage CI - # TODO(zhiqiu): code clean for dygraph - def test_dygraph_deprecated_api(self): + def test_fake_interface_only_api(self): b = default_main_program().current_block() var = b.create_var(dtype="float64", lod_level=0) with fluid.dygraph.guard(): - self.assertIsNone(var.detach()) - self.assertIsNone(var.numpy()) - self.assertIsNone(var.set_value(None)) - self.assertIsNone(var.backward()) - self.assertIsNone(var.gradient()) - self.assertIsNone(var.clear_gradient()) - self.assertIsNone(var.to_string(True)) - self.assertIsNone(var.persistable) + self.assertRaises(AssertionError, var.detach) + self.assertRaises(AssertionError, var.numpy) + self.assertRaises(AssertionError, var.set_value, None) + self.assertRaises(AssertionError, var.backward) + self.assertRaises(AssertionError, var.gradient) + self.assertRaises(AssertionError, var.clear_gradient) + + def test_variable_in_dygraph_mode(self): + b = default_main_program().current_block() + var = b.create_var(dtype="float64", shape=[1, 1]) + with fluid.dygraph.guard(): + self.assertTrue(var.to_string(True).startswith('name:')) + + self.assertFalse(var.persistable) + var.persistable = True + self.assertTrue(var.persistable) + + self.assertFalse(var.stop_gradient) var.stop_gradient = True - self.assertIsNone(var.stop_gradient) - var.stop_gradient = 'tmp' - self.assertIsNone(var.name) - self.assertIsNone(var.shape) - self.assertIsNone(var.dtype) - self.assertIsNone(var.type) + self.assertTrue(var.stop_gradient) + + self.assertTrue(var.name.startswith('_generated_var_')) + self.assertEqual(var.shape, (1, 1)) + self.assertEqual(var.dtype, fluid.core.VarDesc.VarType.FP64) + self.assertEqual(var.type, fluid.core.VarDesc.VarType.LOD_TENSOR) def test_create_selected_rows(self): b = default_main_program().current_block()