未验证 提交 a4519a5d 编写于 作者: L Leo Chen 提交者: GitHub

Dev/add fake_interface_only decorator for some function of Variable (#24083) (#24166)

* add decorator, test=develop

* add fake_interface_only, test=develop

* remove some dygraph_not_support, test=develop

* change dygraph to imperative, test=develop
上级 91ae7848
...@@ -203,7 +203,7 @@ def in_dygraph_mode(): ...@@ -203,7 +203,7 @@ def in_dygraph_mode():
def _dygraph_not_support_(func): def _dygraph_not_support_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
assert not in_dygraph_mode( assert not in_dygraph_mode(
), "We don't support %s in Dygraph mode" % func.__name__ ), "We don't support %s in imperative mode" % func.__name__
return func(*args, **kwargs) return func(*args, **kwargs)
return __impl__ return __impl__
...@@ -212,14 +212,31 @@ def _dygraph_not_support_(func): ...@@ -212,14 +212,31 @@ def _dygraph_not_support_(func):
def _dygraph_only_(func): def _dygraph_only_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
assert in_dygraph_mode( assert in_dygraph_mode(
), "We Only support %s in Dygraph mode, please use fluid.dygraph.guard() as context to run it in Dygraph Mode" % func.__name__ ), "We Only support %s in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative Mode" % func.__name__
return func(*args, **kwargs) return func(*args, **kwargs)
return __impl__ return __impl__
# NOTE(zhiqiu): This decorator is used for the APIs of Variable which is only
# used to make Variable and VarBase has same interfaces, like numpy. Since VarBase is not exposed in our
# official docments, logically, we want to keep VarBase and logically consistent. While, actually,
# in our implementation, there some APIs not supported, like numpy, because Variable contains the desc.
# So, those APIs are listed under class Variable to generate docs only.
# TODO(zhiqiu): We should make VarBase consistent with Variable in future, for example, by inheritting
# same base class.
def _fake_interface_only_(func):
def __impl__(*args, **kwargs):
raise AssertionError(
"'%s' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
% func.__name__)
return __impl__
dygraph_not_support = wrap_decorator(_dygraph_not_support_) dygraph_not_support = wrap_decorator(_dygraph_not_support_)
dygraph_only = wrap_decorator(_dygraph_only_) dygraph_only = wrap_decorator(_dygraph_only_)
fake_interface_only = wrap_decorator(_fake_interface_only_)
def _dygraph_tracer(): def _dygraph_tracer():
...@@ -592,7 +609,6 @@ class VariableMetaClass(type): ...@@ -592,7 +609,6 @@ class VariableMetaClass(type):
def __instancecheck__(cls, instance): def __instancecheck__(cls, instance):
t = type(instance) t = type(instance)
if in_dygraph_mode(): if in_dygraph_mode():
return issubclass(t, core.VarBase) return issubclass(t, core.VarBase)
else: else:
return issubclass(t, Variable) return issubclass(t, Variable)
...@@ -954,7 +970,7 @@ class Variable(object): ...@@ -954,7 +970,7 @@ class Variable(object):
self._stop_gradient = stop_gradient self._stop_gradient = stop_gradient
self.is_data = is_data self.is_data = is_data
@dygraph_only @fake_interface_only
def detach(self): def detach(self):
""" """
**Notes**: **Notes**:
...@@ -984,7 +1000,7 @@ class Variable(object): ...@@ -984,7 +1000,7 @@ class Variable(object):
""" """
pass pass
@dygraph_only @fake_interface_only
def numpy(self): def numpy(self):
""" """
**Notes**: **Notes**:
...@@ -1016,7 +1032,7 @@ class Variable(object): ...@@ -1016,7 +1032,7 @@ class Variable(object):
""" """
pass pass
@dygraph_only @fake_interface_only
def set_value(self, value): def set_value(self, value):
""" """
**Notes**: **Notes**:
...@@ -1047,7 +1063,7 @@ class Variable(object): ...@@ -1047,7 +1063,7 @@ class Variable(object):
""" """
pass pass
@dygraph_only @fake_interface_only
def backward(self, backward_strategy=None): def backward(self, backward_strategy=None):
""" """
**Notes**: **Notes**:
...@@ -1085,7 +1101,7 @@ class Variable(object): ...@@ -1085,7 +1101,7 @@ class Variable(object):
""" """
pass pass
@dygraph_only @fake_interface_only
def gradient(self): def gradient(self):
""" """
**Notes**: **Notes**:
...@@ -1133,7 +1149,7 @@ class Variable(object): ...@@ -1133,7 +1149,7 @@ class Variable(object):
""" """
pass pass
@dygraph_only @fake_interface_only
def clear_gradient(self): def clear_gradient(self):
""" """
**Notes**: **Notes**:
...@@ -1200,9 +1216,6 @@ class Variable(object): ...@@ -1200,9 +1216,6 @@ class Variable(object):
print("=============with detail===============") print("=============with detail===============")
print(new_variable.to_string(True, True)) print(new_variable.to_string(True, True))
""" """
if in_dygraph_mode():
return
assert isinstance(throw_on_error, bool) and isinstance(with_details, assert isinstance(throw_on_error, bool) and isinstance(with_details,
bool) bool)
protostr = self.desc.serialize_to_string() protostr = self.desc.serialize_to_string()
...@@ -1249,17 +1262,11 @@ class Variable(object): ...@@ -1249,17 +1262,11 @@ class Variable(object):
assert linear.weight.gradient() is None assert linear.weight.gradient() is None
assert (out1.gradient() == 0).all() assert (out1.gradient() == 0).all()
""" """
if in_dygraph_mode(): return self._stop_gradient
pass
else:
return self._stop_gradient
@stop_gradient.setter @stop_gradient.setter
def stop_gradient(self, s): def stop_gradient(self, s):
if in_dygraph_mode(): self._stop_gradient = s
pass
else:
self._stop_gradient = s
@property @property
def persistable(self): def persistable(self):
...@@ -1284,19 +1291,11 @@ class Variable(object): ...@@ -1284,19 +1291,11 @@ class Variable(object):
dtype='float32') dtype='float32')
print("persistable of current Var is: {}".format(new_variable.persistable)) print("persistable of current Var is: {}".format(new_variable.persistable))
""" """
if in_dygraph_mode(): return self.desc.persistable()
pass
else:
return self.desc.persistable()
@persistable.setter @persistable.setter
def persistable(self, p): def persistable(self, p):
if in_dygraph_mode(): self.desc.set_persistable(p)
logging.warn(
"There will be no use to set persistable in Dygraph Mode, since "
"you can just do it by hold it as normal Python variable")
else:
self.desc.set_persistable(p)
@property @property
def name(self): def name(self):
...@@ -1316,10 +1315,7 @@ class Variable(object): ...@@ -1316,10 +1315,7 @@ class Variable(object):
dtype='float32') dtype='float32')
print("name of current Var is: {}".format(new_variable.name)) print("name of current Var is: {}".format(new_variable.name))
""" """
if in_dygraph_mode(): return cpt.to_text(self.desc.name())
pass
else:
return cpt.to_text(self.desc.name())
@property @property
def grad_name(self): def grad_name(self):
...@@ -1343,10 +1339,7 @@ class Variable(object): ...@@ -1343,10 +1339,7 @@ class Variable(object):
@name.setter @name.setter
def name(self, new_name): def name(self, new_name):
if in_dygraph_mode(): self.desc.set_name(new_name)
pass
else:
self.desc.set_name(new_name)
@property @property
def shape(self): def shape(self):
...@@ -1368,10 +1361,7 @@ class Variable(object): ...@@ -1368,10 +1361,7 @@ class Variable(object):
""" """
# convert to tuple, make it as same as numpy API. # convert to tuple, make it as same as numpy API.
if in_dygraph_mode(): return tuple(self.desc.shape())
pass
else:
return tuple(self.desc.shape())
@property @property
def dtype(self): def dtype(self):
...@@ -1391,13 +1381,9 @@ class Variable(object): ...@@ -1391,13 +1381,9 @@ class Variable(object):
dtype='float32') dtype='float32')
print("Dtype of current Var is: {}".format(new_variable.dtype)) print("Dtype of current Var is: {}".format(new_variable.dtype))
""" """
if in_dygraph_mode(): return self.desc.dtype()
pass
else:
return self.desc.dtype()
@property @property
@dygraph_not_support
def lod_level(self): def lod_level(self):
""" """
Indicating ``LoD`` info of current Variable, please refer to :ref:`api_fluid_LoDTensor_en` to check the meaning Indicating ``LoD`` info of current Variable, please refer to :ref:`api_fluid_LoDTensor_en` to check the meaning
...@@ -1420,10 +1406,6 @@ class Variable(object): ...@@ -1420,10 +1406,6 @@ class Variable(object):
dtype='float32') dtype='float32')
print("LoD Level of current Var is: {}".format(new_variable.lod_level)) print("LoD Level of current Var is: {}".format(new_variable.lod_level))
""" """
# TODO(minqiyang): Support lod_level in dygraph mode
if in_dygraph_mode():
raise Exception("Dygraph model DO NOT supprt lod")
if self.type == core.VarDesc.VarType.SELECTED_ROWS: if self.type == core.VarDesc.VarType.SELECTED_ROWS:
raise Exception("SelectedRows DO NOT supprt lod") raise Exception("SelectedRows DO NOT supprt lod")
...@@ -1447,10 +1429,7 @@ class Variable(object): ...@@ -1447,10 +1429,7 @@ class Variable(object):
dtype='float32') dtype='float32')
print("Type of current Var is: {}".format(new_variable.type)) print("Type of current Var is: {}".format(new_variable.type))
""" """
if in_dygraph_mode(): return self.desc.type()
pass
else:
return self.desc.type()
def _set_error_clip(self, error_clip): def _set_error_clip(self, error_clip):
""" """
...@@ -2018,10 +1997,7 @@ class Operator(object): ...@@ -2018,10 +1997,7 @@ class Operator(object):
@property @property
def type(self): def type(self):
if in_dygraph_mode(): return self.desc.type()
return self._type
else:
return self.desc.type()
def input(self, name): def input(self, name):
""" """
...@@ -3977,7 +3953,6 @@ class Program(object): ...@@ -3977,7 +3953,6 @@ class Program(object):
def _version(self): def _version(self):
return self.desc._version() return self.desc._version()
@dygraph_not_support
def clone(self, for_test=False): def clone(self, for_test=False):
""" """
**Notes**: **Notes**:
...@@ -4664,7 +4639,6 @@ class Program(object): ...@@ -4664,7 +4639,6 @@ class Program(object):
if other_var.stop_gradient: if other_var.stop_gradient:
var.stop_gradient = True var.stop_gradient = True
@dygraph_not_support
def list_vars(self): def list_vars(self):
""" """
Get all :ref:`api_guide_Variable_en` from this Program. A iterable object is returned. Get all :ref:`api_guide_Variable_en` from this Program. A iterable object is returned.
...@@ -4687,7 +4661,6 @@ class Program(object): ...@@ -4687,7 +4661,6 @@ class Program(object):
for each_var in list(each_block.vars.values()): for each_var in list(each_block.vars.values()):
yield each_var yield each_var
@dygraph_not_support
def all_parameters(self): def all_parameters(self):
""" """
Get all :ref:`api_guide_parameter_en` from this Program. A list object is returned. Get all :ref:`api_guide_parameter_en` from this Program. A list object is returned.
......
...@@ -158,7 +158,7 @@ class Test_Detach(unittest.TestCase): ...@@ -158,7 +158,7 @@ class Test_Detach(unittest.TestCase):
assert type(e) == AssertionError assert type(e) == AssertionError
assert str( assert str(
e e
) == 'We Only support detach in Dygraph mode, please use fluid.dygraph.guard() as context to run it in Dygraph Mode' ) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -184,27 +184,35 @@ class TestVariable(unittest.TestCase): ...@@ -184,27 +184,35 @@ class TestVariable(unittest.TestCase):
with fluid.program_guard(default_main_program()): with fluid.program_guard(default_main_program()):
self._tostring() self._tostring()
# NOTE(zhiqiu): for coverage CI def test_fake_interface_only_api(self):
# TODO(zhiqiu): code clean for dygraph
def test_dygraph_deprecated_api(self):
b = default_main_program().current_block() b = default_main_program().current_block()
var = b.create_var(dtype="float64", lod_level=0) var = b.create_var(dtype="float64", lod_level=0)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
self.assertIsNone(var.detach()) self.assertRaises(AssertionError, var.detach)
self.assertIsNone(var.numpy()) self.assertRaises(AssertionError, var.numpy)
self.assertIsNone(var.set_value(None)) self.assertRaises(AssertionError, var.set_value, None)
self.assertIsNone(var.backward()) self.assertRaises(AssertionError, var.backward)
self.assertIsNone(var.gradient()) self.assertRaises(AssertionError, var.gradient)
self.assertIsNone(var.clear_gradient()) self.assertRaises(AssertionError, var.clear_gradient)
self.assertIsNone(var.to_string(True))
self.assertIsNone(var.persistable) def test_variable_in_dygraph_mode(self):
b = default_main_program().current_block()
var = b.create_var(dtype="float64", shape=[1, 1])
with fluid.dygraph.guard():
self.assertTrue(var.to_string(True).startswith('name:'))
self.assertFalse(var.persistable)
var.persistable = True
self.assertTrue(var.persistable)
self.assertFalse(var.stop_gradient)
var.stop_gradient = True var.stop_gradient = True
self.assertIsNone(var.stop_gradient) self.assertTrue(var.stop_gradient)
var.stop_gradient = 'tmp'
self.assertIsNone(var.name) self.assertTrue(var.name.startswith('_generated_var_'))
self.assertIsNone(var.shape) self.assertEqual(var.shape, (1, 1))
self.assertIsNone(var.dtype) self.assertEqual(var.dtype, fluid.core.VarDesc.VarType.FP64)
self.assertIsNone(var.type) self.assertEqual(var.type, fluid.core.VarDesc.VarType.LOD_TENSOR)
def test_create_selected_rows(self): def test_create_selected_rows(self):
b = default_main_program().current_block() b = default_main_program().current_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册