diff --git a/paddle/api/Evaluator.cpp b/paddle/api/Evaluator.cpp index f9656db19a0386c9fb3ce1fe2e0263c4918d7073..681e3a380912339c531c16c88f43255c2f34c32f 100644 --- a/paddle/api/Evaluator.cpp +++ b/paddle/api/Evaluator.cpp @@ -33,3 +33,12 @@ std::vector Evaluator::getNames() const { m->rawPtr->getNames(&retv); return retv; } + +double Evaluator::getValue(const std::string name) const { + paddle::Error err; + double v = m->rawPtr->getValue(name, &err); + if (err) { + throw std::runtime_error(err.msg()); + } + return v; +} diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index f5dcfcf94c658a0f4c593cf25390b028eb7c247a..80c50cdb08bb16443f7772964c49cc1fbc591f37 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -902,6 +902,8 @@ public: std::vector getNames() const; + double getValue(const std::string name) const; + private: EvaluatorPrivate* m; diff --git a/paddle/api/test/testTrain.py b/paddle/api/test/testTrain.py index a90d15c272a3a2b56e35c979e053deb2b54eebc1..7061a4c43bf01158b5f084d0c310dedd81773a04 100644 --- a/paddle/api/test/testTrain.py +++ b/paddle/api/test/testTrain.py @@ -89,9 +89,14 @@ def main(): except Exception as e: print e + ev = m.makeEvaluator() + ev.start() m.forwardBackward(inArgs, outArgs, swig_paddle.PASS_TRAIN, update_callback) - + m.eval(ev) + ev.finish() + for name in ev.getNames(): + print name, ev.getValue(name) for optimizer in optimizers: optimizer.finishBatch()