提交 55684af2 编写于 作者: Q qiaolongfei

fix MultiGradientMachine train and infer

上级 7573e9eb
......@@ -171,6 +171,12 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config,
}
}
MultiGradientMachine::~MultiGradientMachine() {
for (auto& thread : threads_) {
thread->stop();
}
}
std::vector<const std::vector<ParameterPtr>*>
MultiGradientMachine::getSlaveParameters() {
std::vector<const std::vector<ParameterPtr>*> vec;
......@@ -326,12 +332,6 @@ void MultiGradientMachine::onPassEnd() {
}
}
void MultiGradientMachine::finish() {
for (auto& thread : threads_) {
thread->stop();
}
}
Evaluator* MultiGradientMachine::makeEvaluator() const {
return threads_[0]->getGradientMachine()->makeEvaluator();
}
......
......@@ -176,6 +176,8 @@ public:
explicit MultiGradientMachine(const ModelConfig& config, bool useGpu);
virtual ~MultiGradientMachine();
virtual void prefetch(const std::vector<Argument>& inArgs);
virtual void forward(const std::vector<Argument>& inArgs,
......@@ -193,8 +195,6 @@ public:
virtual void onPassEnd();
virtual void finish();
virtual Evaluator* makeEvaluator() const;
virtual void eval(Evaluator* evaluator) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册