提交 7e6def58 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1756 from qingqing01/v2_api_multi_leaf_node

Add extra_layers in paddle.trainer.SGD.
...@@ -53,20 +53,29 @@ import data_type ...@@ -53,20 +53,29 @@ import data_type
__all__ = ['parse_network', 'data'] __all__ = ['parse_network', 'data']
def parse_network(*outputs): def parse_network(output_layers, extra_layers=None):
""" """
Parse all output layers and then generate a ModelConfig object. Parse all layers in the neural network graph and
then generate a ModelConfig object.
.. note:: .. note::
This function is used internally in paddle.v2 module. User should never This function is used internally in paddle.v2 module. User should never
invoke this method. invoke this method.
:param outputs: Output layers. :param output_layers: Output layers.
:type outputs: Layer :type output_layers: Layer
:param extra_layers: Some layers in the neural network graph are not in the
path of output_layers.
:type extra_layers: Layer
:return: A ModelConfig object instance. :return: A ModelConfig object instance.
:rtype: ModelConfig :rtype: ModelConfig
""" """
if not isinstance(output_layers, collections.Sequence):
output_layers = [output_layers]
if extra_layers is not None and not isinstance(extra_layers,
collections.Sequence):
extra_layers = [extra_layers]
def __real_func__(): def __real_func__():
""" """
...@@ -74,7 +83,11 @@ def parse_network(*outputs): ...@@ -74,7 +83,11 @@ def parse_network(*outputs):
the plain old paddle configuration function. the plain old paddle configuration function.
""" """
context = dict() context = dict()
real_output = [each.to_proto(context=context) for each in outputs] real_output = [each.to_proto(context=context) for each in output_layers]
if extra_layers is not None:
extra_output = [
each.to_proto(context=context) for each in extra_layers
]
conf_helps.outputs(real_output) conf_helps.outputs(real_output)
return __parse__(__real_func__) return __parse__(__real_func__)
......
...@@ -59,13 +59,13 @@ class ImageLayerTest(unittest.TestCase): ...@@ -59,13 +59,13 @@ class ImageLayerTest(unittest.TestCase):
num_channels=16, num_channels=16,
pool_type=pooling.Max()) pool_type=pooling.Max())
maxout = layer.maxout(input=conv, num_channels=16, groups=4) maxout = layer.maxout(input=conv, num_channels=16, groups=4)
print layer.parse_network(maxpool, spp, maxout) print layer.parse_network([maxpool, spp, maxout])
def test_norm_layer(self): def test_norm_layer(self):
norm1 = layer.img_cmrnorm(input=conv, size=5) norm1 = layer.img_cmrnorm(input=conv, size=5)
norm2 = layer.batch_norm(input=conv) norm2 = layer.batch_norm(input=conv)
norm3 = layer.sum_to_one_norm(input=conv) norm3 = layer.sum_to_one_norm(input=conv)
print layer.parse_network(norm1, norm2, norm3) print layer.parse_network([norm1, norm2, norm3])
class AggregateLayerTest(unittest.TestCase): class AggregateLayerTest(unittest.TestCase):
...@@ -78,7 +78,8 @@ class AggregateLayerTest(unittest.TestCase): ...@@ -78,7 +78,8 @@ class AggregateLayerTest(unittest.TestCase):
first_seq = layer.first_seq(input=pixel) first_seq = layer.first_seq(input=pixel)
concat = layer.concat(input=[last_seq, first_seq]) concat = layer.concat(input=[last_seq, first_seq])
seq_concat = layer.seq_concat(a=last_seq, b=first_seq) seq_concat = layer.seq_concat(a=last_seq, b=first_seq)
print layer.parse_network(pool, last_seq, first_seq, concat, seq_concat) print layer.parse_network(
[pool, last_seq, first_seq, concat, seq_concat])
class MathLayerTest(unittest.TestCase): class MathLayerTest(unittest.TestCase):
...@@ -95,8 +96,10 @@ class MathLayerTest(unittest.TestCase): ...@@ -95,8 +96,10 @@ class MathLayerTest(unittest.TestCase):
tensor = layer.tensor(a=pixel, b=pixel, size=1000) tensor = layer.tensor(a=pixel, b=pixel, size=1000)
cos_sim = layer.cos_sim(a=pixel, b=pixel) cos_sim = layer.cos_sim(a=pixel, b=pixel)
trans = layer.trans(input=tensor) trans = layer.trans(input=tensor)
print layer.parse_network(addto, linear_comb, interpolation, power, print layer.parse_network([
scaling, slope, tensor, cos_sim, trans) addto, linear_comb, interpolation, power, scaling, slope, tensor,
cos_sim, trans
])
class ReshapeLayerTest(unittest.TestCase): class ReshapeLayerTest(unittest.TestCase):
...@@ -110,7 +113,8 @@ class ReshapeLayerTest(unittest.TestCase): ...@@ -110,7 +113,8 @@ class ReshapeLayerTest(unittest.TestCase):
repeat = layer.repeat(input=pixel, num_repeats=4) repeat = layer.repeat(input=pixel, num_repeats=4)
reshape = layer.seq_reshape(input=pixel, reshape_size=4) reshape = layer.seq_reshape(input=pixel, reshape_size=4)
rotate = layer.rotate(input=pixel, height=16, width=49) rotate = layer.rotate(input=pixel, height=16, width=49)
print layer.parse_network(block_expand, expand, repeat, reshape, rotate) print layer.parse_network(
[block_expand, expand, repeat, reshape, rotate])
class RecurrentLayerTest(unittest.TestCase): class RecurrentLayerTest(unittest.TestCase):
...@@ -119,7 +123,7 @@ class RecurrentLayerTest(unittest.TestCase): ...@@ -119,7 +123,7 @@ class RecurrentLayerTest(unittest.TestCase):
recurrent = layer.recurrent(input=word) recurrent = layer.recurrent(input=word)
lstm = layer.lstmemory(input=word) lstm = layer.lstmemory(input=word)
gru = layer.grumemory(input=word) gru = layer.grumemory(input=word)
print layer.parse_network(recurrent, lstm, gru) print layer.parse_network([recurrent, lstm, gru])
class CostLayerTest(unittest.TestCase): class CostLayerTest(unittest.TestCase):
...@@ -139,10 +143,10 @@ class CostLayerTest(unittest.TestCase): ...@@ -139,10 +143,10 @@ class CostLayerTest(unittest.TestCase):
cost10 = layer.sum_cost(input=inference) cost10 = layer.sum_cost(input=inference)
cost11 = layer.huber_cost(input=score, label=label) cost11 = layer.huber_cost(input=score, label=label)
print layer.parse_network(cost1, cost2) print layer.parse_network([cost1, cost2])
print layer.parse_network(cost3, cost4) print layer.parse_network([cost3, cost4])
print layer.parse_network(cost5, cost6) print layer.parse_network([cost5, cost6])
print layer.parse_network(cost7, cost8, cost9, cost10, cost11) print layer.parse_network([cost7, cost8, cost9, cost10, cost11])
crf = layer.crf(input=inference, label=label) crf = layer.crf(input=inference, label=label)
crf_decoding = layer.crf_decoding(input=inference, size=3) crf_decoding = layer.crf_decoding(input=inference, size=3)
...@@ -151,8 +155,8 @@ class CostLayerTest(unittest.TestCase): ...@@ -151,8 +155,8 @@ class CostLayerTest(unittest.TestCase):
nce = layer.nce(input=inference, label=label, num_classes=3) nce = layer.nce(input=inference, label=label, num_classes=3)
hsigmoid = layer.hsigmoid(input=inference, label=label, num_classes=3) hsigmoid = layer.hsigmoid(input=inference, label=label, num_classes=3)
print layer.parse_network(crf, crf_decoding, ctc, warp_ctc, nce, print layer.parse_network(
hsigmoid) [crf, crf_decoding, ctc, warp_ctc, nce, hsigmoid])
class OtherLayerTest(unittest.TestCase): class OtherLayerTest(unittest.TestCase):
...@@ -160,7 +164,7 @@ class OtherLayerTest(unittest.TestCase): ...@@ -160,7 +164,7 @@ class OtherLayerTest(unittest.TestCase):
maxid = layer.max_id(input=inference) maxid = layer.max_id(input=inference)
sampling_id = layer.sampling_id(input=inference) sampling_id = layer.sampling_id(input=inference)
eos = layer.eos(input=maxid, eos_id=5) eos = layer.eos(input=maxid, eos_id=5)
print layer.parse_network(maxid, sampling_id, eos) print layer.parse_network([maxid, sampling_id, eos])
def test_slicing_joining_layer(self): def test_slicing_joining_layer(self):
pad = layer.pad(input=conv, pad_c=[2, 3], pad_h=[1, 2], pad_w=[3, 1]) pad = layer.pad(input=conv, pad_c=[2, 3], pad_h=[1, 2], pad_w=[3, 1])
......
...@@ -53,14 +53,26 @@ class Topology(object): ...@@ -53,14 +53,26 @@ class Topology(object):
and network configs. and network configs.
""" """
def __init__(self, layers): def __init__(self, layers, extra_layers=None):
if not isinstance(layers, collections.Sequence): def __check__(layers):
__check_layer_type__(layers) if not isinstance(layers, collections.Sequence):
layers = [layers] __check_layer_type__(layers)
for layer in layers: layers = [layers]
__check_layer_type__(layer) for layer in layers:
__check_layer_type__(layer)
return layers
layers = __check__(layers)
self.layers = layers self.layers = layers
self.__model_config__ = v2_layer.parse_network(*layers) if extra_layers is not None:
extra_layers = __check__(extra_layers)
self.__model_config__ = v2_layer.parse_network(
layers, extra_layers=extra_layers)
if extra_layers is not None:
self.layers.extend(extra_layers)
assert isinstance(self.__model_config__, ModelConfig) assert isinstance(self.__model_config__, ModelConfig)
def proto(self): def proto(self):
......
...@@ -37,9 +37,12 @@ class SGD(object): ...@@ -37,9 +37,12 @@ class SGD(object):
:type cost: paddle.v2.config_base.Layer :type cost: paddle.v2.config_base.Layer
:param parameters: The parameters dictionary. :param parameters: The parameters dictionary.
:type parameters: paddle.v2.parameters.Parameters :type parameters: paddle.v2.parameters.Parameters
:param extra_layers: Some layers in the neural network graph are not
in the path of cost layer.
:type extra_layers: paddle.v2.config_base.Layer
""" """
def __init__(self, cost, parameters, update_equation): def __init__(self, cost, parameters, update_equation, extra_layers=None):
if not isinstance(parameters, v2_parameters.Parameters): if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters') raise TypeError('parameters should be parameters')
...@@ -47,7 +50,7 @@ class SGD(object): ...@@ -47,7 +50,7 @@ class SGD(object):
if not isinstance(update_equation, v2_optimizer.Optimizer): if not isinstance(update_equation, v2_optimizer.Optimizer):
raise TypeError("update equation parameter must be " raise TypeError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer") "paddle.v2.optimizer.Optimizer")
topology = Topology(cost) topology = Topology(cost, extra_layers=extra_layers)
self.__optimizer__ = update_equation self.__optimizer__ = update_equation
self.__topology__ = topology self.__topology__ = topology
self.__parameters__ = parameters self.__parameters__ = parameters
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册