提交 b9dd33f8 编写于 作者: Q qiaolongfei

hide Topology

上级 53bd4a48
...@@ -26,9 +26,7 @@ def main(): ...@@ -26,9 +26,7 @@ def main():
act=paddle.activation.Softmax()) act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=inference, label=label) cost = paddle.layer.classification_cost(input=inference, label=label)
topology = paddle.topology.Topology(cost) parameters = paddle.parameters.create([cost])
parameters = paddle.parameters.create(topology)
for param_name in parameters.keys(): for param_name in parameters.keys():
array = parameters.get(param_name) array = parameters.get(param_name)
array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
...@@ -49,7 +47,7 @@ def main(): ...@@ -49,7 +47,7 @@ def main():
trainer.train( trainer.train(
train_data_reader=train_reader, train_data_reader=train_reader,
topology=topology, topology=[cost],
parameters=parameters, parameters=parameters,
event_handler=event_handler, event_handler=event_handler,
batch_size=32) # batch size should be refactor in Data reader batch_size=32) # batch size should be refactor in Data reader
......
...@@ -7,15 +7,13 @@ import topology as v2_topology ...@@ -7,15 +7,13 @@ import topology as v2_topology
__all__ = ['Parameters', 'create'] __all__ = ['Parameters', 'create']
def create(topology): def create(layers):
""" """
Create parameter pool by topology. Create parameter pool by topology.
:param topology: :param layers:
:return: :return:
""" """
if not isinstance(topology, v2_topology.Topology): topology = v2_topology.Topology(layers)
raise ValueError(
'create must pass a topology which type is topology.Topology')
pool = Parameters() pool = Parameters()
for param in topology.proto().parameters: for param in topology.proto().parameters:
pool.__append_config__(param) pool.__append_config__(param)
......
...@@ -12,10 +12,13 @@ ...@@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.proto.ModelConfig_pb2 import ModelConfig import collections
import paddle.trainer_config_helpers as conf_helps import paddle.trainer_config_helpers as conf_helps
import layer as v2_layer from paddle.proto.ModelConfig_pb2 import ModelConfig
import data_type import data_type
import layer as v2_layer
__all__ = ['Topology'] __all__ = ['Topology']
...@@ -26,11 +29,12 @@ class Topology(object): ...@@ -26,11 +29,12 @@ class Topology(object):
and network configs. and network configs.
""" """
def __init__(self, *layers): def __init__(self, layers):
if not isinstance(layers, collections.Sequence):
raise ValueError("input of Topology should be a list of Layer")
for layer in layers: for layer in layers:
if not isinstance(layer, v2_layer.LayerV2): if not isinstance(layer, v2_layer.LayerV2):
raise ValueError('create must pass a topologies ' raise ValueError('layer should have type paddle.layer.Layer')
'which type is paddle.layer.Layer')
self.layers = layers self.layers = layers
self.__model_config__ = v2_layer.parse_network(*layers) self.__model_config__ = v2_layer.parse_network(*layers)
assert isinstance(self.__model_config__, ModelConfig) assert isinstance(self.__model_config__, ModelConfig)
......
...@@ -73,7 +73,7 @@ class SGD(ITrainer): ...@@ -73,7 +73,7 @@ class SGD(ITrainer):
Training method. Will train num_passes of input data. Training method. Will train num_passes of input data.
:param train_data_reader: :param train_data_reader:
:param topology: Network Topology, use one or more Layers to represent it. :param topology: cost layers, use one or more Layers to represent it.
:param parameters: The parameter pools. :param parameters: The parameter pools.
:param num_passes: The total train passes. :param num_passes: The total train passes.
:param test_data_reader: :param test_data_reader:
...@@ -87,6 +87,8 @@ class SGD(ITrainer): ...@@ -87,6 +87,8 @@ class SGD(ITrainer):
if event_handler is None: if event_handler is None:
event_handler = default_event_handler event_handler = default_event_handler
topology = v2_topology.Topology(topology)
__check_train_args__(**locals()) __check_train_args__(**locals())
gm = api.GradientMachine.createFromConfigProto( gm = api.GradientMachine.createFromConfigProto(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册