提交 0dc68a2c 编写于 作者: L liaogang

add getNonStaticParameters

上级 49020f0b
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License
from api_v2_vgg import resnet_cifar10
from api_v2_resnet import vgg_bn_drop
from api_v2_vgg import vgg_bn_drop
from api_v2_resnet import resnet_cifar10
import paddle.v2 as paddle
......
......@@ -142,6 +142,20 @@ Parameter* GradientMachine::getParameter(size_t i) throw(RangeError) {
}
}
size_t GradientMachine::getNonStaticParameterSize() const {
return m->machine->getNonStaticParameters().size();
}
Parameter* GradientMachine::getNonStaticParameter(size_t i) throw(RangeError) {
auto params = m->machine->getNonStaticParameters();
if (i < params.size()) {
return Parameter::createFromSharedPtr(
&m->machine->getNonStaticParameters()[i]);
} else {
throw RangeError();
}
}
void GradientMachine::randParameters() { m->machine->randParameters(); }
Arguments* GradientMachine::getLayerOutput(const std::string& layerName) const
......
......@@ -768,6 +768,9 @@ public:
size_t getParameterSize() const;
Parameter* getParameter(size_t i) throw(RangeError);
size_t getNonStaticParameterSize() const;
Parameter* getNonStaticParameter(size_t i) throw(RangeError);
void randParameters();
Arguments* getLayerOutput(const std::string& layerName) const
......
......@@ -195,6 +195,12 @@ def __monkeypatch_gradient_machine__():
swig_paddle.GradientMachine.getParameters = getParameters
def getNonStaticParameters(self):
return (self.getNonStaticParameter(i)
for i in xrange(self.getNonStaticParameterSize()))
swig_paddle.GradientMachine.getParameters = getParameters
def getLayerOutputs(self, layerNames):
"""
getLayerOutputs. get outputs of layers and return a numpy matrix dict.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册