From d256512832d734134204ff483504551b3c184a57 Mon Sep 17 00:00:00 2001 From: Liang Zhao Date: Mon, 20 Feb 2017 11:18:22 -0800 Subject: [PATCH] Rewrite code according to reviewer comments --- paddle/gserver/evaluators/Evaluator.cpp | 34 ++++++++----------------- paddle/parameter/Parameter.cpp | 4 --- 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 30e91fa222..f2dd5cf807 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -102,36 +102,22 @@ public: maxValues_, height, width, false, useGpu(arguments[0].deviceId)); output->rowMax(*maxIds_, *maxValues_); // top-k values - int* ids = nullptr; - int* lbl = nullptr; IVectorPtr dest = IVector::create(maxIds_->getSize(), false); IVectorPtr dest2 = IVector::create(label->getSize(), false); - if (useGpu(arguments[0].deviceId)) { - hl_memcpy_device2host((void*)dest->getData(), - (void*)maxIds_->getData(), - sizeof(int) * maxIds_->getSize()); - ids = dest->getData(); - - hl_memcpy_device2host((void*)dest2->getData(), - (void*)label->getData(), - sizeof(int) * label->getSize()); - lbl = dest2->getData(); - } else { - ids = maxIds_->getData(); - lbl = label->getData(); - } + dest->copyFrom(*maxIds_); + dest2->copyFrom(*label); + int* ids = dest->getData(); + int* lbl = dest2->getData(); - real* result2 = errorMat2->getData(); for (size_t i = 0; i < height; ++i) { - result2[i] = (ids[i * width] != lbl[i]); // initialize top-k error - for (size_t j = 1; j < width; ++j) { - if (result2[i] == 0.0) { - break; - } - result2[i] = (ids[i * width + j] != lbl[i]); // top-k error + bool contain = false; + for (size_t j = 0; j < width && !contain; ++j) { + contain = (ids[i * width + j] == lbl[i]); + } + if (!contain) { + totalScore2_ += 1.0; // update top-k error } } - totalScore2_ += errorMat2->getSum(); } } else if (dynamic_cast(multiBinaryLabel.get()) || dynamic_cast(multiBinaryLabel.get())) { diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index 9e7b4ea39d..1ccded8187 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) { std::ifstream fs(filename, std::ios_base::binary); if (!fs) { LOG(INFO) << "missing parameters [" << filename << "] while loading model."; - /*if (isStatic()) { - LOG(FATAL) << getName() << " is static but missing, not allowed."; - return false; - }*/ if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { LOG(FATAL) << getName() << " missing, not allowed."; return false; -- GitLab