提交 55684af2 编写于 作者: Q qiaolongfei

fix MultiGradientMachine train and infer

上级 7573e9eb
...@@ -171,6 +171,12 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config, ...@@ -171,6 +171,12 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config,
} }
} }
MultiGradientMachine::~MultiGradientMachine() {
for (auto& thread : threads_) {
thread->stop();
}
}
std::vector<const std::vector<ParameterPtr>*> std::vector<const std::vector<ParameterPtr>*>
MultiGradientMachine::getSlaveParameters() { MultiGradientMachine::getSlaveParameters() {
std::vector<const std::vector<ParameterPtr>*> vec; std::vector<const std::vector<ParameterPtr>*> vec;
...@@ -326,12 +332,6 @@ void MultiGradientMachine::onPassEnd() { ...@@ -326,12 +332,6 @@ void MultiGradientMachine::onPassEnd() {
} }
} }
void MultiGradientMachine::finish() {
for (auto& thread : threads_) {
thread->stop();
}
}
Evaluator* MultiGradientMachine::makeEvaluator() const { Evaluator* MultiGradientMachine::makeEvaluator() const {
return threads_[0]->getGradientMachine()->makeEvaluator(); return threads_[0]->getGradientMachine()->makeEvaluator();
} }
......
...@@ -176,6 +176,8 @@ public: ...@@ -176,6 +176,8 @@ public:
explicit MultiGradientMachine(const ModelConfig& config, bool useGpu); explicit MultiGradientMachine(const ModelConfig& config, bool useGpu);
virtual ~MultiGradientMachine();
virtual void prefetch(const std::vector<Argument>& inArgs); virtual void prefetch(const std::vector<Argument>& inArgs);
virtual void forward(const std::vector<Argument>& inArgs, virtual void forward(const std::vector<Argument>& inArgs,
...@@ -193,8 +195,6 @@ public: ...@@ -193,8 +195,6 @@ public:
virtual void onPassEnd(); virtual void onPassEnd();
virtual void finish();
virtual Evaluator* makeEvaluator() const; virtual Evaluator* makeEvaluator() const;
virtual void eval(Evaluator* evaluator) const; virtual void eval(Evaluator* evaluator) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册