提交 abfd9fdd 编写于 作者: T tensor-tang

add clip to avoid log zero and nan

上级 7483087c
...@@ -3637,7 +3637,7 @@ void CpuMatrix::oneHotCrossEntropy(Matrix& output, IVector& label) { ...@@ -3637,7 +3637,7 @@ void CpuMatrix::oneHotCrossEntropy(Matrix& output, IVector& label) {
for (size_t i = 0; i < numSamples; ++i, out += dim) { for (size_t i = 0; i < numSamples; ++i, out += dim) {
CHECK_GE(lbl[i], 0); CHECK_GE(lbl[i], 0);
CHECK_LT((size_t)lbl[i], dim); CHECK_LT((size_t)lbl[i], dim);
cost[i] = -std::log(out[lbl[i]]); cost[i] = -std::log(std::max(out[lbl[i]], real(FLT_MIN)));
} }
} }
...@@ -3652,7 +3652,7 @@ void CpuMatrix::oneHotCrossEntropyBp(Matrix& output, IVector& label) { ...@@ -3652,7 +3652,7 @@ void CpuMatrix::oneHotCrossEntropyBp(Matrix& output, IVector& label) {
real* grad = getData(); real* grad = getData();
int* lbl = label.getData(); int* lbl = label.getData();
for (size_t i = 0; i < numSamples; ++i, out += dim, grad += dim) { for (size_t i = 0; i < numSamples; ++i, out += dim, grad += dim) {
grad[lbl[i]] -= 1 / out[lbl[i]]; grad[lbl[i]] -= 1 / std::max(out[lbl[i]], real(FLT_MIN));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册