未验证 提交 e588f2d9 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] add 0D Tensor UT case for XPU and expand kernel support 0D (#53555)

* [Zero-Dim] add 0D Tensor UT case for XPU

* fix comment

* remove some unnecessary UT
上级 a37ef769
......@@ -37,7 +37,8 @@ void ExpandGradKernel(const Context& ctx,
// Two zero
if (out_grad_dims.size() == 0 && in_grad_dims.size() == 0) {
return;
out_grad_dims = {1};
in_grad_dims = {1};
}
int r = xpu::expand_grad<XPUType>(
......
......@@ -94,26 +94,17 @@ void ExpandKernel(const Context& ctx,
shape_size,
rank));
if (shape_size == 0) {
phi::DDim out_dims = phi::make_ddim(final_expand_shape);
out->Resize(out_dims);
ctx.template Alloc<T>(out);
int r = xpu::copy<XPUType>(ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
DDim out_dims = phi::make_ddim(final_expand_shape);
out->Resize(out_dims);
ctx.template Alloc<T>(out);
auto& x_shape = vec_in_dims;
auto out_shape = phi::vectorize<int>(out_dims);
if (shape_size == 0) {
x_shape = {1};
out_shape = {1};
}
int r = XPU_SUCCESS;
if (std::is_same<T, bool>::value) {
auto x_data = reinterpret_cast<const int8_t*>(x.data<T>());
auto out_data = reinterpret_cast<int8_t*>(out->data<T>());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册