提交 e00f06af 编写于 作者: L Liang Zhao

Add top-k error

上级 5a1d9263
......@@ -76,9 +76,55 @@ public:
1,
/* trans= */ false,
useGpu(arguments[0].deviceId));
const MatrixPtr errorMat2 = Matrix::create(output->getHeight(),
1,
/* trans= */ false, false);
// useGpu(arguments[0].deviceId));
errorMat->zeroMem();
if (label != nullptr) {
errorMat->classificationError(*output, *label);
errorMat->classificationError(output, label); // top-1 error
size_t height = output->getHeight();
size_t width = 5; // config_.num_results();
IVector::resizeOrCreate(maxIds_, height * width,
useGpu(arguments[0].deviceId));
Matrix::resizeOrCreate(maxValues_, height, width, false,
useGpu(arguments[0].deviceId));
output->rowMax(*maxIds_, *maxValues_); // top-5 values
int* ids;
int* lbl;
if (useGpu(arguments[0].deviceId)) {
IVectorPtr dest = IVector::create(maxIds_->getSize(), false);
hl_memcpy_device2host((void*)dest->getData(),
(void*)maxIds_->getData(),
sizeof(int) * maxIds_->getSize());
ids = dest->getData();
IVectorPtr dest2 = IVector::create(label->getSize(), false);
hl_memcpy_device2host((void*)dest2->getData(),
(void*)label->getData(),
sizeof(int) * label->getSize());
lbl = dest2->getData();
} else {
ids = maxIds_->getData();
lbl = label->getData();
}
// real* result = errorMat->getData();
real* result2 = errorMat2->getData();
for (size_t i = 0; i < height; ++i) {
// result[i] = (ids[i * width] != lbl[i]); // top-1 error
result2[i] = (ids[i * width] != lbl[i]); // initialize top-5 error
for (size_t j = 1; j < width; ++j) {
if (result2[i] == 0.0) {
break;
}
result2[i] = (ids[i * width + j] != lbl[i]); // top-5 error
}
}
totalScore2_ = errorMat2->getSum();
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
errorMat->classificationErrorMulti(
......@@ -94,6 +140,13 @@ public:
return errorMat;
}
void printStats(std::ostream& os) const {
os << "top_1_error="
<< (numSamples_ ? totalScore_ / numSamples_ : 0)
<< " top_5_error="
<< (numSamples_ ? totalScore2_ / numSamples_ : 0);
}
virtual real evalImp(std::vector<Argument>& arguments) {
MatrixPtr errorMat = calcError(arguments);
return errorMat->getSum();
......@@ -102,6 +155,12 @@ public:
virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client);
}
private:
IVectorPtr maxIds_;
MatrixPtr maxValues_;
double totalScore2_;
};
/**
......
......@@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
size_t beam = maxVal.getWidth();
CHECK_EQ(maxIds.getSize(), numSamples * beam);
CHECK_EQ(maxVal.getHeight(), numSamples);
CHECK_EQ(maxVal.getWidth(), beam);
hl_matrix_top_k(maxVal.getData(),
maxVal.getStride(),
......@@ -3039,7 +3040,7 @@ void CpuMatrix::rowMax(Matrix& max) {
max.maxRows(*this);
}
/* get beam size of max ids and values */
/* Get the top k elements of each row of this matrix */
void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
CHECK(isContiguous());
CHECK(!maxIds.useGpu() && !maxVal.useGpu()) << "Matrix type are not equal";
......@@ -3047,6 +3048,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
size_t beam = maxVal.getWidth();
CHECK_EQ(maxIds.getSize(), numSamples * beam);
CHECK_EQ(maxVal.getHeight(), numSamples);
CHECK_EQ(maxVal.getWidth(), beam);
real* a = getData();
int* s = maxIds.getData();
......
......@@ -375,10 +375,10 @@ bool Parameter::load(const std::string& filename) {
std::ifstream fs(filename, std::ios_base::binary);
if (!fs) {
LOG(INFO) << "missing parameters [" << filename << "] while loading model.";
if (isStatic()) {
/*if (isStatic()) {
LOG(FATAL) << getName() << " is static but missing, not allowed.";
return false;
}
}*/
if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) {
LOG(FATAL) << getName() << " missing, not allowed.";
return false;
......
......@@ -495,7 +495,6 @@ def gradient_printer_evaluator(
"""
evaluator_base(name=name, type="gradient_printer", input=input)
@evaluator(EvaluatorAttribute.FOR_PRINT)
@wrap_name_default()
def maxid_printer_evaluator(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册