diff --git a/paddle/api/GradientMachine.cpp b/paddle/api/GradientMachine.cpp index 297eaa19bb9981c7f07c90763d76494b7910af93..ced2293376cae51eb5ac9cd27133f13174f61e3c 100644 --- a/paddle/api/GradientMachine.cpp +++ b/paddle/api/GradientMachine.cpp @@ -64,6 +64,14 @@ GradientMachine* GradientMachine::createByModelConfig( return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types); } +void GradientMachine::onPassEnd() { m->machine->onPassEnd(); } + +void GradientMachine::prefetch(const Arguments& inArgs) { + auto& in = + m->cast>(inArgs.getInternalArgumentsPtr()); + m->machine->prefetch(in); +} + void GradientMachine::forward(const Arguments& inArgs, Arguments* outArgs, PassType passType) { diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 84a66719c33678fc4aeb038bb81a6b7c5d0c93fb..7521ff4c6c654c4b59e43abb365dad245a9bd189 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -714,6 +714,16 @@ public: GradientMatchineCreateMode mode = CREATE_MODE_NORMAL, const std::vector& parameterTypes = defaultParamTypes); + /** + * Prefetch row ids of sparse parameter. + */ + void prefetch(const Arguments& inArgs); + + /** + * Do some thing when train pass ended. + */ + void onPassEnd(); + /** * The forward stage of GradientMachine. *