From 051d15cf0044e330515ab0a10ca82e0fdb576105 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Fri, 23 Dec 2016 17:44:25 +0800 Subject: [PATCH] add prefetch and onPassEnd to PaddleApi.h --- paddle/api/GradientMachine.cpp | 8 ++++++++ paddle/api/PaddleAPI.h | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/paddle/api/GradientMachine.cpp b/paddle/api/GradientMachine.cpp index 297eaa19bb9..ced2293376c 100644 --- a/paddle/api/GradientMachine.cpp +++ b/paddle/api/GradientMachine.cpp @@ -64,6 +64,14 @@ GradientMachine* GradientMachine::createByModelConfig( return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types); } +void GradientMachine::onPassEnd() { m->machine->onPassEnd(); } + +void GradientMachine::prefetch(const Arguments& inArgs) { + auto& in = + m->cast>(inArgs.getInternalArgumentsPtr()); + m->machine->prefetch(in); +} + void GradientMachine::forward(const Arguments& inArgs, Arguments* outArgs, PassType passType) { diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 84a66719c33..7521ff4c6c6 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -714,6 +714,16 @@ public: GradientMatchineCreateMode mode = CREATE_MODE_NORMAL, const std::vector& parameterTypes = defaultParamTypes); + /** + * Prefetch row ids of sparse parameter. + */ + void prefetch(const Arguments& inArgs); + + /** + * Do some thing when train pass ended. + */ + void onPassEnd(); + /** * The forward stage of GradientMachine. * -- GitLab