提交 edf36423 编写于 作者: D dangqingqing

Add extra_layers in paddle.trainer.SGD.

上级 d94e1f51
...@@ -53,7 +53,7 @@ import data_type ...@@ -53,7 +53,7 @@ import data_type
__all__ = ['parse_network', 'data'] __all__ = ['parse_network', 'data']
def parse_network(*outputs): def parse_network(*outputs, **kwargs):
""" """
Parse all output layers and then generate a ModelConfig object. Parse all output layers and then generate a ModelConfig object.
...@@ -75,6 +75,11 @@ def parse_network(*outputs): ...@@ -75,6 +75,11 @@ def parse_network(*outputs):
""" """
context = dict() context = dict()
real_output = [each.to_proto(context=context) for each in outputs] real_output = [each.to_proto(context=context) for each in outputs]
extra_layers = kwargs.get('extra_layers', None)
if extra_layers is not None:
extra_output = [
each.to_proto(context=context) for each in extra_layers
]
conf_helps.outputs(real_output) conf_helps.outputs(real_output)
return __parse__(__real_func__) return __parse__(__real_func__)
......
...@@ -53,14 +53,23 @@ class Topology(object): ...@@ -53,14 +53,23 @@ class Topology(object):
and network configs. and network configs.
""" """
def __init__(self, layers): def __init__(self, layers, extra_layers=None):
if not isinstance(layers, collections.Sequence): def __check__(layers):
__check_layer_type__(layers) if not isinstance(layers, collections.Sequence):
layers = [layers] __check_layer_type__(layers)
for layer in layers: layers = [layers]
__check_layer_type__(layer) for layer in layers:
__check_layer_type__(layer)
return layers
layers = __check__(layers)
self.layers = layers self.layers = layers
self.__model_config__ = v2_layer.parse_network(*layers) if extra_layers is not None:
extra_layers = __check__(extra_layers)
self.layers.extend(extra_layers)
self.__model_config__ = v2_layer.parse_network(
*layers, extra_layers=extra_layers)
assert isinstance(self.__model_config__, ModelConfig) assert isinstance(self.__model_config__, ModelConfig)
def proto(self): def proto(self):
......
...@@ -39,7 +39,7 @@ class SGD(object): ...@@ -39,7 +39,7 @@ class SGD(object):
:type parameters: paddle.v2.parameters.Parameters :type parameters: paddle.v2.parameters.Parameters
""" """
def __init__(self, cost, parameters, update_equation): def __init__(self, cost, parameters, update_equation, extra_layers=None):
if not isinstance(parameters, v2_parameters.Parameters): if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters') raise TypeError('parameters should be parameters')
...@@ -47,7 +47,7 @@ class SGD(object): ...@@ -47,7 +47,7 @@ class SGD(object):
if not isinstance(update_equation, v2_optimizer.Optimizer): if not isinstance(update_equation, v2_optimizer.Optimizer):
raise TypeError("update equation parameter must be " raise TypeError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer") "paddle.v2.optimizer.Optimizer")
topology = Topology(cost) topology = Topology(cost, extra_layers=extra_layers)
self.__optimizer__ = update_equation self.__optimizer__ = update_equation
self.__topology__ = topology self.__topology__ = topology
self.__parameters__ = parameters self.__parameters__ = parameters
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册