未验证 提交 0e492e43 编写于 作者: R ronnywang 提交者: GitHub

[XPU] add int32,fp32 support for conv2d_transpose (#51677)

* [XPU] add int32,fp32 support for conv2d_transpose*

* update
上级 09ae2852
......@@ -25,18 +25,16 @@ void MatMul(const Context& dev_ctx,
const DenseTensor& b,
bool trans_b,
DenseTensor* out) {
using XPUT = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out);
xpu::Context* xpu_ctx = dev_ctx.x_context();
if (std::is_same<phi::dtype::float16, T>::value) {
MatMulXPUFunction<T, int16_t>(a, b, out, trans_a, trans_b, xpu_ctx);
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == XPUFCCalcType::FC_INT32) {
MatMulXPUFunction<T, int32_t>(a, b, out, trans_a, trans_b, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
MatMulXPUFunction<T, float>(a, b, out, trans_a, trans_b, xpu_ctx);
} else {
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(a, b, out, trans_a, trans_b, xpu_ctx);
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
MatMulXPUFunction<T, float>(a, b, out, trans_a, trans_b, xpu_ctx);
} else {
MatMulXPUFunction<T, int16_t>(a, b, out, trans_a, trans_b, xpu_ctx);
}
MatMulXPUFunction<T, int16_t>(a, b, out, trans_a, trans_b, xpu_ctx);
}
}
......
......@@ -20,6 +20,7 @@ void BmmKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
using XPUT = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out);
if (x.numel() == 0 || y.numel() == 0) {
return;
......@@ -62,16 +63,13 @@ void BmmKernel(const Context& dev_ctx,
y_dims[1]));
xpu::Context* xpu_ctx = dev_ctx.x_context();
if (std::is_same<phi::dtype::float16, T>::value) {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, xpu_ctx);
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == XPUFCCalcType::FC_INT32) {
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, xpu_ctx);
} else {
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, xpu_ctx);
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, xpu_ctx);
} else {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, xpu_ctx);
}
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, xpu_ctx);
}
}
} // namespace phi
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace phi {
......@@ -48,6 +49,8 @@ void Conv2dTransposeKernel(const Context& ctx,
const std::vector<int>& dilations,
const std::string& data_format,
DenseTensor* out) {
using XPUT = typename XPUTypeTrait<T>::Type;
// The filter will be reshaped in the calculations,
// so here use an assignment operation,
// that avoids modifying the variable in the Scope.
......@@ -76,26 +79,71 @@ void Conv2dTransposeKernel(const Context& ctx,
const int img_xh = static_cast<int>(out->dims()[2]);
const int img_xw = static_cast<int>(out->dims()[3]);
int r = xpu::conv2d_transpose_v2<float, float, float, int16_t>(
ctx.x_context(),
x.data<float>(),
filter_.data<float>(),
out->data<float>(),
batch_size,
img_yc,
img_xh,
img_xw,
img_xc,
ksize,
strides,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
int fccal_type = FCCalcType<XPUT>();
if (fccal_type == XPUFCCalcType::FC_INT32) {
int r = xpu::conv2d_transpose_v2<float, float, float, int32_t>(
ctx.x_context(),
x.data<float>(),
filter_.data<float>(),
out->data<float>(),
batch_size,
img_yc,
img_xh,
img_xw,
img_xc,
ksize,
strides,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
int r = xpu::conv2d_transpose_v2<float, float, float, float>(
ctx.x_context(),
x.data<float>(),
filter_.data<float>(),
out->data<float>(),
batch_size,
img_yc,
img_xh,
img_xw,
img_xc,
ksize,
strides,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
} else {
int r = xpu::conv2d_transpose_v2<float, float, float, int16_t>(
ctx.x_context(),
x.data<float>(),
filter_.data<float>(),
out->data<float>(),
batch_size,
img_yc,
img_xh,
img_xw,
img_xc,
ksize,
strides,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
}
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册