提交 1e6c87bd 编写于 作者: Y Yu Yang

Merge branch 'feature/add_const_in_gradient_machine_eval' into feature/mnist_train_api

...@@ -181,12 +181,12 @@ public: ...@@ -181,12 +181,12 @@ public:
/** /**
* Create an evaluator which can be used for eval() * Create an evaluator which can be used for eval()
*/ */
virtual Evaluator* makeEvaluator() = 0; virtual Evaluator* makeEvaluator() const = 0;
/** /**
* evaluate using the given evaluator * evaluate using the given evaluator
*/ */
virtual void eval(Evaluator* evaluator) = 0; virtual void eval(Evaluator* evaluator) const = 0;
std::vector<ParameterPtr>& getParameters() { return parameters_; } std::vector<ParameterPtr>& getParameters() { return parameters_; }
......
...@@ -327,11 +327,11 @@ void MultiGradientMachine::finish() { ...@@ -327,11 +327,11 @@ void MultiGradientMachine::finish() {
} }
} }
Evaluator* MultiGradientMachine::makeEvaluator() { Evaluator* MultiGradientMachine::makeEvaluator() const {
return threads_[0]->getGradientMachine()->makeEvaluator(); return threads_[0]->getGradientMachine()->makeEvaluator();
} }
void MultiGradientMachine::eval(Evaluator* evaluator) { void MultiGradientMachine::eval(Evaluator* evaluator) const {
for (auto& thread : threads_) { for (auto& thread : threads_) {
SetDevice device(thread->getDeviceId()); SetDevice device(thread->getDeviceId());
thread->getGradientMachine()->eval(evaluator); thread->getGradientMachine()->eval(evaluator);
......
...@@ -193,9 +193,9 @@ public: ...@@ -193,9 +193,9 @@ public:
virtual void finish(); virtual void finish();
virtual Evaluator* makeEvaluator(); virtual Evaluator* makeEvaluator() const;
virtual void eval(Evaluator* evaluator); virtual void eval(Evaluator* evaluator) const;
bool useGpu() const { return useGpu_; } bool useGpu() const { return useGpu_; }
......
...@@ -171,7 +171,7 @@ protected: ...@@ -171,7 +171,7 @@ protected:
std::vector<std::unique_ptr<Evaluator>> evaluators_; std::vector<std::unique_ptr<Evaluator>> evaluators_;
}; };
Evaluator* MultiNetwork::makeEvaluator() { Evaluator* MultiNetwork::makeEvaluator() const {
MultiCombinedEvaluator* multiCombinedEvaluator = new MultiCombinedEvaluator(); MultiCombinedEvaluator* multiCombinedEvaluator = new MultiCombinedEvaluator();
for (size_t i = 0; i < subNetworks_.size(); i++) { for (size_t i = 0; i < subNetworks_.size(); i++) {
std::unique_ptr<Evaluator> evaluator(subNetworks_[i]->makeEvaluator()); std::unique_ptr<Evaluator> evaluator(subNetworks_[i]->makeEvaluator());
...@@ -180,6 +180,6 @@ Evaluator* MultiNetwork::makeEvaluator() { ...@@ -180,6 +180,6 @@ Evaluator* MultiNetwork::makeEvaluator() {
return multiCombinedEvaluator; return multiCombinedEvaluator;
} }
void MultiNetwork::eval(Evaluator* evaluator) { evaluator->eval(*this); } void MultiNetwork::eval(Evaluator* evaluator) const { evaluator->eval(*this); }
} // namespace paddle } // namespace paddle
...@@ -46,9 +46,9 @@ public: ...@@ -46,9 +46,9 @@ public:
virtual void onPassEnd(); virtual void onPassEnd();
virtual Evaluator* makeEvaluator(); virtual Evaluator* makeEvaluator() const;
virtual void eval(Evaluator* evaluator); virtual void eval(Evaluator* evaluator) const;
const std::vector<std::unique_ptr<NeuralNetwork>>& getSubNetworks() const { const std::vector<std::unique_ptr<NeuralNetwork>>& getSubNetworks() const {
return subNetworks_; return subNetworks_;
......
...@@ -348,7 +348,7 @@ protected: ...@@ -348,7 +348,7 @@ protected:
std::vector<std::unique_ptr<Evaluator>> evaluators_; std::vector<std::unique_ptr<Evaluator>> evaluators_;
}; };
Evaluator* NeuralNetwork::makeEvaluator() { Evaluator* NeuralNetwork::makeEvaluator() const {
CombinedEvaluator* combinedEvaluator = new CombinedEvaluator(); CombinedEvaluator* combinedEvaluator = new CombinedEvaluator();
auto subModelConfig = std::find_if(config_.sub_models().begin(), auto subModelConfig = std::find_if(config_.sub_models().begin(),
config_.sub_models().end(), config_.sub_models().end(),
...@@ -383,7 +383,7 @@ Evaluator* NeuralNetwork::makeEvaluator() { ...@@ -383,7 +383,7 @@ Evaluator* NeuralNetwork::makeEvaluator() {
return combinedEvaluator; return combinedEvaluator;
} }
void NeuralNetwork::eval(Evaluator* evaluator) { evaluator->eval(*this); } void NeuralNetwork::eval(Evaluator* evaluator) const { evaluator->eval(*this); }
void NeuralNetwork::setOutputGrad(const std::vector<Argument>& args) { void NeuralNetwork::setOutputGrad(const std::vector<Argument>& args) {
CHECK_GE(outputLayers_.size(), args.size()); CHECK_GE(outputLayers_.size(), args.size());
......
...@@ -96,9 +96,9 @@ public: ...@@ -96,9 +96,9 @@ public:
virtual void onPassEnd(); virtual void onPassEnd();
virtual Evaluator* makeEvaluator(); virtual Evaluator* makeEvaluator() const;
virtual void eval(Evaluator* evaluator); virtual void eval(Evaluator* evaluator) const;
virtual void resetState(); virtual void resetState();
virtual void setOutputGrad(const std::vector<Argument>& args); virtual void setOutputGrad(const std::vector<Argument>& args);
......
...@@ -593,7 +593,7 @@ void RecurrentGradientMachine::forwardBackward( ...@@ -593,7 +593,7 @@ void RecurrentGradientMachine::forwardBackward(
LOG(FATAL) << "should not use this function"; LOG(FATAL) << "should not use this function";
} }
void RecurrentGradientMachine::eval(Evaluator* evaluator) { void RecurrentGradientMachine::eval(Evaluator* evaluator) const {
// call printers frame by frame // call printers frame by frame
for (int i = 0; i < maxSequenceLength_; ++i) { for (int i = 0; i < maxSequenceLength_; ++i) {
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " begin"; LOG(INFO) << "Recurrent Layer Group eval frame " << i << " begin";
......
...@@ -63,7 +63,7 @@ public: ...@@ -63,7 +63,7 @@ public:
const UpdateCallback& callback); const UpdateCallback& callback);
virtual void resetState() {} virtual void resetState() {}
virtual void eval(Evaluator* evaluator); virtual void eval(Evaluator* evaluator) const;
const std::vector<int>& getParameterIds() { return parameterIds_; } const std::vector<int>& getParameterIds() { return parameterIds_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册