提交 95967313 编写于 作者: Y Yu Yang 提交者: GitHub

enhance default param_attrs (#5142)

上级 b44f4ccb
......@@ -75,18 +75,29 @@ class LayerHelper(object):
}
}
actual = self.kwargs.get('param_attr', None)
return actual if actual is not None else default
if actual is None:
actual = default
for default_field in default.keys():
if default_field not in actual:
actual[default_field] = default[default_field]
return actual
def bias_attr(self):
bias_attr = self.kwargs.get('bias_attr', None)
if bias_attr is True:
bias_attr = {
default = {
'name': None,
'init_attr': {
'type': 'fill_constant',
'value': 0.0
}
}
bias_attr = self.kwargs.get('bias_attr', None)
if bias_attr is True:
bias_attr = default
if isinstance(bias_attr, dict):
for default_field in default.keys():
if default_field not in bias_attr:
bias_attr[default_field] = default[default_field]
return bias_attr
def multiple_param_attr(self, length):
......
......@@ -103,40 +103,30 @@ class TestBook(unittest.TestCase):
next_word = layers.data(
name='nextw', shape=[1], data_type='int32', program=program)
embed_param_attr_1 = {
'name': 'shared_w',
'init_attr': {
'max': 1.0,
'type': 'uniform_random',
'min': -1.0
}
}
embed_param_attr_2 = {'name': 'shared_w'}
embed_first = layers.embedding(
input=first_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_1,
param_attr={'name': 'shared_w'},
program=program)
embed_second = layers.embedding(
input=second_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program)
embed_third = layers.embedding(
input=third_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program)
embed_forth = layers.embedding(
input=forth_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program)
concat_embed = layers.concat(
......
......@@ -50,28 +50,18 @@ next_word = layers.data(
program=program,
init_program=init_program)
embed_param_attr_1 = {
'name': 'shared_w',
'init_attr': {
'max': 1.0,
'type': 'uniform_random',
'min': -1.0
}
}
embed_param_attr_2 = {'name': 'shared_w'}
embed_first = layers.embedding(
input=first_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_1,
param_attr={'name': 'shared_w'},
program=program,
init_program=init_program)
embed_second = layers.embedding(
input=second_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program,
init_program=init_program)
......@@ -79,14 +69,14 @@ embed_third = layers.embedding(
input=third_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program,
init_program=init_program)
embed_forth = layers.embedding(
input=forth_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program,
init_program=init_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册