diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 30e91fa222fca571457dd1221cd6b3158cadd098..f2dd5cf8073f6ddf6311cc82b2ba7275ca4daf1c 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -102,36 +102,22 @@ public: maxValues_, height, width, false, useGpu(arguments[0].deviceId)); output->rowMax(*maxIds_, *maxValues_); // top-k values - int* ids = nullptr; - int* lbl = nullptr; IVectorPtr dest = IVector::create(maxIds_->getSize(), false); IVectorPtr dest2 = IVector::create(label->getSize(), false); - if (useGpu(arguments[0].deviceId)) { - hl_memcpy_device2host((void*)dest->getData(), - (void*)maxIds_->getData(), - sizeof(int) * maxIds_->getSize()); - ids = dest->getData(); - - hl_memcpy_device2host((void*)dest2->getData(), - (void*)label->getData(), - sizeof(int) * label->getSize()); - lbl = dest2->getData(); - } else { - ids = maxIds_->getData(); - lbl = label->getData(); - } + dest->copyFrom(*maxIds_); + dest2->copyFrom(*label); + int* ids = dest->getData(); + int* lbl = dest2->getData(); - real* result2 = errorMat2->getData(); for (size_t i = 0; i < height; ++i) { - result2[i] = (ids[i * width] != lbl[i]); // initialize top-k error - for (size_t j = 1; j < width; ++j) { - if (result2[i] == 0.0) { - break; - } - result2[i] = (ids[i * width + j] != lbl[i]); // top-k error + bool contain = false; + for (size_t j = 0; j < width && !contain; ++j) { + contain = (ids[i * width + j] == lbl[i]); + } + if (!contain) { + totalScore2_ += 1.0; // update top-k error } } - totalScore2_ += errorMat2->getSum(); } } else if (dynamic_cast(multiBinaryLabel.get()) || dynamic_cast(multiBinaryLabel.get())) { diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index 9e7b4ea39d4ea676d8eff7e938ce536cb8812ce7..1ccded818796798105a889df978618688b56ed36 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) { std::ifstream fs(filename, std::ios_base::binary); if (!fs) { LOG(INFO) << "missing parameters [" << filename << "] while loading model."; - /*if (isStatic()) { - LOG(FATAL) << getName() << " is static but missing, not allowed."; - return false; - }*/ if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { LOG(FATAL) << getName() << " missing, not allowed."; return false;