未验证 提交 1a8be158 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Zero-Dim] Fix bug in masked_select for XPU (#49904)

上级 3bf74127
......@@ -33,6 +33,12 @@ void MaskedSelectGradKernel(const Context& dev_ctx,
auto mask_shape = phi::vectorize<int>(mask.dims());
auto xshape = phi::vectorize<int>(x_grad->dims());
if (mask.dims().size() == 0) {
mask_shape = std::vector<int>({1});
}
if (x_grad->dims().size() == 0) {
xshape = std::vector<int>({1});
}
int r = xpu::masked_select_grad(dev_ctx.x_context(),
input_data,
......
......@@ -61,6 +61,12 @@ void MaskedSelectKernel(const Context& dev_ctx,
auto input_shape = vectorize<int>(input_dim);
auto mask_shape = vectorize<int>(mask_dim);
if (input_dim.size() == 0) {
input_shape = std::vector<int>({1});
}
if (mask_dim.size() == 0) {
mask_shape = std::vector<int>({1});
}
if (out_size_cpu > 0) {
PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册