提交 b9dd33f8 编写于 作者: Q qiaolongfei

hide Topology

上级 53bd4a48
......@@ -26,9 +26,7 @@ def main():
act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=inference, label=label)
topology = paddle.topology.Topology(cost)
parameters = paddle.parameters.create(topology)
parameters = paddle.parameters.create([cost])
for param_name in parameters.keys():
array = parameters.get(param_name)
array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
......@@ -49,7 +47,7 @@ def main():
trainer.train(
train_data_reader=train_reader,
topology=topology,
topology=[cost],
parameters=parameters,
event_handler=event_handler,
batch_size=32) # batch size should be refactor in Data reader
......
......@@ -7,15 +7,13 @@ import topology as v2_topology
__all__ = ['Parameters', 'create']
def create(topology):
def create(layers):
"""
Create parameter pool by topology.
:param topology:
:param layers:
:return:
"""
if not isinstance(topology, v2_topology.Topology):
raise ValueError(
'create must pass a topology which type is topology.Topology')
topology = v2_topology.Topology(layers)
pool = Parameters()
for param in topology.proto().parameters:
pool.__append_config__(param)
......
......@@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.proto.ModelConfig_pb2 import ModelConfig
import collections
import paddle.trainer_config_helpers as conf_helps
import layer as v2_layer
from paddle.proto.ModelConfig_pb2 import ModelConfig
import data_type
import layer as v2_layer
__all__ = ['Topology']
......@@ -26,11 +29,12 @@ class Topology(object):
and network configs.
"""
def __init__(self, *layers):
def __init__(self, layers):
if not isinstance(layers, collections.Sequence):
raise ValueError("input of Topology should be a list of Layer")
for layer in layers:
if not isinstance(layer, v2_layer.LayerV2):
raise ValueError('create must pass a topologies '
'which type is paddle.layer.Layer')
raise ValueError('layer should have type paddle.layer.Layer')
self.layers = layers
self.__model_config__ = v2_layer.parse_network(*layers)
assert isinstance(self.__model_config__, ModelConfig)
......
......@@ -73,7 +73,7 @@ class SGD(ITrainer):
Training method. Will train num_passes of input data.
:param train_data_reader:
:param topology: Network Topology, use one or more Layers to represent it.
:param topology: cost layers, use one or more Layers to represent it.
:param parameters: The parameter pools.
:param num_passes: The total train passes.
:param test_data_reader:
......@@ -87,6 +87,8 @@ class SGD(ITrainer):
if event_handler is None:
event_handler = default_event_handler
topology = v2_topology.Topology(topology)
__check_train_args__(**locals())
gm = api.GradientMachine.createFromConfigProto(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册