diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 59486ed1b3ba494a20b06b7ef7027fc3e86c043c..b5cc74ce67dfc8e1afa65bd52f5ec600260032ce 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -25,8 +25,7 @@ def main(): act=paddle.activation.Softmax()) cost = paddle.layer.classification_cost(input=inference, label=label) - topology = paddle.layer.parse_network(cost) - parameters = paddle.parameters.create(topology) + parameters = paddle.parameters.create(cost) for param_name in parameters.keys(): array = parameters.get(param_name) array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) @@ -46,7 +45,7 @@ def main(): trainer = paddle.trainer.SGD(update_equation=adam_optimizer) trainer.train(train_data_reader=train_reader, - topology=topology, + topology=cost, parameters=parameters, event_handler=event_handler, batch_size=32, # batch size should be refactor in Data reader diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index e5b7dabcb8eb3a845dedea663f978e7a9820495d..ea504d5104716d157add87ed3f6e31ea69e0a3f0 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -1,27 +1,27 @@ import numpy as np - -from paddle.proto.ModelConfig_pb2 import ModelConfig -from paddle.proto.ParameterConfig_pb2 import ParameterConfig +from . import layer as v2_layer import py_paddle.swig_paddle as api +from paddle.proto.ParameterConfig_pb2 import ParameterConfig __all__ = ['Parameters', 'create'] -def create(*topologies): +def create(*layers): """ - Create parameter pool by topologies. + Create parameter pool by layers. In paddle, layer can be represent a + model config. - :param topologies: + :param layers: :return: """ - pool = Parameters() - for topo in topologies: - if not isinstance(topo, ModelConfig): + for layer in layers: + if not isinstance(layer, v2_layer.Layer): raise ValueError( - 'create must pass a topologies which type is ModelConfig') - - for param in topo.parameters: - pool.__append_config__(param) + 'create must pass a topologies which type is paddle.layer.Layer') + model_config = v2_layer.parse_network(*layers) + pool = Parameters() + for param in model_config.parameters: + pool.__append_config__(param) return pool diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 9ba13dc5c8a81f8dcf39260d1a44dcdcc7c22ad5..4365bd41e7073bce4112e5813dbf1517856c06f5 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -1,12 +1,13 @@ import collections import py_paddle.swig_paddle as api +from paddle.proto.ModelConfig_pb2 import ModelConfig from py_paddle import DataProviderConverter -from paddle.proto.ModelConfig_pb2 import ModelConfig +from . import event as v2_event +from . import layer as v2_layer from . import optimizer as v2_optimizer from . import parameters as v2_parameters -from . import event as v2_event __all__ = ['ITrainer', 'SGD'] @@ -73,7 +74,7 @@ class SGD(ITrainer): Training method. Will train num_passes of input data. :param train_data_reader: - :param topology: Network Topology, a protobuf ModelConfig message. + :param topology: Network Topology, use one or more Layers to represent it. :param parameters: The parameter pools. :param num_passes: The total train passes. :param test_data_reader: @@ -87,6 +88,8 @@ class SGD(ITrainer): if event_handler is None: event_handler = default_event_handler + topology = v2_layer.parse_network(topology) + __check_train_args__(**locals()) gm = api.GradientMachine.createFromConfigProto(