提交 42641f17 编写于 作者: Z ZPaC

Add protection in cross entropy kernel.

上级 a8efea5c
...@@ -27,7 +27,7 @@ __global__ void CrossEntropyWithSparseKernel(const T *logits, const S *labels, c ...@@ -27,7 +27,7 @@ __global__ void CrossEntropyWithSparseKernel(const T *logits, const S *labels, c
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
T logit = logits[i * class_num + labels[i]]; T logit = logits[i * class_num + labels[i]];
if (logit <= 0) { if (logit <= 0) {
logit += epsilon; logit = epsilon;
} }
total_loss += -logf(logit); total_loss += -logf(logit);
} }
...@@ -54,8 +54,9 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label ...@@ -54,8 +54,9 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label
template <typename T, typename S> template <typename T, typename S>
__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) { __global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) {
losses[threadIdx.x] = 0; losses[threadIdx.x] = 0;
T epsilon = 1e-6;
for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) { for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) {
losses[threadIdx.x] -= logf(logits[i]) * labels[i]; losses[threadIdx.x] -= logf((logits[i] <= 0 ? epsilon : logits[i])) * labels[i];
dlogits[i] = logits[i] - labels[i]; dlogits[i] = logits[i] - labels[i];
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册