From e3d4da2de3a442c85cecdde8fbc9407b54dba0f0 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 26 Dec 2016 14:15:29 +0800 Subject: [PATCH] Add sum cost to Arguments --- paddle/api/Arguments.cpp | 4 ++++ paddle/api/PaddleAPI.h | 2 ++ paddle/api/test/testArguments.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/paddle/api/Arguments.cpp b/paddle/api/Arguments.cpp index 0cafbd896e2..41beed38a87 100644 --- a/paddle/api/Arguments.cpp +++ b/paddle/api/Arguments.cpp @@ -137,6 +137,10 @@ void Arguments::setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError) { a.cpuSequenceDims = m->cast(vec->getSharedPtr()); } +float Arguments::sumCosts() const { + return paddle::Argument::sumCosts(m->outputs); +} + int64_t Arguments::getBatchSize(size_t idx) const throw(RangeError) { auto& a = m->getArg(idx); return a.getBatchSize(); diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 7521ff4c6c6..155e3e3afe6 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -454,6 +454,8 @@ public: IVector* vec) throw(RangeError); void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError); + float sumCosts() const; + private: static Arguments* createByPaddleArgumentVector(void* ptr); void* getInternalArgumentsPtr() const; diff --git a/paddle/api/test/testArguments.py b/paddle/api/test/testArguments.py index 8cabecd242f..a04a805d7a6 100644 --- a/paddle/api/test/testArguments.py +++ b/paddle/api/test/testArguments.py @@ -22,6 +22,8 @@ class TestArguments(unittest.TestCase): args = swig_paddle.Arguments.createArguments(1) args.setSlotValue(0, m) + self.assertAlmostEqual(27.0, args.sumCosts()) + mat = args.getSlotValue(0) assert isinstance(mat, swig_paddle.Matrix) np_mat = mat.toNumpyMatInplace() -- GitLab