提交 debd4e4d 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #2197 from qingqing01/variable_input

Expose more interfaces for Arguments in swig.
......@@ -151,4 +151,24 @@ int64_t Arguments::getBatchSize(size_t idx) const throw(RangeError) {
return a.getBatchSize();
}
void Arguments::setSlotFrameHeight(size_t idx, size_t h) throw(RangeError) {
auto& a = m->getArg(idx);
a.setFrameHeight(h);
}
void Arguments::setSlotFrameWidth(size_t idx, size_t w) throw(RangeError) {
auto& a = m->getArg(idx);
a.setFrameWidth(w);
}
size_t Arguments::getSlotFrameHeight(size_t idx) const throw(RangeError) {
auto& a = m->getArg(idx);
return a.getFrameHeight();
}
size_t Arguments::getSlotFrameWidth(size_t idx) const throw(RangeError) {
auto& a = m->getArg(idx);
return a.getFrameWidth();
}
void* Arguments::getInternalArgumentsPtr() const { return &m->outputs; }
......@@ -454,6 +454,25 @@ public:
IVector* vec) throw(RangeError);
void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError);
/**
* Set the frame height of the idx-th Argument.
*
* @param ids The index of which Argument.
* @param h The height value.
*/
void setSlotFrameHeight(size_t idx, size_t h) throw(RangeError);
/**
* Set the frame height of the idx-th Argument.
*
* @param ids The index of which Argument.
* @param h The height value.
*/
void setSlotFrameWidth(size_t idx, size_t w) throw(RangeError);
size_t getSlotFrameHeight(size_t idx = 0) const throw(RangeError);
size_t getSlotFrameWidth(size_t idx = 0) const throw(RangeError);
float sum() const;
private:
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from py_paddle import swig_paddle
import numpy as np
import unittest
......@@ -36,6 +37,17 @@ class TestArguments(unittest.TestCase):
np_arr = iv.toNumpyArrayInplace()
self.assertEqual(np_arr.shape, (6, ))
def test_arguments_shape(self):
h, w = 4, 6
v = np.random.rand(2, h * w)
m = swig_paddle.Matrix.createDense(v.flatten(), 2, h * w)
args = swig_paddle.Arguments.createArguments(1)
args.setSlotValue(0, m)
args.setSlotFrameHeight(0, h)
args.setSlotFrameWidth(0, w)
self.assertEqual(args.getSlotFrameHeight(), h)
self.assertEqual(args.getSlotFrameWidth(), w)
if __name__ == '__main__':
swig_paddle.initPaddle("--use_gpu=0")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册