提交 d2565128 编写于 作者: L Liang Zhao

Rewrite code according to reviewer comments

上级 e768721c
...@@ -102,36 +102,22 @@ public: ...@@ -102,36 +102,22 @@ public:
maxValues_, height, width, false, useGpu(arguments[0].deviceId)); maxValues_, height, width, false, useGpu(arguments[0].deviceId));
output->rowMax(*maxIds_, *maxValues_); // top-k values output->rowMax(*maxIds_, *maxValues_); // top-k values
int* ids = nullptr;
int* lbl = nullptr;
IVectorPtr dest = IVector::create(maxIds_->getSize(), false); IVectorPtr dest = IVector::create(maxIds_->getSize(), false);
IVectorPtr dest2 = IVector::create(label->getSize(), false); IVectorPtr dest2 = IVector::create(label->getSize(), false);
if (useGpu(arguments[0].deviceId)) { dest->copyFrom(*maxIds_);
hl_memcpy_device2host((void*)dest->getData(), dest2->copyFrom(*label);
(void*)maxIds_->getData(), int* ids = dest->getData();
sizeof(int) * maxIds_->getSize()); int* lbl = dest2->getData();
ids = dest->getData();
hl_memcpy_device2host((void*)dest2->getData(),
(void*)label->getData(),
sizeof(int) * label->getSize());
lbl = dest2->getData();
} else {
ids = maxIds_->getData();
lbl = label->getData();
}
real* result2 = errorMat2->getData();
for (size_t i = 0; i < height; ++i) { for (size_t i = 0; i < height; ++i) {
result2[i] = (ids[i * width] != lbl[i]); // initialize top-k error bool contain = false;
for (size_t j = 1; j < width; ++j) { for (size_t j = 0; j < width && !contain; ++j) {
if (result2[i] == 0.0) { contain = (ids[i * width + j] == lbl[i]);
break; }
} if (!contain) {
result2[i] = (ids[i * width + j] != lbl[i]); // top-k error totalScore2_ += 1.0; // update top-k error
} }
} }
totalScore2_ += errorMat2->getSum();
} }
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) || } else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) { dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
......
...@@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) { ...@@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) {
std::ifstream fs(filename, std::ios_base::binary); std::ifstream fs(filename, std::ios_base::binary);
if (!fs) { if (!fs) {
LOG(INFO) << "missing parameters [" << filename << "] while loading model."; LOG(INFO) << "missing parameters [" << filename << "] while loading model.";
/*if (isStatic()) {
LOG(FATAL) << getName() << " is static but missing, not allowed.";
return false;
}*/
if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) {
LOG(FATAL) << getName() << " missing, not allowed."; LOG(FATAL) << getName() << " missing, not allowed.";
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册