提交 edf36423 编写于 作者: D dangqingqing

Add extra_layers in paddle.trainer.SGD.

上级 d94e1f51
......@@ -53,7 +53,7 @@ import data_type
__all__ = ['parse_network', 'data']
def parse_network(*outputs):
def parse_network(*outputs, **kwargs):
"""
Parse all output layers and then generate a ModelConfig object.
......@@ -75,6 +75,11 @@ def parse_network(*outputs):
"""
context = dict()
real_output = [each.to_proto(context=context) for each in outputs]
extra_layers = kwargs.get('extra_layers', None)
if extra_layers is not None:
extra_output = [
each.to_proto(context=context) for each in extra_layers
]
conf_helps.outputs(real_output)
return __parse__(__real_func__)
......
......@@ -53,14 +53,23 @@ class Topology(object):
and network configs.
"""
def __init__(self, layers):
if not isinstance(layers, collections.Sequence):
__check_layer_type__(layers)
layers = [layers]
for layer in layers:
__check_layer_type__(layer)
def __init__(self, layers, extra_layers=None):
def __check__(layers):
if not isinstance(layers, collections.Sequence):
__check_layer_type__(layers)
layers = [layers]
for layer in layers:
__check_layer_type__(layer)
return layers
layers = __check__(layers)
self.layers = layers
self.__model_config__ = v2_layer.parse_network(*layers)
if extra_layers is not None:
extra_layers = __check__(extra_layers)
self.layers.extend(extra_layers)
self.__model_config__ = v2_layer.parse_network(
*layers, extra_layers=extra_layers)
assert isinstance(self.__model_config__, ModelConfig)
def proto(self):
......
......@@ -39,7 +39,7 @@ class SGD(object):
:type parameters: paddle.v2.parameters.Parameters
"""
def __init__(self, cost, parameters, update_equation):
def __init__(self, cost, parameters, update_equation, extra_layers=None):
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters')
......@@ -47,7 +47,7 @@ class SGD(object):
if not isinstance(update_equation, v2_optimizer.Optimizer):
raise TypeError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer")
topology = Topology(cost)
topology = Topology(cost, extra_layers=extra_layers)
self.__optimizer__ = update_equation
self.__topology__ = topology
self.__parameters__ = parameters
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册