未验证 提交 d069f2ca 编写于 作者: Y Yu Yang 提交者: GitHub

Make fluid.layers.fc support multiple param_attr (#6532)

Fix #6531
上级 1ba8f7fe
...@@ -36,6 +36,8 @@ class ParamAttr(object): ...@@ -36,6 +36,8 @@ class ParamAttr(object):
def to_attr(arg): def to_attr(arg):
if arg is None: if arg is None:
return ParamAttr() return ParamAttr()
elif isinstance(arg, list) or isinstance(arg, tuple):
return [ParamAttr.to_attr(a) for a in arg]
elif isinstance(arg, ParamAttr): elif isinstance(arg, ParamAttr):
return arg return arg
elif isinstance(arg, str) or isinstance(arg, unicode): elif isinstance(arg, str) or isinstance(arg, unicode):
......
...@@ -29,7 +29,10 @@ class TestBook(unittest.TestCase): ...@@ -29,7 +29,10 @@ class TestBook(unittest.TestCase):
label = layers.data(name='label', shape=[1], dtype='int32') label = layers.data(name='label', shape=[1], dtype='int32')
hidden1 = layers.fc(input=images, size=128, act='relu') hidden1 = layers.fc(input=images, size=128, act='relu')
hidden2 = layers.fc(input=hidden1, size=64, act='relu') hidden2 = layers.fc(input=hidden1, size=64, act='relu')
predict = layers.fc(input=hidden2, size=10, act='softmax') predict = layers.fc(input=[hidden2, hidden1],
size=10,
act='softmax',
param_attr=["sftmax.w1", "sftmax.w2"])
cost = layers.cross_entropy(input=predict, label=label) cost = layers.cross_entropy(input=predict, label=label)
avg_cost = layers.mean(x=cost) avg_cost = layers.mean(x=cost)
self.assertIsNotNone(avg_cost) self.assertIsNotNone(avg_cost)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册