未验证 提交 075d6b14 编写于 作者: J jameszhang 提交者: GitHub

[kunlun] bugfix for collective softmax_with_ce (#52565)

上级 d947b20a
...@@ -144,6 +144,12 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> { ...@@ -144,6 +144,12 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
phi::DenseTensor predicted_logits; phi::DenseTensor predicted_logits;
predicted_logits = predicted_logits =
ctx.AllocateTmpTensor<T, phi::XPUContext>({N, 1}, dev_ctx); ctx.AllocateTmpTensor<T, phi::XPUContext>({N, 1}, dev_ctx);
ret = xpu::constant<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<XPUType*>(predicted_logits.data<T>()),
N,
0.0);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant");
const int start_index = rank * D; const int start_index = rank * D;
const int end_index = start_index + D; const int end_index = start_index + D;
const auto& label_type = framework::TransToProtoVarType(labels->dtype()); const auto& label_type = framework::TransToProtoVarType(labels->dtype());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册