提交 051d15cf 编写于 作者: Q qiaolongfei

add prefetch and onPassEnd to PaddleApi.h

上级 dae8b9bc
...@@ -64,6 +64,14 @@ GradientMachine* GradientMachine::createByModelConfig( ...@@ -64,6 +64,14 @@ GradientMachine* GradientMachine::createByModelConfig(
return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types); return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
} }
void GradientMachine::onPassEnd() { m->machine->onPassEnd(); }
void GradientMachine::prefetch(const Arguments& inArgs) {
auto& in =
m->cast<std::vector<paddle::Argument>>(inArgs.getInternalArgumentsPtr());
m->machine->prefetch(in);
}
void GradientMachine::forward(const Arguments& inArgs, void GradientMachine::forward(const Arguments& inArgs,
Arguments* outArgs, Arguments* outArgs,
PassType passType) { PassType passType) {
......
...@@ -714,6 +714,16 @@ public: ...@@ -714,6 +714,16 @@ public:
GradientMatchineCreateMode mode = CREATE_MODE_NORMAL, GradientMatchineCreateMode mode = CREATE_MODE_NORMAL,
const std::vector<int>& parameterTypes = defaultParamTypes); const std::vector<int>& parameterTypes = defaultParamTypes);
/**
* Prefetch row ids of sparse parameter.
*/
void prefetch(const Arguments& inArgs);
/**
* Do some thing when train pass ended.
*/
void onPassEnd();
/** /**
* The forward stage of GradientMachine. * The forward stage of GradientMachine.
* *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册