diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index 1e4efedde363f20fde168941adcb6e8a594b533a..a91c6586a48aae4c692088e52d179efe4f69dbdd 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -53,7 +53,7 @@ import data_type __all__ = ['parse_network', 'data'] -def parse_network(*outputs): +def parse_network(*outputs, **kwargs): """ Parse all output layers and then generate a ModelConfig object. @@ -75,6 +75,11 @@ def parse_network(*outputs): """ context = dict() real_output = [each.to_proto(context=context) for each in outputs] + extra_layers = kwargs.get('extra_layers', None) + if extra_layers is not None: + extra_output = [ + each.to_proto(context=context) for each in extra_layers + ] conf_helps.outputs(real_output) return __parse__(__real_func__) diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index f0679c5675b0c0f24f28f3df22efd4eb51ccbb3a..3e6fded839bf2ca0ddc263af1345e8216b7d2a94 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -53,14 +53,23 @@ class Topology(object): and network configs. """ - def __init__(self, layers): - if not isinstance(layers, collections.Sequence): - __check_layer_type__(layers) - layers = [layers] - for layer in layers: - __check_layer_type__(layer) + def __init__(self, layers, extra_layers=None): + def __check__(layers): + if not isinstance(layers, collections.Sequence): + __check_layer_type__(layers) + layers = [layers] + for layer in layers: + __check_layer_type__(layer) + return layers + + layers = __check__(layers) self.layers = layers - self.__model_config__ = v2_layer.parse_network(*layers) + if extra_layers is not None: + extra_layers = __check__(extra_layers) + self.layers.extend(extra_layers) + + self.__model_config__ = v2_layer.parse_network( + *layers, extra_layers=extra_layers) assert isinstance(self.__model_config__, ModelConfig) def proto(self): diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 4e432a52b209c825ca1b74393cd607db8f884f4f..cccb00794acdf936d2e11e68ae1573f99a8a5155 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -39,7 +39,7 @@ class SGD(object): :type parameters: paddle.v2.parameters.Parameters """ - def __init__(self, cost, parameters, update_equation): + def __init__(self, cost, parameters, update_equation, extra_layers=None): if not isinstance(parameters, v2_parameters.Parameters): raise TypeError('parameters should be parameters') @@ -47,7 +47,7 @@ class SGD(object): if not isinstance(update_equation, v2_optimizer.Optimizer): raise TypeError("update equation parameter must be " "paddle.v2.optimizer.Optimizer") - topology = Topology(cost) + topology = Topology(cost, extra_layers=extra_layers) self.__optimizer__ = update_equation self.__topology__ = topology self.__parameters__ = parameters