提交 e00f06af 编写于 作者: L Liang Zhao

Add top-k error

上级 5a1d9263
...@@ -76,9 +76,55 @@ public: ...@@ -76,9 +76,55 @@ public:
1, 1,
/* trans= */ false, /* trans= */ false,
useGpu(arguments[0].deviceId)); useGpu(arguments[0].deviceId));
const MatrixPtr errorMat2 = Matrix::create(output->getHeight(),
1,
/* trans= */ false, false);
// useGpu(arguments[0].deviceId));
errorMat->zeroMem(); errorMat->zeroMem();
if (label != nullptr) { if (label != nullptr) {
errorMat->classificationError(*output, *label); errorMat->classificationError(output, label); // top-1 error
size_t height = output->getHeight();
size_t width = 5; // config_.num_results();
IVector::resizeOrCreate(maxIds_, height * width,
useGpu(arguments[0].deviceId));
Matrix::resizeOrCreate(maxValues_, height, width, false,
useGpu(arguments[0].deviceId));
output->rowMax(*maxIds_, *maxValues_); // top-5 values
int* ids;
int* lbl;
if (useGpu(arguments[0].deviceId)) {
IVectorPtr dest = IVector::create(maxIds_->getSize(), false);
hl_memcpy_device2host((void*)dest->getData(),
(void*)maxIds_->getData(),
sizeof(int) * maxIds_->getSize());
ids = dest->getData();
IVectorPtr dest2 = IVector::create(label->getSize(), false);
hl_memcpy_device2host((void*)dest2->getData(),
(void*)label->getData(),
sizeof(int) * label->getSize());
lbl = dest2->getData();
} else {
ids = maxIds_->getData();
lbl = label->getData();
}
// real* result = errorMat->getData();
real* result2 = errorMat2->getData();
for (size_t i = 0; i < height; ++i) {
// result[i] = (ids[i * width] != lbl[i]); // top-1 error
result2[i] = (ids[i * width] != lbl[i]); // initialize top-5 error
for (size_t j = 1; j < width; ++j) {
if (result2[i] == 0.0) {
break;
}
result2[i] = (ids[i * width + j] != lbl[i]); // top-5 error
}
}
totalScore2_ = errorMat2->getSum();
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) || } else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) { dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
errorMat->classificationErrorMulti( errorMat->classificationErrorMulti(
...@@ -94,6 +140,13 @@ public: ...@@ -94,6 +140,13 @@ public:
return errorMat; return errorMat;
} }
void printStats(std::ostream& os) const {
os << "top_1_error="
<< (numSamples_ ? totalScore_ / numSamples_ : 0)
<< " top_5_error="
<< (numSamples_ ? totalScore2_ / numSamples_ : 0);
}
virtual real evalImp(std::vector<Argument>& arguments) { virtual real evalImp(std::vector<Argument>& arguments) {
MatrixPtr errorMat = calcError(arguments); MatrixPtr errorMat = calcError(arguments);
return errorMat->getSum(); return errorMat->getSum();
...@@ -102,6 +155,12 @@ public: ...@@ -102,6 +155,12 @@ public:
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client); mergeResultsOfAllClients(client);
} }
private:
IVectorPtr maxIds_;
MatrixPtr maxValues_;
double totalScore2_;
}; };
/** /**
......
...@@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { ...@@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
size_t beam = maxVal.getWidth(); size_t beam = maxVal.getWidth();
CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxIds.getSize(), numSamples * beam);
CHECK_EQ(maxVal.getHeight(), numSamples); CHECK_EQ(maxVal.getHeight(), numSamples);
CHECK_EQ(maxVal.getWidth(), beam);
hl_matrix_top_k(maxVal.getData(), hl_matrix_top_k(maxVal.getData(),
maxVal.getStride(), maxVal.getStride(),
...@@ -3039,7 +3040,7 @@ void CpuMatrix::rowMax(Matrix& max) { ...@@ -3039,7 +3040,7 @@ void CpuMatrix::rowMax(Matrix& max) {
max.maxRows(*this); max.maxRows(*this);
} }
/* get beam size of max ids and values */ /* Get the top k elements of each row of this matrix */
void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
CHECK(isContiguous()); CHECK(isContiguous());
CHECK(!maxIds.useGpu() && !maxVal.useGpu()) << "Matrix type are not equal"; CHECK(!maxIds.useGpu() && !maxVal.useGpu()) << "Matrix type are not equal";
...@@ -3047,6 +3048,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { ...@@ -3047,6 +3048,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
size_t beam = maxVal.getWidth(); size_t beam = maxVal.getWidth();
CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxIds.getSize(), numSamples * beam);
CHECK_EQ(maxVal.getHeight(), numSamples); CHECK_EQ(maxVal.getHeight(), numSamples);
CHECK_EQ(maxVal.getWidth(), beam);
real* a = getData(); real* a = getData();
int* s = maxIds.getData(); int* s = maxIds.getData();
......
...@@ -375,10 +375,10 @@ bool Parameter::load(const std::string& filename) { ...@@ -375,10 +375,10 @@ bool Parameter::load(const std::string& filename) {
std::ifstream fs(filename, std::ios_base::binary); std::ifstream fs(filename, std::ios_base::binary);
if (!fs) { if (!fs) {
LOG(INFO) << "missing parameters [" << filename << "] while loading model."; LOG(INFO) << "missing parameters [" << filename << "] while loading model.";
if (isStatic()) { /*if (isStatic()) {
LOG(FATAL) << getName() << " is static but missing, not allowed."; LOG(FATAL) << getName() << " is static but missing, not allowed.";
return false; return false;
} }*/
if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) {
LOG(FATAL) << getName() << " missing, not allowed."; LOG(FATAL) << getName() << " missing, not allowed.";
return false; return false;
......
...@@ -495,7 +495,6 @@ def gradient_printer_evaluator( ...@@ -495,7 +495,6 @@ def gradient_printer_evaluator(
""" """
evaluator_base(name=name, type="gradient_printer", input=input) evaluator_base(name=name, type="gradient_printer", input=input)
@evaluator(EvaluatorAttribute.FOR_PRINT) @evaluator(EvaluatorAttribute.FOR_PRINT)
@wrap_name_default() @wrap_name_default()
def maxid_printer_evaluator( def maxid_printer_evaluator(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册