未验证 提交 e08d0646 编写于 作者: X XiaociZhang 提交者: GitHub

[XPU] remove range check for ignore index (#56869)

* [XPU] remove range check for ignore index

* add a log
上级 d2fedeac
......@@ -34,12 +34,9 @@ class CSoftmaxWithCrossEntropyOp : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int64_t ignore_index = ctx.Attr<int64_t>("ignore_index");
PADDLE_ENFORCE_LT(ignore_index,
0,
platform::errors::InvalidArgument(
"When SoftmaxWithCrossEntropy run on XPU, "
"ignore_index should be <=0, however it's %ld",
ignore_index));
if (ignore_index >= 0) {
LOG_FIRST_N(INFO, 1) << "XPU does not support ignore_index in mp.";
}
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
......@@ -467,12 +464,9 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
const phi::DenseTensor* softmax =
context.Input<phi::DenseTensor>("Softmax");
const int64_t ignore_index = context.Attr<int64_t>("ignore_index");
PADDLE_ENFORCE_LT(ignore_index,
0,
platform::errors::InvalidArgument(
"When SoftmaxWithCrossEntropy run on XPU, "
"ignore_index should be <=0, however it's %ld",
ignore_index));
if (ignore_index >= 0) {
LOG_FIRST_N(INFO, 1) << "XPU does not support ignore_index in mp.";
}
const int rank = context.Attr<int>("rank");
auto& dev_ctx = context.template device_context<DeviceContext>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册