diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index a51b1073b4fc4fd3ac44c355e050b0d720944645..6329915350bbb4bcec5723fd249bc35b3ad40bc9 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -68,13 +68,18 @@ class Topology(object): get all data layer :return: """ - data_layers = set() + data_layers = dict() def find_data_layer(layer): if isinstance(layer, v2_layer.DataLayerV2): - data_layers.add(layer) - for parent_layer in layer.__parent_layers__.values(): - find_data_layer(parent_layer) + data_layers[layer.name] = layer + if not isinstance(layer, collections.Sequence): + for parent_layer in layer.__parent_layers__.values(): + find_data_layer(parent_layer) + else: + for each_l in layer: + for parent_layer in each_l.__parent_layers__.values(): + find_data_layer(parent_layer) for layer in self.layers: find_data_layer(layer) @@ -86,8 +91,12 @@ class Topology(object): get data_type from proto, such as: [('image', dense_vector(768)), ('label', integer_value(10))] """ - return [(data_layer.name, data_layer.type) - for data_layer in self.data_layers()] + + data_types_lists = [] + for each in self.__model_config__.input_layer_names: + data_layers = self.data_layers() + data_types_lists.append((each, data_layers[each].type)) + return data_types_lists def __check_layer_type__(layer):