提交 6b3c33d0 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #2523 from emailweixu/rnn_evaluator

evaluator for recurrent group.
...@@ -321,7 +321,7 @@ public: ...@@ -321,7 +321,7 @@ public:
} }
} }
virtual void eval(const NeuralNetwork& nn) { virtual void eval(const NeuralNetwork& nn) override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->eval(nn); evaluator->eval(nn);
} }
...@@ -396,6 +396,30 @@ private: ...@@ -396,6 +396,30 @@ private:
} }
}; };
class SubnetEvaluator : public CombinedEvaluator {
public:
SubnetEvaluator(const std::string& layerName,
std::unique_ptr<Evaluator>&& evaluator)
: layerName_(layerName) {
addEvaluator(std::move(evaluator));
}
virtual void eval(const NeuralNetwork& nn) override {
const LayerPtr& layer = nn.getLayer(layerName_);
CHECK(layer) << "Nonexisted layer: " << layerName_ << " in submodel "
<< nn.getName();
bool accessed = false;
layer->accessSubNetwork([this, &accessed](NeuralNetwork& subnet) {
subnet.eval(evaluators_[0].get());
accessed = true;
});
CHECK(accessed) << "There is no subnetwork for layer " << layerName_
<< " in submodel " << nn.getName();
}
protected:
std::string layerName_;
};
Evaluator* NeuralNetwork::makeEvaluator() const { Evaluator* NeuralNetwork::makeEvaluator() const {
CombinedEvaluator* combinedEvaluator = new CombinedEvaluator(); CombinedEvaluator* combinedEvaluator = new CombinedEvaluator();
auto subModelConfig = std::find_if(config_.sub_models().begin(), auto subModelConfig = std::find_if(config_.sub_models().begin(),
...@@ -422,6 +446,15 @@ Evaluator* NeuralNetwork::makeEvaluator() const { ...@@ -422,6 +446,15 @@ Evaluator* NeuralNetwork::makeEvaluator() const {
combinedEvaluator->addEvaluator(std::move(evaluator)); combinedEvaluator->addEvaluator(std::move(evaluator));
} }
} }
for (auto& layer : layers_) {
layer->accessSubNetwork(
[layer, combinedEvaluator](NeuralNetwork& subnet) {
std::unique_ptr<Evaluator> subEvaluator(new SubnetEvaluator(
layer->getName(),
std::unique_ptr<Evaluator>(subnet.makeEvaluator())));
combinedEvaluator->addEvaluator(std::move(subEvaluator));
});
}
} else { } else {
for (const EvaluatorConfig& evalConfig : config_.evaluators()) { for (const EvaluatorConfig& evalConfig : config_.evaluators()) {
std::unique_ptr<Evaluator> evaluator(Evaluator::create(evalConfig)); std::unique_ptr<Evaluator> evaluator(Evaluator::create(evalConfig));
......
...@@ -129,6 +129,8 @@ public: ...@@ -129,6 +129,8 @@ public:
static NeuralNetwork* newNeuralNetwork(const std::string& name = "", static NeuralNetwork* newNeuralNetwork(const std::string& name = "",
NeuralNetwork* rootNetwork = nullptr); NeuralNetwork* rootNetwork = nullptr);
const std::string& getName() const { return subModelName_; }
protected: protected:
/** /**
* The constructor of NeuralNetwork. * The constructor of NeuralNetwork.
......
...@@ -288,10 +288,6 @@ void RecurrentGradientMachine::init( ...@@ -288,10 +288,6 @@ void RecurrentGradientMachine::init(
parameterIds_.push_back(para->getID()); parameterIds_.push_back(para->getID());
} }
} }
if (subModelConfig->evaluator_names_size() > 0) {
evaluator_.reset(frames_[0]->makeEvaluator());
}
} }
void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) { void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) {
...@@ -562,9 +558,6 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs, ...@@ -562,9 +558,6 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
std::vector<Argument> outArgs; std::vector<Argument> outArgs;
frames_[i]->forward(inArgs, &outArgs, passType); frames_[i]->forward(inArgs, &outArgs, passType);
} }
if (evaluator_ && passType == PASS_TEST) {
this->eval(evaluator_.get());
}
reorganizeOutput(passType); reorganizeOutput(passType);
} }
...@@ -581,11 +574,6 @@ void RecurrentGradientMachine::backward(const UpdateCallback& callback) { ...@@ -581,11 +574,6 @@ void RecurrentGradientMachine::backward(const UpdateCallback& callback) {
for (auto& memoryFrameLine : memoryFrameLines_) { for (auto& memoryFrameLine : memoryFrameLines_) {
memoryFrameLine.bootLayer->backward(nullptr); memoryFrameLine.bootLayer->backward(nullptr);
} }
// call printers here so the gradient can be printed
if (evaluator_) {
this->eval(evaluator_.get());
}
} }
void RecurrentGradientMachine::forwardBackward( void RecurrentGradientMachine::forwardBackward(
...@@ -599,9 +587,9 @@ void RecurrentGradientMachine::forwardBackward( ...@@ -599,9 +587,9 @@ void RecurrentGradientMachine::forwardBackward(
void RecurrentGradientMachine::eval(Evaluator* evaluator) const { void RecurrentGradientMachine::eval(Evaluator* evaluator) const {
// call printers frame by frame // call printers frame by frame
for (int i = 0; i < maxSequenceLength_; ++i) { for (int i = 0; i < maxSequenceLength_; ++i) {
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " begin"; VLOG(2) << "Recurrent Layer Group eval frame " << i << " begin";
evaluator->eval(*(frames_[i].get())); evaluator->eval(*(frames_[i].get()));
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " end"; VLOG(2) << "Recurrent Layer Group eval frame " << i << " end";
} }
} }
...@@ -1097,10 +1085,6 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) { ...@@ -1097,10 +1085,6 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
copyDataOutlinkFrame(machineCur); copyDataOutlinkFrame(machineCur);
// call value printer
if (evaluator_) {
evaluator_->eval(*(frames_[machineCur].get()));
}
// check eos // check eos
const IVectorPtr& eosVec = const IVectorPtr& eosVec =
eosFrameLine_->layers[machineCur]->getOutput().ids; eosFrameLine_->layers[machineCur]->getOutput().ids;
......
...@@ -429,8 +429,6 @@ protected: ...@@ -429,8 +429,6 @@ protected:
std::vector<int> std::vector<int>
parameterIds_; // parameters actually used by this Layer Group parameterIds_; // parameters actually used by this Layer Group
std::unique_ptr<Evaluator> evaluator_; // frame printers in this layer group
// store final argument of outFrameLines_ // store final argument of outFrameLines_
std::vector<Argument> dataArgs_; std::vector<Argument> dataArgs_;
// store each frame's output argument of outFrameLines_ // store each frame's output argument of outFrameLines_
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册