diff --git a/paddle/api/Arguments.cpp b/paddle/api/Arguments.cpp index d49b189e253f7a0792fe3f1fe7c8fdbb7071acd4..c6f9106912c475dda76b4c11e0793cc5a9f78d3f 100644 --- a/paddle/api/Arguments.cpp +++ b/paddle/api/Arguments.cpp @@ -151,4 +151,24 @@ int64_t Arguments::getBatchSize(size_t idx) const throw(RangeError) { return a.getBatchSize(); } +void Arguments::setSlotFrameHeight(size_t idx, size_t h) throw(RangeError) { + auto& a = m->getArg(idx); + a.setFrameHeight(h); +} + +void Arguments::setSlotFrameWidth(size_t idx, size_t w) throw(RangeError) { + auto& a = m->getArg(idx); + a.setFrameWidth(w); +} + +size_t Arguments::getSlotFrameHeight(size_t idx) const throw(RangeError) { + auto& a = m->getArg(idx); + return a.getFrameHeight(); +} + +size_t Arguments::getSlotFrameWidth(size_t idx) const throw(RangeError) { + auto& a = m->getArg(idx); + return a.getFrameWidth(); +} + void* Arguments::getInternalArgumentsPtr() const { return &m->outputs; } diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index d51204012171c9887acd5f578f913143182efe36..da0f157abd68c73c45f498cf9ef2726aac67c95b 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -454,6 +454,25 @@ public: IVector* vec) throw(RangeError); void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError); + /** + * Set the frame height of the idx-th Argument. + * + * @param ids The index of which Argument. + * @param h The height value. + */ + void setSlotFrameHeight(size_t idx, size_t h) throw(RangeError); + + /** + * Set the frame height of the idx-th Argument. + * + * @param ids The index of which Argument. + * @param h The height value. + */ + void setSlotFrameWidth(size_t idx, size_t w) throw(RangeError); + + size_t getSlotFrameHeight(size_t idx = 0) const throw(RangeError); + size_t getSlotFrameWidth(size_t idx = 0) const throw(RangeError); + float sum() const; private: diff --git a/paddle/api/test/testArguments.py b/paddle/api/test/testArguments.py index 9fe44de94ea6ddb71d2dfbb2243fc86ede0d0531..4d40ffec9a030bf756a515266b2c33915fcc4e10 100644 --- a/paddle/api/test/testArguments.py +++ b/paddle/api/test/testArguments.py @@ -13,6 +13,7 @@ # limitations under the License. from py_paddle import swig_paddle +import numpy as np import unittest @@ -36,6 +37,17 @@ class TestArguments(unittest.TestCase): np_arr = iv.toNumpyArrayInplace() self.assertEqual(np_arr.shape, (6, )) + def test_arguments_shape(self): + h, w = 4, 6 + v = np.random.rand(2, h * w) + m = swig_paddle.Matrix.createDense(v.flatten(), 2, h * w) + args = swig_paddle.Arguments.createArguments(1) + args.setSlotValue(0, m) + args.setSlotFrameHeight(0, h) + args.setSlotFrameWidth(0, w) + self.assertEqual(args.getSlotFrameHeight(), h) + self.assertEqual(args.getSlotFrameWidth(), w) + if __name__ == '__main__': swig_paddle.initPaddle("--use_gpu=0")