From dc530a714f1414f70a6e30c11f9d4a22eeaef544 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 18 May 2017 14:54:22 +0800 Subject: [PATCH] Expose more interfaces for Arguments in swig. --- paddle/api/Arguments.cpp | 20 ++++++++++++++++++++ paddle/api/PaddleAPI.h | 19 +++++++++++++++++++ paddle/api/test/testArguments.py | 12 ++++++++++++ 3 files changed, 51 insertions(+) diff --git a/paddle/api/Arguments.cpp b/paddle/api/Arguments.cpp index d49b189e2..c6f910691 100644 --- a/paddle/api/Arguments.cpp +++ b/paddle/api/Arguments.cpp @@ -151,4 +151,24 @@ int64_t Arguments::getBatchSize(size_t idx) const throw(RangeError) { return a.getBatchSize(); } +void Arguments::setSlotFrameHeight(size_t idx, size_t h) throw(RangeError) { + auto& a = m->getArg(idx); + a.setFrameHeight(h); +} + +void Arguments::setSlotFrameWidth(size_t idx, size_t w) throw(RangeError) { + auto& a = m->getArg(idx); + a.setFrameWidth(w); +} + +size_t Arguments::getSlotFrameHeight(size_t idx) const throw(RangeError) { + auto& a = m->getArg(idx); + return a.getFrameHeight(); +} + +size_t Arguments::getSlotFrameWidth(size_t idx) const throw(RangeError) { + auto& a = m->getArg(idx); + return a.getFrameWidth(); +} + void* Arguments::getInternalArgumentsPtr() const { return &m->outputs; } diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index d51204012..da0f157ab 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -454,6 +454,25 @@ public: IVector* vec) throw(RangeError); void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError); + /** + * Set the frame height of the idx-th Argument. + * + * @param ids The index of which Argument. + * @param h The height value. + */ + void setSlotFrameHeight(size_t idx, size_t h) throw(RangeError); + + /** + * Set the frame height of the idx-th Argument. + * + * @param ids The index of which Argument. + * @param h The height value. + */ + void setSlotFrameWidth(size_t idx, size_t w) throw(RangeError); + + size_t getSlotFrameHeight(size_t idx = 0) const throw(RangeError); + size_t getSlotFrameWidth(size_t idx = 0) const throw(RangeError); + float sum() const; private: diff --git a/paddle/api/test/testArguments.py b/paddle/api/test/testArguments.py index 9fe44de94..4d40ffec9 100644 --- a/paddle/api/test/testArguments.py +++ b/paddle/api/test/testArguments.py @@ -13,6 +13,7 @@ # limitations under the License. from py_paddle import swig_paddle +import numpy as np import unittest @@ -36,6 +37,17 @@ class TestArguments(unittest.TestCase): np_arr = iv.toNumpyArrayInplace() self.assertEqual(np_arr.shape, (6, )) + def test_arguments_shape(self): + h, w = 4, 6 + v = np.random.rand(2, h * w) + m = swig_paddle.Matrix.createDense(v.flatten(), 2, h * w) + args = swig_paddle.Arguments.createArguments(1) + args.setSlotValue(0, m) + args.setSlotFrameHeight(0, h) + args.setSlotFrameWidth(0, w) + self.assertEqual(args.getSlotFrameHeight(), h) + self.assertEqual(args.getSlotFrameWidth(), w) + if __name__ == '__main__': swig_paddle.initPaddle("--use_gpu=0") -- GitLab