提交 948218da 编写于 作者: Y Yu Yang

Unify PrintLogic in PrintLayer/ValuePrinter.

上级 2c07dd50
......@@ -888,19 +888,10 @@ Evaluator* Evaluator::create(const EvaluatorConfig& config) {
*/
class ValuePrinter : public Evaluator {
public:
ValuePrinter() {}
virtual void eval(const NeuralNetwork& nn) {
for (const std::string& name : config_.input_layers()) {
auto& argu = nn.getLayer(name)->getOutput();
std::unordered_map<std::string, std::string> out;
argu.getValueString(&out);
for (auto field : {"value", "id", "sequence pos", "sub-sequence pos"}) {
auto it = out.find(field);
if (it != out.end()) {
LOG(INFO) << "layer=" << name << " " << field << ":\n" << it->second;
}
}
nn.getLayer(name)->getOutput().printValueString(LOG(INFO),
"layer=" + name + " ");
}
}
......
......@@ -19,25 +19,17 @@ namespace paddle {
class PrintLayer : public Layer {
public:
explicit PrintLayer(const LayerConfig& config) : Layer(config) {}
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override {}
};
void PrintLayer::forward(PassType passType) {
Layer::forward(passType);
for (size_t i = 0; i != inputLayers_.size(); ++i) {
auto& argu = getInput(i);
const std::string& name = inputLayers_[i]->getName();
std::unordered_map<std::string, std::string> out;
argu.getValueString(&out);
for (auto field : {"value", "id", "sequence pos", "sub-sequence pos"}) {
auto it = out.find(field);
if (it != out.end()) {
LOG(INFO) << "layer=" << name << " " << field << ":\n" << it->second;
}
void forward(PassType passType) override {
Layer::forward(passType);
for (size_t i = 0; i != inputLayers_.size(); ++i) {
getInput(i).printValueString(LOG(INFO),
"layer=" + inputLayers_[i]->getName() + " ");
}
}
}
void backward(const UpdateCallback& callback) override {}
};
REGISTER_LAYER(print, PrintLayer);
......
......@@ -628,6 +628,18 @@ void Argument::getValueString(
}
}
void Argument::printValueString(std::ostream& stream,
const std::string& prefix) const {
std::unordered_map<std::string, std::string> out;
getValueString(&out);
for (auto field : {"value", "id", "sequence pos", "sub-sequence pos"}) {
auto it = out.find(field);
if (it != out.end()) {
stream << prefix << field << ":\n" << it->second;
}
}
}
void Argument::subArgFrom(const Argument& input,
size_t offset,
size_t height,
......
......@@ -305,6 +305,9 @@ struct Argument {
* @param out [out]: the return values.
*/
void getValueString(std::unordered_map<std::string, std::string>* out) const;
void printValueString(std::ostream& stream,
const std::string& prefix = "") const;
};
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册