diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index f6edd1f34fe8d197422dfc2a7641fd3bb37adc66..cc45229fbd688d5a0d3d3105c9fa8e1402095cca 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -26,9 +26,7 @@ def main(): act=paddle.activation.Softmax()) cost = paddle.layer.classification_cost(input=inference, label=label) - topology = paddle.topology.Topology(cost) - - parameters = paddle.parameters.create(topology) + parameters = paddle.parameters.create([cost]) for param_name in parameters.keys(): array = parameters.get(param_name) array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) @@ -49,7 +47,7 @@ def main(): trainer.train( train_data_reader=train_reader, - topology=topology, + topology=[cost], parameters=parameters, event_handler=event_handler, batch_size=32) # batch size should be refactor in Data reader diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index b569afe3a1fc205d2fbab32989eadc9955304088..b8d4b287032cd3b2369e7ae7a0ef9bffc39576cf 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -7,15 +7,13 @@ import topology as v2_topology __all__ = ['Parameters', 'create'] -def create(topology): +def create(layers): """ Create parameter pool by topology. - :param topology: + :param layers: :return: """ - if not isinstance(topology, v2_topology.Topology): - raise ValueError( - 'create must pass a topology which type is topology.Topology') + topology = v2_topology.Topology(layers) pool = Parameters() for param in topology.proto().parameters: pool.__append_config__(param) diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index 6508b3ce881c4957b4a001506f9872aac17e9cec..bfd7ef171a5fb6421801f589e36e3172c42469f3 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.proto.ModelConfig_pb2 import ModelConfig +import collections + import paddle.trainer_config_helpers as conf_helps -import layer as v2_layer +from paddle.proto.ModelConfig_pb2 import ModelConfig + import data_type +import layer as v2_layer __all__ = ['Topology'] @@ -26,11 +29,12 @@ class Topology(object): and network configs. """ - def __init__(self, *layers): + def __init__(self, layers): + if not isinstance(layers, collections.Sequence): + raise ValueError("input of Topology should be a list of Layer") for layer in layers: if not isinstance(layer, v2_layer.LayerV2): - raise ValueError('create must pass a topologies ' - 'which type is paddle.layer.Layer') + raise ValueError('layer should have type paddle.layer.Layer') self.layers = layers self.__model_config__ = v2_layer.parse_network(*layers) assert isinstance(self.__model_config__, ModelConfig) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index c8da6e70cf528da64c40bd3245c4af6bf32b6a7f..969aa6e0e059350bed110ac7cd9a8af654d91dc5 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -73,7 +73,7 @@ class SGD(ITrainer): Training method. Will train num_passes of input data. :param train_data_reader: - :param topology: Network Topology, use one or more Layers to represent it. + :param topology: cost layers, use one or more Layers to represent it. :param parameters: The parameter pools. :param num_passes: The total train passes. :param test_data_reader: @@ -87,6 +87,8 @@ class SGD(ITrainer): if event_handler is None: event_handler = default_event_handler + topology = v2_topology.Topology(topology) + __check_train_args__(**locals()) gm = api.GradientMachine.createFromConfigProto(