未验证 提交 1a8be158 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Zero-Dim] Fix bug in masked_select for XPU (#49904)

上级 3bf74127
...@@ -33,6 +33,12 @@ void MaskedSelectGradKernel(const Context& dev_ctx, ...@@ -33,6 +33,12 @@ void MaskedSelectGradKernel(const Context& dev_ctx,
auto mask_shape = phi::vectorize<int>(mask.dims()); auto mask_shape = phi::vectorize<int>(mask.dims());
auto xshape = phi::vectorize<int>(x_grad->dims()); auto xshape = phi::vectorize<int>(x_grad->dims());
if (mask.dims().size() == 0) {
mask_shape = std::vector<int>({1});
}
if (x_grad->dims().size() == 0) {
xshape = std::vector<int>({1});
}
int r = xpu::masked_select_grad(dev_ctx.x_context(), int r = xpu::masked_select_grad(dev_ctx.x_context(),
input_data, input_data,
......
...@@ -61,6 +61,12 @@ void MaskedSelectKernel(const Context& dev_ctx, ...@@ -61,6 +61,12 @@ void MaskedSelectKernel(const Context& dev_ctx,
auto input_shape = vectorize<int>(input_dim); auto input_shape = vectorize<int>(input_dim);
auto mask_shape = vectorize<int>(mask_dim); auto mask_shape = vectorize<int>(mask_dim);
if (input_dim.size() == 0) {
input_shape = std::vector<int>({1});
}
if (mask_dim.size() == 0) {
mask_shape = std::vector<int>({1});
}
if (out_size_cpu > 0) { if (out_size_cpu > 0) {
PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(), PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册