From 1976f137b2dc28fa6080f6df27755a5503956240 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 9 May 2017 22:27:09 +0800 Subject: [PATCH] add attr() interface to layer --- python/paddle/v2/config_base.py | 5 +++++ python/paddle/v2/topology.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/python/paddle/v2/config_base.py b/python/paddle/v2/config_base.py index b0e8da563e..acda778e0a 100644 --- a/python/paddle/v2/config_base.py +++ b/python/paddle/v2/config_base.py @@ -16,6 +16,7 @@ import collections import re from paddle.trainer_config_helpers.default_decorators import wrap_name_default import paddle.trainer_config_helpers as conf_helps +from topology import Topology class LayerType(type): @@ -161,6 +162,10 @@ class Layer(object): """ return self.__context__[self.context_name()].size + def attr(self): + topo = Topology(self) + return topo.get_layer_proto(self.name) + def __convert_to_v2__(method_name, parent_names, diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index ff28c85c53..1e46e4973f 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -130,6 +130,12 @@ class Topology(object): return [(nm, data_layers[nm].type) for nm in self.proto().input_layer_names] + def get_layer_proto(self, name): + for layer in self.__model_config__.layers: + if layer.name == name: + return layer + return None + def __check_layer_type__(layer): if not isinstance(layer, v2_layer.LayerV2): -- GitLab