未验证 提交 6e04334c 编写于 作者: H hong 提交者: GitHub

Check dygraph weight name (#22140)

* add parameter check; test=develop

* change parameter name checker in dygraph guard; test=develop

* fix test layers error; test=develop

* revert some code to develop; test=develop

* fix exampel error; test=develop

* fix comment error; test=develop

* fix comment error; test=develop
上级 6882b8eb
......@@ -331,6 +331,14 @@ class LayerHelperBase(object):
if in_dygraph_mode():
# In dygraph mode, we want the returned parameter to be
# initialized so that it can be used imperatively.
# check parameter name
is_used = unique_name.dygraph_parameter_name_checker(attr.name)
if is_used:
raise ValueError(
"parameter name [{}] have be been used. "
"In dygraph mode, the name of parameter can't be same."
"Please check the parameter attr value passed to self.create_parameter or "
"constructor of dygraph Layers".format(attr.name))
return self.main_program.global_block().create_parameter(
dtype=dtype,
shape=shape,
......
......@@ -1550,6 +1550,15 @@ class TestLayer(LayerTest):
class TestBook(LayerTest):
def setUp(self):
self.only_static_set = set({"make_word_embedding"})
self.not_compare_static_dygraph_set = set({
"make_gaussian_random", "make_gaussian_random_batch_size_like",
"make_kldiv_loss", "make_prelu",
"make_sampled_softmax_with_cross_entropy", "make_sampling_id",
"make_uniform_random_batch_size_like"
})
def test_all_layers(self):
attrs = (getattr(self, name) for name in dir(self))
methods = filter(inspect.ismethod, attrs)
......@@ -1572,9 +1581,12 @@ class TestBook(LayerTest):
feed=self._feed_dict,
fetch_list=fetch_list,
force_to_use_cpu=self._force_to_use_cpu)
else:
assert method.__name__ in ('make_get_places')
continue
if method.__name__ in self.only_static_set:
continue
with self.dynamic_graph(self._force_to_use_cpu):
dy_result = method()
......@@ -1582,7 +1594,9 @@ class TestBook(LayerTest):
dy_result = dy_result[0]
dy_result_value = dy_result.numpy()
self.assertTrue(np.array_equal(static_result[0], dy_result_value))
if method.__name__ not in self.not_compare_static_dygraph_set:
self.assertTrue(
np.array_equal(static_result[0], dy_result_value))
def _get_np_data(self, shape, dtype, append_batch_size=True):
np.random.seed(self.seed)
......
......@@ -51,6 +51,33 @@ class UniqueNameGenerator(object):
return self.prefix + "_".join([key, str(tmp)])
class DygraphParameterNameChecker(object):
"""
Check whether the name of parameter is used.
"""
def __init__(self):
self._name_set = set()
def __call__(self, name):
'''
Check whether the name is used. If not used, insert into the _name_set.
Args:
name(str): The name of parameter to check.
Returns(bool): If the name is in name_set, return True; Otherwise, return False.
'''
if name in self._name_set:
return True
else:
self._name_set.add(name)
return False
dygraph_parameter_name_checker = DygraphParameterNameChecker()
generator = UniqueNameGenerator()
......@@ -101,7 +128,7 @@ def generate_with_ignorable_key(key):
return generator(key)
def switch(new_generator=None):
def switch(new_generator=None, new_para_name_checker=None):
"""
Switch the namespace of in current context to a new namespace. Though
:code:`switch()` and :code:`guard()` can both change namespace,
......@@ -112,9 +139,13 @@ def switch(new_generator=None):
new_generator(UniqueNameGenerator, optional): A new UniqueNameGenerator, not
required normally. Default is None, which means switch to a new anonymous
namespace.
new_para_name_checker(DygraphParameterNameChecker, optional): A new DygraphParameterNameChecker,
not required normally. Default is None, which means switch to a new parameter name
checker.
Returns:
UniqueNameGenerator: The previous UniqueNameGenerator.
DygraphParameterNameChecker: The previous DygraphParameterNameChecker
Examples:
......@@ -125,22 +156,29 @@ def switch(new_generator=None):
name2 = fluid.unique_name.generate('fc')
print(name1, name2) # fc_0, fc_1
pre_generator = fluid.unique_name.switch() # switch to a new anonymous namespace.
pre_generator, pre_dygraph_name_checker = fluid.unique_name.switch() # switch to a new anonymous namespace.
name2 = fluid.unique_name.generate('fc')
print(name2) # fc_0
fluid.unique_name.switch(pre_generator) # switch back to pre_generator.
fluid.unique_name.switch(pre_generator, pre_dygraph_name_checker) # switch back to pre_generator.
name3 = fluid.unique_name.generate('fc')
print(name3) # fc_2, since pre_generator has generated fc_0, fc_1.
"""
global generator
old = generator
old_generator = generator
global dygraph_parameter_name_checker
old_para_name_checker = dygraph_parameter_name_checker
if new_generator is None:
generator = UniqueNameGenerator()
else:
generator = new_generator
return old
if new_para_name_checker is None:
dygraph_parameter_name_checker = DygraphParameterNameChecker()
else:
dygraph_parameter_name_checker = new_para_name_checker
return old_generator, old_para_name_checker
@signature_safe_contextmanager
......@@ -180,6 +218,7 @@ def guard(new_generator=None):
new_generator = UniqueNameGenerator(new_generator)
elif isinstance(new_generator, six.binary_type):
new_generator = UniqueNameGenerator(new_generator.decode())
old = switch(new_generator)
old_generator, old_para_name_checker = switch(new_generator)
yield
switch(old)
switch(old_generator, old_para_name_checker)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册