提交 95967313 编写于 作者: Y Yu Yang 提交者: GitHub

enhance default param_attrs (#5142)

上级 b44f4ccb
...@@ -75,18 +75,29 @@ class LayerHelper(object): ...@@ -75,18 +75,29 @@ class LayerHelper(object):
} }
} }
actual = self.kwargs.get('param_attr', None) actual = self.kwargs.get('param_attr', None)
return actual if actual is not None else default if actual is None:
actual = default
for default_field in default.keys():
if default_field not in actual:
actual[default_field] = default[default_field]
return actual
def bias_attr(self): def bias_attr(self):
bias_attr = self.kwargs.get('bias_attr', None) default = {
if bias_attr is True:
bias_attr = {
'name': None, 'name': None,
'init_attr': { 'init_attr': {
'type': 'fill_constant', 'type': 'fill_constant',
'value': 0.0 'value': 0.0
} }
} }
bias_attr = self.kwargs.get('bias_attr', None)
if bias_attr is True:
bias_attr = default
if isinstance(bias_attr, dict):
for default_field in default.keys():
if default_field not in bias_attr:
bias_attr[default_field] = default[default_field]
return bias_attr return bias_attr
def multiple_param_attr(self, length): def multiple_param_attr(self, length):
......
...@@ -103,40 +103,30 @@ class TestBook(unittest.TestCase): ...@@ -103,40 +103,30 @@ class TestBook(unittest.TestCase):
next_word = layers.data( next_word = layers.data(
name='nextw', shape=[1], data_type='int32', program=program) name='nextw', shape=[1], data_type='int32', program=program)
embed_param_attr_1 = {
'name': 'shared_w',
'init_attr': {
'max': 1.0,
'type': 'uniform_random',
'min': -1.0
}
}
embed_param_attr_2 = {'name': 'shared_w'}
embed_first = layers.embedding( embed_first = layers.embedding(
input=first_word, input=first_word,
size=[dict_size, embed_size], size=[dict_size, embed_size],
data_type='float32', data_type='float32',
param_attr=embed_param_attr_1, param_attr={'name': 'shared_w'},
program=program) program=program)
embed_second = layers.embedding( embed_second = layers.embedding(
input=second_word, input=second_word,
size=[dict_size, embed_size], size=[dict_size, embed_size],
data_type='float32', data_type='float32',
param_attr=embed_param_attr_2, param_attr={'name': 'shared_w'},
program=program) program=program)
embed_third = layers.embedding( embed_third = layers.embedding(
input=third_word, input=third_word,
size=[dict_size, embed_size], size=[dict_size, embed_size],
data_type='float32', data_type='float32',
param_attr=embed_param_attr_2, param_attr={'name': 'shared_w'},
program=program) program=program)
embed_forth = layers.embedding( embed_forth = layers.embedding(
input=forth_word, input=forth_word,
size=[dict_size, embed_size], size=[dict_size, embed_size],
data_type='float32', data_type='float32',
param_attr=embed_param_attr_2, param_attr={'name': 'shared_w'},
program=program) program=program)
concat_embed = layers.concat( concat_embed = layers.concat(
......
...@@ -50,28 +50,18 @@ next_word = layers.data( ...@@ -50,28 +50,18 @@ next_word = layers.data(
program=program, program=program,
init_program=init_program) init_program=init_program)
embed_param_attr_1 = {
'name': 'shared_w',
'init_attr': {
'max': 1.0,
'type': 'uniform_random',
'min': -1.0
}
}
embed_param_attr_2 = {'name': 'shared_w'}
embed_first = layers.embedding( embed_first = layers.embedding(
input=first_word, input=first_word,
size=[dict_size, embed_size], size=[dict_size, embed_size],
data_type='float32', data_type='float32',
param_attr=embed_param_attr_1, param_attr={'name': 'shared_w'},
program=program, program=program,
init_program=init_program) init_program=init_program)
embed_second = layers.embedding( embed_second = layers.embedding(
input=second_word, input=second_word,
size=[dict_size, embed_size], size=[dict_size, embed_size],
data_type='float32', data_type='float32',
param_attr=embed_param_attr_2, param_attr={'name': 'shared_w'},
program=program, program=program,
init_program=init_program) init_program=init_program)
...@@ -79,14 +69,14 @@ embed_third = layers.embedding( ...@@ -79,14 +69,14 @@ embed_third = layers.embedding(
input=third_word, input=third_word,
size=[dict_size, embed_size], size=[dict_size, embed_size],
data_type='float32', data_type='float32',
param_attr=embed_param_attr_2, param_attr={'name': 'shared_w'},
program=program, program=program,
init_program=init_program) init_program=init_program)
embed_forth = layers.embedding( embed_forth = layers.embedding(
input=forth_word, input=forth_word,
size=[dict_size, embed_size], size=[dict_size, embed_size],
data_type='float32', data_type='float32',
param_attr=embed_param_attr_2, param_attr={'name': 'shared_w'},
program=program, program=program,
init_program=init_program) init_program=init_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册