提交 043859b5 编写于 作者: L Liang Zhao

clean up code

上级 5250b9b2
......@@ -78,23 +78,24 @@ public:
useGpu(arguments[0].deviceId));
const MatrixPtr errorMat2 = Matrix::create(output->getHeight(),
1,
/* trans= */ false, false);
// useGpu(arguments[0].deviceId));
/* trans= */ false,
false);
errorMat->zeroMem();
if (label != nullptr) {
errorMat->classificationError(*output, *label); // top-1 error
size_t height = output->getHeight();
size_t width = 5; // config_.num_results();
size_t width = 5;
IVector::resizeOrCreate(maxIds_, height * width,
useGpu(arguments[0].deviceId));
Matrix::resizeOrCreate(maxValues_, height, width, false,
useGpu(arguments[0].deviceId));
IVector::resizeOrCreate(
maxIds_, height * width, useGpu(arguments[0].deviceId));
Matrix::resizeOrCreate(
maxValues_, height, width, false, useGpu(arguments[0].deviceId));
output->rowMax(*maxIds_, *maxValues_); // top-5 values
int* ids;
int* lbl;
int* ids = nullptr;
int* lbl = nullptr;
if (useGpu(arguments[0].deviceId)) {
IVectorPtr dest = IVector::create(maxIds_->getSize(), false);
hl_memcpy_device2host((void*)dest->getData(),
......@@ -112,10 +113,8 @@ public:
lbl = label->getData();
}
// real* result = errorMat->getData();
real* result2 = errorMat2->getData();
for (size_t i = 0; i < height; ++i) {
// result[i] = (ids[i * width] != lbl[i]); // top-1 error
result2[i] = (ids[i * width] != lbl[i]); // initialize top-5 error
for (size_t j = 1; j < width; ++j) {
if (result2[i] == 0.0) {
......@@ -141,10 +140,8 @@ public:
}
void printStats(std::ostream& os) const {
os << "top_1_error="
<< (numSamples_ ? totalScore_ / numSamples_ : 0)
<< " top_5_error="
<< (numSamples_ ? totalScore2_ / numSamples_ : 0);
os << "top_1_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0)
<< " top_5_error=" << (numSamples_ ? totalScore2_ / numSamples_ : 0);
}
virtual real evalImp(std::vector<Argument>& arguments) {
......@@ -156,7 +153,6 @@ public:
mergeResultsOfAllClients(client);
}
private:
IVectorPtr maxIds_;
MatrixPtr maxValues_;
......
......@@ -495,6 +495,7 @@ def gradient_printer_evaluator(
"""
evaluator_base(name=name, type="gradient_printer", input=input)
@evaluator(EvaluatorAttribute.FOR_PRINT)
@wrap_name_default()
def maxid_printer_evaluator(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册