提交 046349dd 编写于 作者: L Liang Zhao

Fix definition of hl_matrix_classification_error in hl_matrix_stub.h

上级 8fded24c
......@@ -35,8 +35,16 @@ inline void hl_sequence_softmax_forward(real* A_d,
inline void hl_matrix_softmax_derivative(
real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {}
inline void hl_matrix_classification_error(
real* A_d, int* B_d, real* C_d, int dimM, int dimN) {}
inline void hl_matrix_classification_error(real* topVal,
int ldv,
int* topIds,
real* src,
int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult) {}
inline void hl_matrix_cross_entropy(
real* A_d, real* C_d, int* label_d, int dimM, int dimN) {}
......
......@@ -764,7 +764,7 @@ TEST(Matrix, paramReluBackwardDiff) {
}
}
void testClassificationError(int numSamples, int dim) {
void testClassificationError(int numSamples, int dim, int topkSize) {
MatrixPtr cpuError = std::make_shared<CpuMatrix>(numSamples, 1);
MatrixPtr gpuError = std::make_shared<GpuMatrix>(numSamples, 1);
MatrixPtr cpuOutput = std::make_shared<CpuMatrix>(numSamples, dim);
......@@ -777,17 +777,22 @@ void testClassificationError(int numSamples, int dim) {
gpuOutput->copyFrom(*cpuOutput);
gpuLabel->copyFrom(*cpuLabel);
cpuError->classificationError(*cpuOutput, *cpuLabel);
gpuError->classificationError(*gpuOutput, *gpuLabel);
cpuError->classificationError(*cpuOutput, *cpuLabel, topkSize);
gpuError->classificationError(*gpuOutput, *gpuLabel, topkSize);
TensorCheckEqual(*cpuError, *gpuError);
}
TEST(Matrix, classificationError) {
for (auto numSamples : {1, 10, 100, 1000, 70000}) {
for (auto dim : {1, 10, 100, 1000}) {
VLOG(3) << " numSamples=" << numSamples << " dim=" << dim;
testClassificationError(numSamples, dim);
for (auto numSamples : {1, 5, 31, 90, 150, 300}) {
for (auto dim :
{1, 5, 8, 10, 15, 64, 80, 120, 256, 300, 1280, 5120, 50000}) {
for (auto topkSize : {1, 5, 10, 20, 40, (int)rand() % dim + 1}) {
if (topkSize > dim) continue;
VLOG(3) << " sample= " << numSamples << " topkSize= " << topkSize
<< " dim= " << dim;
testClassificationError(numSamples, dim, topkSize);
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册