提交 cafb075e 编写于 作者: J jacquesqiao 提交者: GitHub

Merge pull request #1361 from jacquesqiao/hide-parse_network

hide parse_config
...@@ -25,8 +25,7 @@ def main(): ...@@ -25,8 +25,7 @@ def main():
act=paddle.activation.Softmax()) act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=inference, label=label) cost = paddle.layer.classification_cost(input=inference, label=label)
topology = paddle.layer.parse_network(cost) parameters = paddle.parameters.create(cost)
parameters = paddle.parameters.create(topology)
for param_name in parameters.keys(): for param_name in parameters.keys():
array = parameters.get(param_name) array = parameters.get(param_name)
array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
...@@ -46,7 +45,7 @@ def main(): ...@@ -46,7 +45,7 @@ def main():
trainer = paddle.trainer.SGD(update_equation=adam_optimizer) trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train(train_data_reader=train_reader, trainer.train(train_data_reader=train_reader,
topology=topology, topology=cost,
parameters=parameters, parameters=parameters,
event_handler=event_handler, event_handler=event_handler,
batch_size=32, # batch size should be refactor in Data reader batch_size=32, # batch size should be refactor in Data reader
......
import numpy as np import numpy as np
from . import layer as v2_layer
from paddle.proto.ModelConfig_pb2 import ModelConfig
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
__all__ = ['Parameters', 'create'] __all__ = ['Parameters', 'create']
def create(*topologies): def create(*layers):
""" """
Create parameter pool by topologies. Create parameter pool by layers. In paddle, layer can be represent a
model config.
:param topologies: :param layers:
:return: :return:
""" """
pool = Parameters() for layer in layers:
for topo in topologies: if not isinstance(layer, v2_layer.Layer):
if not isinstance(topo, ModelConfig):
raise ValueError( raise ValueError(
'create must pass a topologies which type is ModelConfig') 'create must pass a topologies which type is paddle.layer.Layer')
model_config = v2_layer.parse_network(*layers)
for param in topo.parameters: pool = Parameters()
pool.__append_config__(param) for param in model_config.parameters:
pool.__append_config__(param)
return pool return pool
......
import collections import collections
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
from paddle.proto.ModelConfig_pb2 import ModelConfig
from py_paddle import DataProviderConverter from py_paddle import DataProviderConverter
from paddle.proto.ModelConfig_pb2 import ModelConfig from . import event as v2_event
from . import layer as v2_layer
from . import optimizer as v2_optimizer from . import optimizer as v2_optimizer
from . import parameters as v2_parameters from . import parameters as v2_parameters
from . import event as v2_event
__all__ = ['ITrainer', 'SGD'] __all__ = ['ITrainer', 'SGD']
...@@ -73,7 +74,7 @@ class SGD(ITrainer): ...@@ -73,7 +74,7 @@ class SGD(ITrainer):
Training method. Will train num_passes of input data. Training method. Will train num_passes of input data.
:param train_data_reader: :param train_data_reader:
:param topology: Network Topology, a protobuf ModelConfig message. :param topology: Network Topology, use one or more Layers to represent it.
:param parameters: The parameter pools. :param parameters: The parameter pools.
:param num_passes: The total train passes. :param num_passes: The total train passes.
:param test_data_reader: :param test_data_reader:
...@@ -87,6 +88,8 @@ class SGD(ITrainer): ...@@ -87,6 +88,8 @@ class SGD(ITrainer):
if event_handler is None: if event_handler is None:
event_handler = default_event_handler event_handler = default_event_handler
topology = v2_layer.parse_network(topology)
__check_train_args__(**locals()) __check_train_args__(**locals())
gm = api.GradientMachine.createFromConfigProto( gm = api.GradientMachine.createFromConfigProto(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册