diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 97b25b04c459ebe6f4e051bfd553b494905642ff..7ea4ed973cfa4f7caf1f6af06de9447e54021c3e 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/gserver/evaluators/Evaluator.h" -#include "paddle/utils/Stat.h" - #include "paddle/gserver/gradientmachines/NeuralNetwork.h" +#include "paddle/utils/Stat.h" +#include "paddle/utils/StringUtil.h" DECLARE_int32(trainer_id); @@ -801,7 +801,9 @@ void PrecisionRecallEvaluator::getNames(std::vector* names) { real PrecisionRecallEvaluator::getValue(const std::string& name, Error* err) const { this->storeLocalValues(); - auto it = this->values_.find(name); + std::vector buffers; + paddle::str::split(name, '.', &buffers); + auto it = this->values_.find(buffers[buffers.size() - 1]); if (it != this->values_.end() && err != nullptr) { *err = Error("No such key %s", name.c_str()); return .0f; @@ -812,10 +814,12 @@ real PrecisionRecallEvaluator::getValue(const std::string& name, std::string PrecisionRecallEvaluator::getType(const std::string& name, Error* err) const { - this->storeLocalValues(); - auto it = this->values_.find(name); - if (it != this->values_.end() && err != nullptr) { - *err = Error("No such key %s", name.c_str()); + Error localErr; + if (err == nullptr) { + err = &localErr; + } + this->getValue(name, err); + if (!err->isOK()) { return ""; } return "precision_recall"; @@ -989,12 +993,12 @@ static InitFunction __reg_type_auc_sum__([]() { */ class ValuePrinter : public Evaluator { public: - ValuePrinter() {} - virtual void eval(const NeuralNetwork& nn) { + layerOutputs_.clear(); for (const std::string& name : config_.input_layers()) { auto& argu = nn.getLayer(name)->getOutput(); - std::unordered_map out; + layerOutputs_[name] = std::unordered_map(); + auto& out = layerOutputs_[name]; argu.getValueString(&out); for (auto field : {"value", "id", "sequence pos", "sub-sequence pos"}) { auto it = out.find(field); @@ -1008,6 +1012,72 @@ public: virtual void updateSamplesNum(const std::vector& arguments) {} virtual real evalImp(std::vector& arguments) { return 0; } + +private: + std::unordered_map> + layerOutputs_; + + // Evaluator interface +public: + void getNames(std::vector* names) { + for (auto layerIt = layerOutputs_.begin(); layerIt != layerOutputs_.end(); + ++layerIt) { + for (auto it = layerIt->second.begin(); it != layerIt->second.end(); + ++it) { + names->push_back(config_.name() + "." + layerIt->first + "." + + it->second); + } + } + } + + real getValue(const std::string& name, Error* err) const { + (void)(name); + if (err != nullptr) { + *err = Error( + "ValuePrinter do not support getValue, use getValueString instead."); + } + return .0f; + } + std::string getValueStr(const std::string& name, Error* err) const { + std::vector buffer; + str::split(name, '.', &buffer); + if (buffer.size() < 2) { + if (err != nullptr) { + *err = Error("No such key %s", name.c_str()); + } + return ""; + } + auto fieldName = buffer[buffer.size() - 1]; + auto layerName = buffer[buffer.size() - 2]; + auto layerIt = layerOutputs_.find(layerName); + if (layerIt == layerOutputs_.end()) { + if (err != nullptr) { + *err = Error("No such layer %s", layerName.c_str()); + } + return ""; + } + + auto fieldIt = layerIt->second.find(fieldName); + if (fieldIt == layerIt->second.end()) { + if (err != nullptr) { + *err = Error("No such value field %s", fieldName.c_str()); + } + return ""; + } + + return fieldIt->second; + } + std::string getType(const std::string& name, Error* err) const { + Error localErr; + if (err == nullptr) { + err = &localErr; + } + this->getValueStr(name, err); + if (!err->isOK()) { + return ""; + } + return "value_printer"; + } }; REGISTER_EVALUATOR(value_printer, ValuePrinter); /**