未验证 提交 a4519a5d 编写于 作者: L Leo Chen 提交者: GitHub

Dev/add fake_interface_only decorator for some function of Variable (#24083) (#24166)

* add decorator, test=develop

* add fake_interface_only, test=develop

* remove some dygraph_not_support, test=develop

* change dygraph to imperative, test=develop
上级 91ae7848
......@@ -203,7 +203,7 @@ def in_dygraph_mode():
def _dygraph_not_support_(func):
def __impl__(*args, **kwargs):
assert not in_dygraph_mode(
), "We don't support %s in Dygraph mode" % func.__name__
), "We don't support %s in imperative mode" % func.__name__
return func(*args, **kwargs)
return __impl__
......@@ -212,14 +212,31 @@ def _dygraph_not_support_(func):
def _dygraph_only_(func):
def __impl__(*args, **kwargs):
assert in_dygraph_mode(
), "We Only support %s in Dygraph mode, please use fluid.dygraph.guard() as context to run it in Dygraph Mode" % func.__name__
), "We Only support %s in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative Mode" % func.__name__
return func(*args, **kwargs)
return __impl__
# NOTE(zhiqiu): This decorator is used for the APIs of Variable which is only
# used to make Variable and VarBase has same interfaces, like numpy. Since VarBase is not exposed in our
# official docments, logically, we want to keep VarBase and logically consistent. While, actually,
# in our implementation, there some APIs not supported, like numpy, because Variable contains the desc.
# So, those APIs are listed under class Variable to generate docs only.
# TODO(zhiqiu): We should make VarBase consistent with Variable in future, for example, by inheritting
# same base class.
def _fake_interface_only_(func):
def __impl__(*args, **kwargs):
raise AssertionError(
"'%s' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
% func.__name__)
return __impl__
dygraph_not_support = wrap_decorator(_dygraph_not_support_)
dygraph_only = wrap_decorator(_dygraph_only_)
fake_interface_only = wrap_decorator(_fake_interface_only_)
def _dygraph_tracer():
......@@ -592,7 +609,6 @@ class VariableMetaClass(type):
def __instancecheck__(cls, instance):
t = type(instance)
if in_dygraph_mode():
return issubclass(t, core.VarBase)
else:
return issubclass(t, Variable)
......@@ -954,7 +970,7 @@ class Variable(object):
self._stop_gradient = stop_gradient
self.is_data = is_data
@dygraph_only
@fake_interface_only
def detach(self):
"""
**Notes**:
......@@ -984,7 +1000,7 @@ class Variable(object):
"""
pass
@dygraph_only
@fake_interface_only
def numpy(self):
"""
**Notes**:
......@@ -1016,7 +1032,7 @@ class Variable(object):
"""
pass
@dygraph_only
@fake_interface_only
def set_value(self, value):
"""
**Notes**:
......@@ -1047,7 +1063,7 @@ class Variable(object):
"""
pass
@dygraph_only
@fake_interface_only
def backward(self, backward_strategy=None):
"""
**Notes**:
......@@ -1085,7 +1101,7 @@ class Variable(object):
"""
pass
@dygraph_only
@fake_interface_only
def gradient(self):
"""
**Notes**:
......@@ -1133,7 +1149,7 @@ class Variable(object):
"""
pass
@dygraph_only
@fake_interface_only
def clear_gradient(self):
"""
**Notes**:
......@@ -1200,9 +1216,6 @@ class Variable(object):
print("=============with detail===============")
print(new_variable.to_string(True, True))
"""
if in_dygraph_mode():
return
assert isinstance(throw_on_error, bool) and isinstance(with_details,
bool)
protostr = self.desc.serialize_to_string()
......@@ -1249,16 +1262,10 @@ class Variable(object):
assert linear.weight.gradient() is None
assert (out1.gradient() == 0).all()
"""
if in_dygraph_mode():
pass
else:
return self._stop_gradient
@stop_gradient.setter
def stop_gradient(self, s):
if in_dygraph_mode():
pass
else:
self._stop_gradient = s
@property
......@@ -1284,18 +1291,10 @@ class Variable(object):
dtype='float32')
print("persistable of current Var is: {}".format(new_variable.persistable))
"""
if in_dygraph_mode():
pass
else:
return self.desc.persistable()
@persistable.setter
def persistable(self, p):
if in_dygraph_mode():
logging.warn(
"There will be no use to set persistable in Dygraph Mode, since "
"you can just do it by hold it as normal Python variable")
else:
self.desc.set_persistable(p)
@property
......@@ -1316,9 +1315,6 @@ class Variable(object):
dtype='float32')
print("name of current Var is: {}".format(new_variable.name))
"""
if in_dygraph_mode():
pass
else:
return cpt.to_text(self.desc.name())
@property
......@@ -1343,9 +1339,6 @@ class Variable(object):
@name.setter
def name(self, new_name):
if in_dygraph_mode():
pass
else:
self.desc.set_name(new_name)
@property
......@@ -1368,9 +1361,6 @@ class Variable(object):
"""
# convert to tuple, make it as same as numpy API.
if in_dygraph_mode():
pass
else:
return tuple(self.desc.shape())
@property
......@@ -1391,13 +1381,9 @@ class Variable(object):
dtype='float32')
print("Dtype of current Var is: {}".format(new_variable.dtype))
"""
if in_dygraph_mode():
pass
else:
return self.desc.dtype()
@property
@dygraph_not_support
def lod_level(self):
"""
Indicating ``LoD`` info of current Variable, please refer to :ref:`api_fluid_LoDTensor_en` to check the meaning
......@@ -1420,10 +1406,6 @@ class Variable(object):
dtype='float32')
print("LoD Level of current Var is: {}".format(new_variable.lod_level))
"""
# TODO(minqiyang): Support lod_level in dygraph mode
if in_dygraph_mode():
raise Exception("Dygraph model DO NOT supprt lod")
if self.type == core.VarDesc.VarType.SELECTED_ROWS:
raise Exception("SelectedRows DO NOT supprt lod")
......@@ -1447,9 +1429,6 @@ class Variable(object):
dtype='float32')
print("Type of current Var is: {}".format(new_variable.type))
"""
if in_dygraph_mode():
pass
else:
return self.desc.type()
def _set_error_clip(self, error_clip):
......@@ -2018,9 +1997,6 @@ class Operator(object):
@property
def type(self):
if in_dygraph_mode():
return self._type
else:
return self.desc.type()
def input(self, name):
......@@ -3977,7 +3953,6 @@ class Program(object):
def _version(self):
return self.desc._version()
@dygraph_not_support
def clone(self, for_test=False):
"""
**Notes**:
......@@ -4664,7 +4639,6 @@ class Program(object):
if other_var.stop_gradient:
var.stop_gradient = True
@dygraph_not_support
def list_vars(self):
"""
Get all :ref:`api_guide_Variable_en` from this Program. A iterable object is returned.
......@@ -4687,7 +4661,6 @@ class Program(object):
for each_var in list(each_block.vars.values()):
yield each_var
@dygraph_not_support
def all_parameters(self):
"""
Get all :ref:`api_guide_parameter_en` from this Program. A list object is returned.
......
......@@ -158,7 +158,7 @@ class Test_Detach(unittest.TestCase):
assert type(e) == AssertionError
assert str(
e
) == 'We Only support detach in Dygraph mode, please use fluid.dygraph.guard() as context to run it in Dygraph Mode'
) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
if __name__ == '__main__':
......
......@@ -184,27 +184,35 @@ class TestVariable(unittest.TestCase):
with fluid.program_guard(default_main_program()):
self._tostring()
# NOTE(zhiqiu): for coverage CI
# TODO(zhiqiu): code clean for dygraph
def test_dygraph_deprecated_api(self):
def test_fake_interface_only_api(self):
b = default_main_program().current_block()
var = b.create_var(dtype="float64", lod_level=0)
with fluid.dygraph.guard():
self.assertIsNone(var.detach())
self.assertIsNone(var.numpy())
self.assertIsNone(var.set_value(None))
self.assertIsNone(var.backward())
self.assertIsNone(var.gradient())
self.assertIsNone(var.clear_gradient())
self.assertIsNone(var.to_string(True))
self.assertIsNone(var.persistable)
self.assertRaises(AssertionError, var.detach)
self.assertRaises(AssertionError, var.numpy)
self.assertRaises(AssertionError, var.set_value, None)
self.assertRaises(AssertionError, var.backward)
self.assertRaises(AssertionError, var.gradient)
self.assertRaises(AssertionError, var.clear_gradient)
def test_variable_in_dygraph_mode(self):
b = default_main_program().current_block()
var = b.create_var(dtype="float64", shape=[1, 1])
with fluid.dygraph.guard():
self.assertTrue(var.to_string(True).startswith('name:'))
self.assertFalse(var.persistable)
var.persistable = True
self.assertTrue(var.persistable)
self.assertFalse(var.stop_gradient)
var.stop_gradient = True
self.assertIsNone(var.stop_gradient)
var.stop_gradient = 'tmp'
self.assertIsNone(var.name)
self.assertIsNone(var.shape)
self.assertIsNone(var.dtype)
self.assertIsNone(var.type)
self.assertTrue(var.stop_gradient)
self.assertTrue(var.name.startswith('_generated_var_'))
self.assertEqual(var.shape, (1, 1))
self.assertEqual(var.dtype, fluid.core.VarDesc.VarType.FP64)
self.assertEqual(var.type, fluid.core.VarDesc.VarType.LOD_TENSOR)
def test_create_selected_rows(self):
b = default_main_program().current_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册