From 695b5a7fcd42e4e5678fdb4288cc8dd23240aac4 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Fri, 17 Feb 2017 18:09:55 +0800 Subject: [PATCH] change topology to layer --- demo/mnist/api_train_v2.py | 5 ++--- python/paddle/v2/parameters.py | 26 +++++++++++++------------- python/paddle/v2/trainer.py | 9 ++++++--- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 59486ed1b3b..b5cc74ce67d 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -25,8 +25,7 @@ def main(): act=paddle.activation.Softmax()) cost = paddle.layer.classification_cost(input=inference, label=label) - topology = paddle.layer.parse_network(cost) - parameters = paddle.parameters.create(topology) + parameters = paddle.parameters.create(cost) for param_name in parameters.keys(): array = parameters.get(param_name) array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) @@ -46,7 +45,7 @@ def main(): trainer = paddle.trainer.SGD(update_equation=adam_optimizer) trainer.train(train_data_reader=train_reader, - topology=topology, + topology=cost, parameters=parameters, event_handler=event_handler, batch_size=32, # batch size should be refactor in Data reader diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index e5b7dabcb8e..ea504d51047 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -1,27 +1,27 @@ import numpy as np - -from paddle.proto.ModelConfig_pb2 import ModelConfig -from paddle.proto.ParameterConfig_pb2 import ParameterConfig +from . import layer as v2_layer import py_paddle.swig_paddle as api +from paddle.proto.ParameterConfig_pb2 import ParameterConfig __all__ = ['Parameters', 'create'] -def create(*topologies): +def create(*layers): """ - Create parameter pool by topologies. + Create parameter pool by layers. In paddle, layer can be represent a + model config. - :param topologies: + :param layers: :return: """ - pool = Parameters() - for topo in topologies: - if not isinstance(topo, ModelConfig): + for layer in layers: + if not isinstance(layer, v2_layer.Layer): raise ValueError( - 'create must pass a topologies which type is ModelConfig') - - for param in topo.parameters: - pool.__append_config__(param) + 'create must pass a topologies which type is paddle.layer.Layer') + model_config = v2_layer.parse_network(*layers) + pool = Parameters() + for param in model_config.parameters: + pool.__append_config__(param) return pool diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 9ba13dc5c8a..4365bd41e70 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -1,12 +1,13 @@ import collections import py_paddle.swig_paddle as api +from paddle.proto.ModelConfig_pb2 import ModelConfig from py_paddle import DataProviderConverter -from paddle.proto.ModelConfig_pb2 import ModelConfig +from . import event as v2_event +from . import layer as v2_layer from . import optimizer as v2_optimizer from . import parameters as v2_parameters -from . import event as v2_event __all__ = ['ITrainer', 'SGD'] @@ -73,7 +74,7 @@ class SGD(ITrainer): Training method. Will train num_passes of input data. :param train_data_reader: - :param topology: Network Topology, a protobuf ModelConfig message. + :param topology: Network Topology, use one or more Layers to represent it. :param parameters: The parameter pools. :param num_passes: The total train passes. :param test_data_reader: @@ -87,6 +88,8 @@ class SGD(ITrainer): if event_handler is None: event_handler = default_event_handler + topology = v2_layer.parse_network(topology) + __check_train_args__(**locals()) gm = api.GradientMachine.createFromConfigProto( -- GitLab