未验证 提交 8ef97088 编写于 作者: G Ghost Screaming 提交者: GitHub

Fix bug of c_softmax_with_cross_entropy_op_xpu_op (#52296)

* Support ignore_index for c_softmax_with_cross_entropy_op.

* Polish code. Remove useless comments and add Testcase.

* Polish code for TestCase.

* Polish code.

* Polish code style.

* Polish code.

* Change loss calculation formula and ignore_index dtype.

* Polish TestCase.

* Fix bug of c_softmax_with_cross_entropy_op_xpu_op. Attribute 'ignore_index'
dtype is int64_t.
上级 dfa893fd
......@@ -33,12 +33,12 @@ template <typename DeviceContext, typename T>
class CSoftmaxWithCrossEntropyOp : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int ignore_index = ctx.Attr<int>("ignore_index");
const int64_t ignore_index = ctx.Attr<int64_t>("ignore_index");
PADDLE_ENFORCE_LT(ignore_index,
0,
platform::errors::InvalidArgument(
"When SoftmaxWithCrossEntropy run on XPU, "
"ignore_index should be <=0, however it's %d",
"ignore_index should be <=0, however it's %ld",
ignore_index));
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
......@@ -460,12 +460,12 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
context.Output<phi::DenseTensor>(framework::GradVarName("Logits"));
const phi::DenseTensor* softmax =
context.Input<phi::DenseTensor>("Softmax");
const int ignore_index = context.Attr<int>("ignore_index");
const int64_t ignore_index = context.Attr<int64_t>("ignore_index");
PADDLE_ENFORCE_LT(ignore_index,
0,
platform::errors::InvalidArgument(
"When SoftmaxWithCrossEntropy run on XPU, "
"ignore_index should be <=0, however it's %d",
"ignore_index should be <=0, however it's %ld",
ignore_index));
const int rank = context.Attr<int>("rank");
auto& dev_ctx = context.template device_context<DeviceContext>();
......
......@@ -529,7 +529,7 @@ class ParallelCrossEntropy(paddle.nn.Layer):
mp_group(Group): The tensor parallel group.
name(str, optional): Normally there is no need for user to set this parameter.
For detailed information, please refer to :ref:`api_guide_Name` .
ignore_index (int, optional): Specifies a target value that is ignored and
ignore_index (long int, optional): Specifies a target value that is ignored and
does not contribute to the loss. A negative value means that no label value
needs to be ignored. Default is -100 .
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册