提交 23170e21 编写于 作者: XYZ_916's avatar XYZ_916 提交者: chajchaj

solve the bug divide by zero in softmax_with_cross_entropy_op, test=develop

上级 6327c33b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -44,10 +46,24 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = softmax->dims()[axis];
PADDLE_ENFORCE_GT(
axis_dim, 0,
platform::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
softmax_out->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
const int n = SizeToAxis(axis, softmax->dims());
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"axis size is %d.",
n));
const int d = SizeFromAxis(axis, softmax->dims());
Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
......@@ -80,11 +96,23 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
const int rank = logits->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis];
PADDLE_ENFORCE_GT(
axis_dim, 0,
platform::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
const int n = SizeToAxis(axis, logits->dims());
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"axis size is %d.",
n));
const int d = SizeFromAxis(axis, logits->dims());
Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({n, d});
......@@ -123,8 +151,20 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
const int rank = logit_grad->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logit_grad->dims()[axis];
PADDLE_ENFORCE_GT(
axis_dim, 0,
platform::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
const int n = SizeToAxis(axis, logit_grad->dims());
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"axis size is %d.",
n));
const int d = SizeFromAxis(axis, logit_grad->dims());
Tensor logit_grad_2d, labels_2d, out_grad_2d;
logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册