未验证 提交 04f56338 编写于 作者: Y ykkk2333 提交者: GitHub

add xpu tile and concat kernel int64, test=kunlun (#51349)

上级 615fc429
...@@ -140,6 +140,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -140,6 +140,7 @@ XPUOpMap& get_kl2_ops() {
{"concat", {"concat",
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::FLOAT64,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::INT32})}, phi::DataType::INT32})},
{"conv2d_grad", {"conv2d_grad",
...@@ -730,6 +731,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -730,6 +731,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::INT32, XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::BOOL, phi::DataType::BOOL,
phi::DataType::FLOAT64,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"tile_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"tile_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"transpose2_grad", {"transpose2_grad",
......
...@@ -116,6 +116,7 @@ PD_REGISTER_KERNEL(concat, ...@@ -116,6 +116,7 @@ PD_REGISTER_KERNEL(concat,
ALL_LAYOUT, ALL_LAYOUT,
phi::ConcatKernel, phi::ConcatKernel,
float, float,
double,
phi::dtype::float16, phi::dtype::float16,
int64_t, int64_t,
int) {} int) {}
...@@ -112,6 +112,33 @@ void Pool2dGradKernel(const Context& ctx, ...@@ -112,6 +112,33 @@ void Pool2dGradKernel(const Context& ctx,
true); true);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
// When output dim is 1 * 1 (1 * 1 * 1 in pool_3d), use scale
// and broadcast kernels to get same output, but better performance.
// Since the dim is special in particular models,
// use 'export XPU_POOLING_GRAD_SPECIAL=1' to open this path
if (out_h == 1 && out_w == 1 && std::is_same<T, float>::value &&
std::getenv("XPU_POOLING_GRAD_SPECIAL") != nullptr) {
xpu::ctx_guard RAII_GUARD(ctx.x_context());
float scale = 1.0 / (in_h * in_w);
float* scaled_dy = RAII_GUARD.alloc_l3_or_gm<float>(n * c);
r = xpu::scale(ctx.x_context(),
dout.data<float>(),
scaled_dy,
n * c,
true,
scale,
0.0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
r = xpu::broadcast(ctx.x_context(),
scaled_dy,
dx->data<float>(),
{n, c, 1, 1},
{n, c, in_h, in_w});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
return;
}
r = xpu::adaptive_avg_pool2d_grad<XPUType>( r = xpu::adaptive_avg_pool2d_grad<XPUType>(
ctx.x_context(), ctx.x_context(),
reinterpret_cast<const XPUType*>(dout.data<T>()), reinterpret_cast<const XPUType*>(dout.data<T>()),
...@@ -267,6 +294,31 @@ void Pool3dGradKernel(const Context& ctx, ...@@ -267,6 +294,31 @@ void Pool3dGradKernel(const Context& ctx,
!channel_last); !channel_last);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
if (out_d == 1 && out_h == 1 && out_w == 1 &&
std::is_same<T, float>::value &&
std::getenv("XPU_POOLING_GRAD_SPECIAL") != nullptr) {
xpu::ctx_guard RAII_GUARD(ctx.x_context());
float scale = 1.0 / (in_d * in_h * in_w);
float* scaled_dy = RAII_GUARD.alloc_l3_or_gm<float>(n * c);
r = xpu::scale(ctx.x_context(),
dout.data<float>(),
scaled_dy,
n * c,
true,
scale,
0.0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
r = xpu::broadcast(ctx.x_context(),
scaled_dy,
dx->data<float>(),
{n, c, 1, 1, 1},
{n, c, in_d, in_h, in_w});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
return;
}
r = xpu::adaptive_avg_pool3d_grad<XPUType>( r = xpu::adaptive_avg_pool3d_grad<XPUType>(
ctx.x_context(), ctx.x_context(),
reinterpret_cast<const XPUType*>(dout.data<T>()), reinterpret_cast<const XPUType*>(dout.data<T>()),
......
...@@ -29,6 +29,7 @@ void TileKernel(const Context& dev_ctx, ...@@ -29,6 +29,7 @@ void TileKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& repeat_times_arr, const IntArray& repeat_times_arr,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto rank = x.dims().size(); auto rank = x.dims().size();
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
rank, rank,
...@@ -104,12 +105,21 @@ void TileKernel(const Context& dev_ctx, ...@@ -104,12 +105,21 @@ void TileKernel(const Context& dev_ctx,
if (repeat_times == temp) { if (repeat_times == temp) {
out->Resize(x.dims()); out->Resize(x.dims());
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
int r = if (std::is_same<T, double>::value) {
xpu::copy(dev_ctx.x_context(), x.data<T>(), out->data<T>(), x.numel()); int r = xpu::copy(dev_ctx.x_context(),
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); reinterpret_cast<const int8_t*>(x.data<double>()),
reinterpret_cast<int8_t*>(out->data<double>()),
8 * x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
} else {
int r = xpu::copy(
dev_ctx.x_context(), x.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
return; return;
} }
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int ret = XPU_SUCCESS; int ret = XPU_SUCCESS;
if (std::is_same<T, bool>::value) { if (std::is_same<T, bool>::value) {
ret = xpu::broadcast<int8_t>(dev_ctx.x_context(), ret = xpu::broadcast<int8_t>(dev_ctx.x_context(),
...@@ -118,6 +128,24 @@ void TileKernel(const Context& dev_ctx, ...@@ -118,6 +128,24 @@ void TileKernel(const Context& dev_ctx,
vec_in_dims, vec_in_dims,
vec_out_dims); vec_out_dims);
} else if (std::is_same<T, double>::value) {
float* x_t = RAII_GUARD.alloc_l3_or_gm<float>(x.numel());
float* y_t = RAII_GUARD.alloc_l3_or_gm<float>(out->numel());
int r =
xpu::cast<XPUType, float>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
x_t,
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
ret = xpu::broadcast<float>(
dev_ctx.x_context(), x_t, y_t, vec_in_dims, vec_out_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast");
r = xpu::cast<float, XPUType>(dev_ctx.x_context(),
y_t,
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
} else { } else {
ret = xpu::broadcast<T>(dev_ctx.x_context(), ret = xpu::broadcast<T>(dev_ctx.x_context(),
x.data<T>(), x.data<T>(),
...@@ -131,4 +159,5 @@ void TileKernel(const Context& dev_ctx, ...@@ -131,4 +159,5 @@ void TileKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
tile, XPU, ALL_LAYOUT, phi::TileKernel, bool, float, int, int64_t) {} tile, XPU, ALL_LAYOUT, phi::TileKernel, bool, float, double, int, int64_t) {
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册