提交 5b4e7d5c 编写于 作者: Y Yu Yang

complete value printer

上级 5ecc1a21
...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/gserver/evaluators/Evaluator.h" #include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/utils/Stat.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h" #include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/StringUtil.h"
DECLARE_int32(trainer_id); DECLARE_int32(trainer_id);
...@@ -801,7 +801,9 @@ void PrecisionRecallEvaluator::getNames(std::vector<std::string>* names) { ...@@ -801,7 +801,9 @@ void PrecisionRecallEvaluator::getNames(std::vector<std::string>* names) {
real PrecisionRecallEvaluator::getValue(const std::string& name, real PrecisionRecallEvaluator::getValue(const std::string& name,
Error* err) const { Error* err) const {
this->storeLocalValues(); this->storeLocalValues();
auto it = this->values_.find(name); std::vector<std::string> buffers;
paddle::str::split(name, '.', &buffers);
auto it = this->values_.find(buffers[buffers.size() - 1]);
if (it != this->values_.end() && err != nullptr) { if (it != this->values_.end() && err != nullptr) {
*err = Error("No such key %s", name.c_str()); *err = Error("No such key %s", name.c_str());
return .0f; return .0f;
...@@ -812,10 +814,12 @@ real PrecisionRecallEvaluator::getValue(const std::string& name, ...@@ -812,10 +814,12 @@ real PrecisionRecallEvaluator::getValue(const std::string& name,
std::string PrecisionRecallEvaluator::getType(const std::string& name, std::string PrecisionRecallEvaluator::getType(const std::string& name,
Error* err) const { Error* err) const {
this->storeLocalValues(); Error localErr;
auto it = this->values_.find(name); if (err == nullptr) {
if (it != this->values_.end() && err != nullptr) { err = &localErr;
*err = Error("No such key %s", name.c_str()); }
this->getValue(name, err);
if (!err->isOK()) {
return ""; return "";
} }
return "precision_recall"; return "precision_recall";
...@@ -989,12 +993,12 @@ static InitFunction __reg_type_auc_sum__([]() { ...@@ -989,12 +993,12 @@ static InitFunction __reg_type_auc_sum__([]() {
*/ */
class ValuePrinter : public Evaluator { class ValuePrinter : public Evaluator {
public: public:
ValuePrinter() {}
virtual void eval(const NeuralNetwork& nn) { virtual void eval(const NeuralNetwork& nn) {
layerOutputs_.clear();
for (const std::string& name : config_.input_layers()) { for (const std::string& name : config_.input_layers()) {
auto& argu = nn.getLayer(name)->getOutput(); auto& argu = nn.getLayer(name)->getOutput();
std::unordered_map<std::string, std::string> out; layerOutputs_[name] = std::unordered_map<std::string, std::string>();
auto& out = layerOutputs_[name];
argu.getValueString(&out); argu.getValueString(&out);
for (auto field : {"value", "id", "sequence pos", "sub-sequence pos"}) { for (auto field : {"value", "id", "sequence pos", "sub-sequence pos"}) {
auto it = out.find(field); auto it = out.find(field);
...@@ -1008,6 +1012,72 @@ public: ...@@ -1008,6 +1012,72 @@ public:
virtual void updateSamplesNum(const std::vector<Argument>& arguments) {} virtual void updateSamplesNum(const std::vector<Argument>& arguments) {}
virtual real evalImp(std::vector<Argument>& arguments) { return 0; } virtual real evalImp(std::vector<Argument>& arguments) { return 0; }
private:
std::unordered_map<std::string, std::unordered_map<std::string, std::string>>
layerOutputs_;
// Evaluator interface
public:
void getNames(std::vector<std::string>* names) {
for (auto layerIt = layerOutputs_.begin(); layerIt != layerOutputs_.end();
++layerIt) {
for (auto it = layerIt->second.begin(); it != layerIt->second.end();
++it) {
names->push_back(config_.name() + "." + layerIt->first + "." +
it->second);
}
}
}
real getValue(const std::string& name, Error* err) const {
(void)(name);
if (err != nullptr) {
*err = Error(
"ValuePrinter do not support getValue, use getValueString instead.");
}
return .0f;
}
std::string getValueStr(const std::string& name, Error* err) const {
std::vector<std::string> buffer;
str::split(name, '.', &buffer);
if (buffer.size() < 2) {
if (err != nullptr) {
*err = Error("No such key %s", name.c_str());
}
return "";
}
auto fieldName = buffer[buffer.size() - 1];
auto layerName = buffer[buffer.size() - 2];
auto layerIt = layerOutputs_.find(layerName);
if (layerIt == layerOutputs_.end()) {
if (err != nullptr) {
*err = Error("No such layer %s", layerName.c_str());
}
return "";
}
auto fieldIt = layerIt->second.find(fieldName);
if (fieldIt == layerIt->second.end()) {
if (err != nullptr) {
*err = Error("No such value field %s", fieldName.c_str());
}
return "";
}
return fieldIt->second;
}
std::string getType(const std::string& name, Error* err) const {
Error localErr;
if (err == nullptr) {
err = &localErr;
}
this->getValueStr(name, err);
if (!err->isOK()) {
return "";
}
return "value_printer";
}
}; };
REGISTER_EVALUATOR(value_printer, ValuePrinter); REGISTER_EVALUATOR(value_printer, ValuePrinter);
/** /**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册