未验证 提交 a6b6bcbf 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add softmax with cross entropy infershape & yaml (#41351)

* add infershape and forward yaml

* add final_state call

* add base unittests

* add backward yaml and test

* fix without softmax test error

* add cross_entropy test
上级 a2b80145
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace phi {
void BilinearTensorProductGradInferMeta(const MetaTensor& x,
......@@ -103,6 +105,69 @@ void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x,
}
}
void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
const MetaTensor& softmax,
const MetaTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
MetaTensor* logits_grad,
MetaConfig config) {
auto softmax_dims = softmax.dims();
auto labels_dims = label.dims();
auto softmax_rank = softmax_dims.size();
PADDLE_ENFORCE_GE(axis,
-softmax_rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits)."));
PADDLE_ENFORCE_LT(axis,
softmax_rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits)."));
axis = phi::funcs::CanonicalAxis(axis, softmax_rank);
for (int i = 0; i < softmax_rank; i++) {
if (i != axis) {
if (config.is_runtime || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(
softmax_dims[i],
labels_dims[i],
phi::errors::InvalidArgument(
"Input(Logits) and Input(Label) should in same shape in "
"dimensions except axis."));
}
}
}
if (soft_label) {
if (config.is_runtime ||
(softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
PADDLE_ENFORCE_EQ(softmax_dims[axis],
labels_dims[axis],
phi::errors::InvalidArgument(
"If Attr(soft_label) == true, "
"the axis dimension of "
"Input(X) and Input(Label) should be equal."));
}
} else {
if (config.is_runtime || labels_dims[axis] > 0) {
PADDLE_ENFORCE_EQ(
labels_dims[axis],
1UL,
phi::errors::InvalidArgument("If Attr(soft_label) == false, "
"the axis dimension of "
"Input(Label) should be 1."));
}
}
logits_grad->set_dims(softmax.dims());
logits_grad->set_dtype(softmax.dtype());
}
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -68,6 +68,17 @@ void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x,
MetaTensor* dfilter,
MetaTensor* ddout);
void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
const MetaTensor& softmax,
const MetaTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
MetaTensor* logits_grad,
MetaConfig config = MetaConfig());
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
namespace phi {
......@@ -753,6 +754,82 @@ void CrossInferMeta(const MetaTensor& x,
out->share_lod(x);
}
void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
const MetaTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
MetaTensor* softmax,
MetaTensor* loss,
MetaConfig config) {
auto logits_dims = logits.dims();
auto labels_dims = label.dims();
auto logits_rank = logits_dims.size();
PADDLE_ENFORCE_GE(axis,
-logits_rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits)."));
PADDLE_ENFORCE_LT(axis,
logits_rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits)."));
axis = phi::funcs::CanonicalAxis(axis, logits_rank);
for (int i = 0; i < logits_rank; i++) {
if (i != axis) {
if (config.is_runtime || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[i],
labels_dims[i],
phi::errors::InvalidArgument(
"Input(Logits) and Input(Label) should in "
"same shape in dimensions except axis."));
}
}
}
if (axis != logits_rank - 1) {
PADDLE_ENFORCE_EQ(
numeric_stable_mode,
true,
phi::errors::InvalidArgument("Attr(axis) can only be -1 "
"when not in numeric_stable_mode."));
}
if (soft_label) {
if (config.is_runtime || (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[axis],
labels_dims[axis],
phi::errors::InvalidArgument(
"If Attr(soft_label) == true, "
"the axis dimension of "
"Input(X) and Input(Label) should be equal."));
}
} else {
if (config.is_runtime || labels_dims[axis] > 0) {
PADDLE_ENFORCE_EQ(
labels_dims[axis],
1UL,
phi::errors::InvalidArgument("If Attr(soft_label) == false, "
"the axis dimension of "
"Input(Label) should be 1."));
}
}
softmax->set_dims(logits_dims);
softmax->set_dtype(logits.dtype());
logits_dims[axis] = 1;
loss->set_dims(logits_dims);
loss->set_dtype(logits.dtype());
softmax->share_lod(logits);
loss->share_lod(logits);
}
void DistInferMeta(const MetaTensor& x,
const MetaTensor& y,
float p,
......
......@@ -117,6 +117,17 @@ void CrossInferMeta(const MetaTensor& x,
int axis,
MetaTensor* out);
void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
const MetaTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
MetaTensor* softmax,
MetaTensor* loss,
MetaConfig config = MetaConfig());
void DistInferMeta(const MetaTensor& x,
const MetaTensor& y,
float p,
......
......@@ -21,7 +21,7 @@ from paddle.utils import deprecated
from . import nn
from .layer_function_generator import templatedoc
from ..layer_helper import LayerHelper
from ..framework import Variable, _non_static_mode, static_only, _in_legacy_dygraph
from ..framework import Variable, _non_static_mode, static_only, _in_legacy_dygraph, in_dygraph_mode
from .. import core
from ..data_feeder import check_variable_and_dtype, check_type
from ..param_attr import ParamAttr
......@@ -1267,10 +1267,15 @@ def softmax_with_cross_entropy(logits,
ignore_index, 'numeric_stable_mode', numeric_stable_mode,
'axis', axis)
else:
softmax, loss = _C_ops.softmax_with_cross_entropy(
logits, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', numeric_stable_mode,
'axis', axis)
if in_dygraph_mode():
softmax, loss = _C_ops.final_state_cross_entropy_with_softmax(
logits, label, soft_label, True, numeric_stable_mode,
ignore_index, axis)
if _in_legacy_dygraph():
softmax, loss = _C_ops.softmax_with_cross_entropy(
logits, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', numeric_stable_mode,
'axis', axis)
if not return_softmax:
return loss
else:
......
......@@ -969,6 +969,7 @@ set_tests_properties(test_nearest_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_profiler PROPERTIES TIMEOUT 120)
set_tests_properties(test_inplace_softmax_with_cross_entropy PROPERTIES TIMEOUT 120)
set_tests_properties(test_cross_entropy2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_cross_entropy_loss PROPERTIES TIMEOUT 150)
set_tests_properties(test_fetch_unmerged PROPERTIES TIMEOUT 120)
set_tests_properties(test_gru_unit_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_activation_nn_grad PROPERTIES TIMEOUT 200)
......
......@@ -21,6 +21,7 @@ import unittest
from test_softmax_op import stable_softmax
from test_softmax_with_cross_entropy_op import cross_entropy
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
def log_softmax(x, axis=-1):
......@@ -1447,6 +1448,43 @@ class CrossEntropyLoss(unittest.TestCase):
self.assertTrue(np.allclose(static_ret, expected))
self.assertTrue(np.allclose(dy_ret_value, expected))
def test_soft_1d_dygraph_final_state_api(self):
with _test_eager_guard():
self.test_cross_entropy_loss_soft_1d()
self.test_cross_entropy_loss_soft_1d_weight()
self.test_cross_entropy_loss_soft_1d_mean()
self.test_cross_entropy_loss_soft_1d_weight_mean()
# put all testcases in one test will be failed
def test_soft_2d_dygraph_final_state_api(self):
with _test_eager_guard():
self.test_cross_entropy_loss_soft_2d()
self.test_cross_entropy_loss_soft_2d_weight_mean()
def test_other_dygraph_final_state_api(self):
with _test_eager_guard():
self.test_cross_entropy_loss_1d_with_mean_ignore()
self.test_cross_entropy_loss_1d_with_mean_ignore_negative()
self.test_cross_entropy_loss_1d_with_weight_mean_ignore()
self.test_cross_entropy_loss_1d_with_weight_mean_ignore_exceedlabel(
)
self.test_cross_entropy_loss_1d_with_weight_mean()
self.test_cross_entropy_loss_1d_with_weight_sum()
self.test_cross_entropy_loss_1d_with_weight_none()
self.test_cross_entropy_loss_1d_with_weight_none_func()
self.test_cross_entropy_loss_1d_mean()
self.test_cross_entropy_loss_1d_sum()
self.test_cross_entropy_loss_1d_none()
self.test_cross_entropy_loss_2d_with_weight_none()
self.test_cross_entropy_loss_2d_with_weight_axis_change_mean()
self.test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel(
)
self.test_cross_entropy_loss_2d_with_weight_mean()
self.test_cross_entropy_loss_2d_with_weight_sum()
self.test_cross_entropy_loss_2d_none()
self.test_cross_entropy_loss_2d_mean()
self.test_cross_entropy_loss_2d_sum()
class TestCrossEntropyFAPIError(unittest.TestCase):
def test_errors(self):
......
......@@ -1700,7 +1700,8 @@ def cross_entropy(input,
(got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=axis)
if in_dynamic_mode():
if _non_static_mode():
if soft_label == False:
valid_label = paddle.cast(
label != ignore_index, dtype=label.dtype) * label
......@@ -1718,10 +1719,15 @@ def cross_entropy(input,
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax)
else:
_, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax)
if in_dygraph_mode():
_, out = _C_ops.final_state_cross_entropy_with_softmax(
input, label, soft_label, use_softmax, True, ignore_index,
axis)
if _in_legacy_dygraph():
_, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax)
if weight is not None:
......
......@@ -382,6 +382,17 @@
func : cross
backward : cross_grad
# Part of python API paddle.nn.functional.cross_entropy
- api : cross_entropy_with_softmax
args : (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis)
output : Tensor(softmax), Tensor(loss)
infer_meta :
func : CrossEntropyWithSoftmaxInferMeta
kernel :
func : cross_entropy_with_softmax
data_type : input
backward : cross_entropy_with_softmax_grad
- api : cumprod
args : (Tensor x, int dim)
output : Tensor(out)
......
......@@ -223,6 +223,16 @@
kernel :
func : cosh_grad
- backward_api : cross_entropy_with_softmax_grad
forward : cross_entropy_with_softmax (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) -> Tensor(softmax), Tensor(loss)
args : (Tensor label, Tensor softmax, Tensor loss_grad, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis)
output : Tensor(input_grad)
infer_meta :
func : CrossEntropyWithSoftmaxGradInferMeta
kernel :
func : cross_entropy_with_softmax_grad
data_type : softmax
- backward_api : cross_grad
forward : cross (Tensor x, Tensor y, int axis = 9) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册