提交 043859b5 编写于 作者: L Liang Zhao

clean up code

上级 5250b9b2
...@@ -78,23 +78,24 @@ public: ...@@ -78,23 +78,24 @@ public:
useGpu(arguments[0].deviceId)); useGpu(arguments[0].deviceId));
const MatrixPtr errorMat2 = Matrix::create(output->getHeight(), const MatrixPtr errorMat2 = Matrix::create(output->getHeight(),
1, 1,
/* trans= */ false, false); /* trans= */ false,
// useGpu(arguments[0].deviceId)); false);
errorMat->zeroMem(); errorMat->zeroMem();
if (label != nullptr) { if (label != nullptr) {
errorMat->classificationError(*output, *label); // top-1 error errorMat->classificationError(*output, *label); // top-1 error
size_t height = output->getHeight(); size_t height = output->getHeight();
size_t width = 5; // config_.num_results(); size_t width = 5;
IVector::resizeOrCreate(maxIds_, height * width, IVector::resizeOrCreate(
useGpu(arguments[0].deviceId)); maxIds_, height * width, useGpu(arguments[0].deviceId));
Matrix::resizeOrCreate(maxValues_, height, width, false, Matrix::resizeOrCreate(
useGpu(arguments[0].deviceId)); maxValues_, height, width, false, useGpu(arguments[0].deviceId));
output->rowMax(*maxIds_, *maxValues_); // top-5 values output->rowMax(*maxIds_, *maxValues_); // top-5 values
int* ids; int* ids = nullptr;
int* lbl; int* lbl = nullptr;
if (useGpu(arguments[0].deviceId)) { if (useGpu(arguments[0].deviceId)) {
IVectorPtr dest = IVector::create(maxIds_->getSize(), false); IVectorPtr dest = IVector::create(maxIds_->getSize(), false);
hl_memcpy_device2host((void*)dest->getData(), hl_memcpy_device2host((void*)dest->getData(),
...@@ -112,10 +113,8 @@ public: ...@@ -112,10 +113,8 @@ public:
lbl = label->getData(); lbl = label->getData();
} }
// real* result = errorMat->getData();
real* result2 = errorMat2->getData(); real* result2 = errorMat2->getData();
for (size_t i = 0; i < height; ++i) { for (size_t i = 0; i < height; ++i) {
// result[i] = (ids[i * width] != lbl[i]); // top-1 error
result2[i] = (ids[i * width] != lbl[i]); // initialize top-5 error result2[i] = (ids[i * width] != lbl[i]); // initialize top-5 error
for (size_t j = 1; j < width; ++j) { for (size_t j = 1; j < width; ++j) {
if (result2[i] == 0.0) { if (result2[i] == 0.0) {
...@@ -141,10 +140,8 @@ public: ...@@ -141,10 +140,8 @@ public:
} }
void printStats(std::ostream& os) const { void printStats(std::ostream& os) const {
os << "top_1_error=" os << "top_1_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0)
<< (numSamples_ ? totalScore_ / numSamples_ : 0) << " top_5_error=" << (numSamples_ ? totalScore2_ / numSamples_ : 0);
<< " top_5_error="
<< (numSamples_ ? totalScore2_ / numSamples_ : 0);
} }
virtual real evalImp(std::vector<Argument>& arguments) { virtual real evalImp(std::vector<Argument>& arguments) {
...@@ -156,7 +153,6 @@ public: ...@@ -156,7 +153,6 @@ public:
mergeResultsOfAllClients(client); mergeResultsOfAllClients(client);
} }
private: private:
IVectorPtr maxIds_; IVectorPtr maxIds_;
MatrixPtr maxValues_; MatrixPtr maxValues_;
......
...@@ -495,6 +495,7 @@ def gradient_printer_evaluator( ...@@ -495,6 +495,7 @@ def gradient_printer_evaluator(
""" """
evaluator_base(name=name, type="gradient_printer", input=input) evaluator_base(name=name, type="gradient_printer", input=input)
@evaluator(EvaluatorAttribute.FOR_PRINT) @evaluator(EvaluatorAttribute.FOR_PRINT)
@wrap_name_default() @wrap_name_default()
def maxid_printer_evaluator( def maxid_printer_evaluator(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册