From b9dd33f815ed94237ddda930c99db838b76460e1 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 23 Feb 2017 23:30:17 +0800 Subject: [PATCH] hide Topology --- demo/mnist/api_train_v2.py | 6 ++---- python/paddle/v2/parameters.py | 8 +++----- python/paddle/v2/topology.py | 14 +++++++++----- python/paddle/v2/trainer.py | 4 +++- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index f6edd1f34fe..cc45229fbd6 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -26,9 +26,7 @@ def main(): act=paddle.activation.Softmax()) cost = paddle.layer.classification_cost(input=inference, label=label) - topology = paddle.topology.Topology(cost) - - parameters = paddle.parameters.create(topology) + parameters = paddle.parameters.create([cost]) for param_name in parameters.keys(): array = parameters.get(param_name) array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) @@ -49,7 +47,7 @@ def main(): trainer.train( train_data_reader=train_reader, - topology=topology, + topology=[cost], parameters=parameters, event_handler=event_handler, batch_size=32) # batch size should be refactor in Data reader diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index b569afe3a1f..b8d4b287032 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -7,15 +7,13 @@ import topology as v2_topology __all__ = ['Parameters', 'create'] -def create(topology): +def create(layers): """ Create parameter pool by topology. - :param topology: + :param layers: :return: """ - if not isinstance(topology, v2_topology.Topology): - raise ValueError( - 'create must pass a topology which type is topology.Topology') + topology = v2_topology.Topology(layers) pool = Parameters() for param in topology.proto().parameters: pool.__append_config__(param) diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index 6508b3ce881..bfd7ef171a5 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.proto.ModelConfig_pb2 import ModelConfig +import collections + import paddle.trainer_config_helpers as conf_helps -import layer as v2_layer +from paddle.proto.ModelConfig_pb2 import ModelConfig + import data_type +import layer as v2_layer __all__ = ['Topology'] @@ -26,11 +29,12 @@ class Topology(object): and network configs. """ - def __init__(self, *layers): + def __init__(self, layers): + if not isinstance(layers, collections.Sequence): + raise ValueError("input of Topology should be a list of Layer") for layer in layers: if not isinstance(layer, v2_layer.LayerV2): - raise ValueError('create must pass a topologies ' - 'which type is paddle.layer.Layer') + raise ValueError('layer should have type paddle.layer.Layer') self.layers = layers self.__model_config__ = v2_layer.parse_network(*layers) assert isinstance(self.__model_config__, ModelConfig) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index c8da6e70cf5..969aa6e0e05 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -73,7 +73,7 @@ class SGD(ITrainer): Training method. Will train num_passes of input data. :param train_data_reader: - :param topology: Network Topology, use one or more Layers to represent it. + :param topology: cost layers, use one or more Layers to represent it. :param parameters: The parameter pools. :param num_passes: The total train passes. :param test_data_reader: @@ -87,6 +87,8 @@ class SGD(ITrainer): if event_handler is None: event_handler = default_event_handler + topology = v2_topology.Topology(topology) + __check_train_args__(**locals()) gm = api.GradientMachine.createFromConfigProto( -- GitLab