提交 0dc68a2c 编写于 作者: L liaogang

add getNonStaticParameters

上级 49020f0b
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from api_v2_vgg import resnet_cifar10 from api_v2_vgg import vgg_bn_drop
from api_v2_resnet import vgg_bn_drop from api_v2_resnet import resnet_cifar10
import paddle.v2 as paddle import paddle.v2 as paddle
......
...@@ -142,6 +142,20 @@ Parameter* GradientMachine::getParameter(size_t i) throw(RangeError) { ...@@ -142,6 +142,20 @@ Parameter* GradientMachine::getParameter(size_t i) throw(RangeError) {
} }
} }
size_t GradientMachine::getNonStaticParameterSize() const {
return m->machine->getNonStaticParameters().size();
}
Parameter* GradientMachine::getNonStaticParameter(size_t i) throw(RangeError) {
auto params = m->machine->getNonStaticParameters();
if (i < params.size()) {
return Parameter::createFromSharedPtr(
&m->machine->getNonStaticParameters()[i]);
} else {
throw RangeError();
}
}
void GradientMachine::randParameters() { m->machine->randParameters(); } void GradientMachine::randParameters() { m->machine->randParameters(); }
Arguments* GradientMachine::getLayerOutput(const std::string& layerName) const Arguments* GradientMachine::getLayerOutput(const std::string& layerName) const
......
...@@ -768,6 +768,9 @@ public: ...@@ -768,6 +768,9 @@ public:
size_t getParameterSize() const; size_t getParameterSize() const;
Parameter* getParameter(size_t i) throw(RangeError); Parameter* getParameter(size_t i) throw(RangeError);
size_t getNonStaticParameterSize() const;
Parameter* getNonStaticParameter(size_t i) throw(RangeError);
void randParameters(); void randParameters();
Arguments* getLayerOutput(const std::string& layerName) const Arguments* getLayerOutput(const std::string& layerName) const
......
...@@ -195,6 +195,12 @@ def __monkeypatch_gradient_machine__(): ...@@ -195,6 +195,12 @@ def __monkeypatch_gradient_machine__():
swig_paddle.GradientMachine.getParameters = getParameters swig_paddle.GradientMachine.getParameters = getParameters
def getNonStaticParameters(self):
return (self.getNonStaticParameter(i)
for i in xrange(self.getNonStaticParameterSize()))
swig_paddle.GradientMachine.getParameters = getParameters
def getLayerOutputs(self, layerNames): def getLayerOutputs(self, layerNames):
""" """
getLayerOutputs. get outputs of layers and return a numpy matrix dict. getLayerOutputs. get outputs of layers and return a numpy matrix dict.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册