提交 695b5a7f 编写于 作者: Q qiaolongfei

change topology to layer

上级 6089b7c6
......@@ -25,8 +25,7 @@ def main():
act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=inference, label=label)
topology = paddle.layer.parse_network(cost)
parameters = paddle.parameters.create(topology)
parameters = paddle.parameters.create(cost)
for param_name in parameters.keys():
array = parameters.get(param_name)
array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
......@@ -46,7 +45,7 @@ def main():
trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train(train_data_reader=train_reader,
topology=topology,
topology=cost,
parameters=parameters,
event_handler=event_handler,
batch_size=32, # batch size should be refactor in Data reader
......
import numpy as np
from paddle.proto.ModelConfig_pb2 import ModelConfig
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
from . import layer as v2_layer
import py_paddle.swig_paddle as api
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
__all__ = ['Parameters', 'create']
def create(*topologies):
def create(*layers):
"""
Create parameter pool by topologies.
Create parameter pool by layers. In paddle, layer can be represent a
model config.
:param topologies:
:param layers:
:return:
"""
pool = Parameters()
for topo in topologies:
if not isinstance(topo, ModelConfig):
for layer in layers:
if not isinstance(layer, v2_layer.Layer):
raise ValueError(
'create must pass a topologies which type is ModelConfig')
for param in topo.parameters:
pool.__append_config__(param)
'create must pass a topologies which type is paddle.layer.Layer')
model_config = v2_layer.parse_network(*layers)
pool = Parameters()
for param in model_config.parameters:
pool.__append_config__(param)
return pool
......
import collections
import py_paddle.swig_paddle as api
from paddle.proto.ModelConfig_pb2 import ModelConfig
from py_paddle import DataProviderConverter
from paddle.proto.ModelConfig_pb2 import ModelConfig
from . import event as v2_event
from . import layer as v2_layer
from . import optimizer as v2_optimizer
from . import parameters as v2_parameters
from . import event as v2_event
__all__ = ['ITrainer', 'SGD']
......@@ -73,7 +74,7 @@ class SGD(ITrainer):
Training method. Will train num_passes of input data.
:param train_data_reader:
:param topology: Network Topology, a protobuf ModelConfig message.
:param topology: Network Topology, use one or more Layers to represent it.
:param parameters: The parameter pools.
:param num_passes: The total train passes.
:param test_data_reader:
......@@ -87,6 +88,8 @@ class SGD(ITrainer):
if event_handler is None:
event_handler = default_event_handler
topology = v2_layer.parse_network(topology)
__check_train_args__(**locals())
gm = api.GradientMachine.createFromConfigProto(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册