提交 6b3c33d0 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #2523 from emailweixu/rnn_evaluator

evaluator for recurrent group.
......@@ -321,7 +321,7 @@ public:
}
}
virtual void eval(const NeuralNetwork& nn) {
virtual void eval(const NeuralNetwork& nn) override {
for (auto& evaluator : evaluators_) {
evaluator->eval(nn);
}
......@@ -396,6 +396,30 @@ private:
}
};
class SubnetEvaluator : public CombinedEvaluator {
public:
SubnetEvaluator(const std::string& layerName,
std::unique_ptr<Evaluator>&& evaluator)
: layerName_(layerName) {
addEvaluator(std::move(evaluator));
}
virtual void eval(const NeuralNetwork& nn) override {
const LayerPtr& layer = nn.getLayer(layerName_);
CHECK(layer) << "Nonexisted layer: " << layerName_ << " in submodel "
<< nn.getName();
bool accessed = false;
layer->accessSubNetwork([this, &accessed](NeuralNetwork& subnet) {
subnet.eval(evaluators_[0].get());
accessed = true;
});
CHECK(accessed) << "There is no subnetwork for layer " << layerName_
<< " in submodel " << nn.getName();
}
protected:
std::string layerName_;
};
Evaluator* NeuralNetwork::makeEvaluator() const {
CombinedEvaluator* combinedEvaluator = new CombinedEvaluator();
auto subModelConfig = std::find_if(config_.sub_models().begin(),
......@@ -422,6 +446,15 @@ Evaluator* NeuralNetwork::makeEvaluator() const {
combinedEvaluator->addEvaluator(std::move(evaluator));
}
}
for (auto& layer : layers_) {
layer->accessSubNetwork(
[layer, combinedEvaluator](NeuralNetwork& subnet) {
std::unique_ptr<Evaluator> subEvaluator(new SubnetEvaluator(
layer->getName(),
std::unique_ptr<Evaluator>(subnet.makeEvaluator())));
combinedEvaluator->addEvaluator(std::move(subEvaluator));
});
}
} else {
for (const EvaluatorConfig& evalConfig : config_.evaluators()) {
std::unique_ptr<Evaluator> evaluator(Evaluator::create(evalConfig));
......
......@@ -129,6 +129,8 @@ public:
static NeuralNetwork* newNeuralNetwork(const std::string& name = "",
NeuralNetwork* rootNetwork = nullptr);
const std::string& getName() const { return subModelName_; }
protected:
/**
* The constructor of NeuralNetwork.
......
......@@ -288,10 +288,6 @@ void RecurrentGradientMachine::init(
parameterIds_.push_back(para->getID());
}
}
if (subModelConfig->evaluator_names_size() > 0) {
evaluator_.reset(frames_[0]->makeEvaluator());
}
}
void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) {
......@@ -562,9 +558,6 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
std::vector<Argument> outArgs;
frames_[i]->forward(inArgs, &outArgs, passType);
}
if (evaluator_ && passType == PASS_TEST) {
this->eval(evaluator_.get());
}
reorganizeOutput(passType);
}
......@@ -581,11 +574,6 @@ void RecurrentGradientMachine::backward(const UpdateCallback& callback) {
for (auto& memoryFrameLine : memoryFrameLines_) {
memoryFrameLine.bootLayer->backward(nullptr);
}
// call printers here so the gradient can be printed
if (evaluator_) {
this->eval(evaluator_.get());
}
}
void RecurrentGradientMachine::forwardBackward(
......@@ -599,9 +587,9 @@ void RecurrentGradientMachine::forwardBackward(
void RecurrentGradientMachine::eval(Evaluator* evaluator) const {
// call printers frame by frame
for (int i = 0; i < maxSequenceLength_; ++i) {
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " begin";
VLOG(2) << "Recurrent Layer Group eval frame " << i << " begin";
evaluator->eval(*(frames_[i].get()));
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " end";
VLOG(2) << "Recurrent Layer Group eval frame " << i << " end";
}
}
......@@ -1097,10 +1085,6 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
copyDataOutlinkFrame(machineCur);
// call value printer
if (evaluator_) {
evaluator_->eval(*(frames_[machineCur].get()));
}
// check eos
const IVectorPtr& eosVec =
eosFrameLine_->layers[machineCur]->getOutput().ids;
......
......@@ -429,8 +429,6 @@ protected:
std::vector<int>
parameterIds_; // parameters actually used by this Layer Group
std::unique_ptr<Evaluator> evaluator_; // frame printers in this layer group
// store final argument of outFrameLines_
std::vector<Argument> dataArgs_;
// store each frame's output argument of outFrameLines_
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册