From ffd045a0d0c127e28eb9283c91f9f439984a0bdf Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 28 Feb 2017 21:49:54 +0800 Subject: [PATCH] Fix data_layers and data_type function in topology.py --- python/paddle/v2/topology.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index a51b1073b4f..6329915350b 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -68,13 +68,18 @@ class Topology(object): get all data layer :return: """ - data_layers = set() + data_layers = dict() def find_data_layer(layer): if isinstance(layer, v2_layer.DataLayerV2): - data_layers.add(layer) - for parent_layer in layer.__parent_layers__.values(): - find_data_layer(parent_layer) + data_layers[layer.name] = layer + if not isinstance(layer, collections.Sequence): + for parent_layer in layer.__parent_layers__.values(): + find_data_layer(parent_layer) + else: + for each_l in layer: + for parent_layer in each_l.__parent_layers__.values(): + find_data_layer(parent_layer) for layer in self.layers: find_data_layer(layer) @@ -86,8 +91,12 @@ class Topology(object): get data_type from proto, such as: [('image', dense_vector(768)), ('label', integer_value(10))] """ - return [(data_layer.name, data_layer.type) - for data_layer in self.data_layers()] + + data_types_lists = [] + for each in self.__model_config__.input_layer_names: + data_layers = self.data_layers() + data_types_lists.append((each, data_layers[each].type)) + return data_types_lists def __check_layer_type__(layer): -- GitLab