未验证 提交 0e492e43 编写于 作者: R ronnywang 提交者: GitHub

[XPU] add int32,fp32 support for conv2d_transpose (#51677)

* [XPU] add int32,fp32 support for conv2d_transpose*

* update
上级 09ae2852
...@@ -25,18 +25,16 @@ void MatMul(const Context& dev_ctx, ...@@ -25,18 +25,16 @@ void MatMul(const Context& dev_ctx,
const DenseTensor& b, const DenseTensor& b,
bool trans_b, bool trans_b,
DenseTensor* out) { DenseTensor* out) {
using XPUT = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
xpu::Context* xpu_ctx = dev_ctx.x_context(); xpu::Context* xpu_ctx = dev_ctx.x_context();
if (std::is_same<phi::dtype::float16, T>::value) { int fccal_type = FCCalcType<XPUT>();
MatMulXPUFunction<T, int16_t>(a, b, out, trans_a, trans_b, xpu_ctx); if (fccal_type == XPUFCCalcType::FC_INT32) {
MatMulXPUFunction<T, int32_t>(a, b, out, trans_a, trans_b, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
MatMulXPUFunction<T, float>(a, b, out, trans_a, trans_b, xpu_ctx);
} else { } else {
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { MatMulXPUFunction<T, int16_t>(a, b, out, trans_a, trans_b, xpu_ctx);
MatMulXPUFunction<T, int32_t>(a, b, out, trans_a, trans_b, xpu_ctx);
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
MatMulXPUFunction<T, float>(a, b, out, trans_a, trans_b, xpu_ctx);
} else {
MatMulXPUFunction<T, int16_t>(a, b, out, trans_a, trans_b, xpu_ctx);
}
} }
} }
......
...@@ -20,6 +20,7 @@ void BmmKernel(const Context& dev_ctx, ...@@ -20,6 +20,7 @@ void BmmKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
using XPUT = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (x.numel() == 0 || y.numel() == 0) { if (x.numel() == 0 || y.numel() == 0) {
return; return;
...@@ -62,16 +63,13 @@ void BmmKernel(const Context& dev_ctx, ...@@ -62,16 +63,13 @@ void BmmKernel(const Context& dev_ctx,
y_dims[1])); y_dims[1]));
xpu::Context* xpu_ctx = dev_ctx.x_context(); xpu::Context* xpu_ctx = dev_ctx.x_context();
if (std::is_same<phi::dtype::float16, T>::value) { int fccal_type = FCCalcType<XPUT>();
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, xpu_ctx); if (fccal_type == XPUFCCalcType::FC_INT32) {
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, xpu_ctx);
} else { } else {
if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, xpu_ctx);
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, xpu_ctx);
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, xpu_ctx);
} else {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, xpu_ctx);
}
} }
} }
} // namespace phi } // namespace phi
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace phi { namespace phi {
...@@ -48,6 +49,8 @@ void Conv2dTransposeKernel(const Context& ctx, ...@@ -48,6 +49,8 @@ void Conv2dTransposeKernel(const Context& ctx,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::string& data_format, const std::string& data_format,
DenseTensor* out) { DenseTensor* out) {
using XPUT = typename XPUTypeTrait<T>::Type;
// The filter will be reshaped in the calculations, // The filter will be reshaped in the calculations,
// so here use an assignment operation, // so here use an assignment operation,
// that avoids modifying the variable in the Scope. // that avoids modifying the variable in the Scope.
...@@ -76,26 +79,71 @@ void Conv2dTransposeKernel(const Context& ctx, ...@@ -76,26 +79,71 @@ void Conv2dTransposeKernel(const Context& ctx,
const int img_xh = static_cast<int>(out->dims()[2]); const int img_xh = static_cast<int>(out->dims()[2]);
const int img_xw = static_cast<int>(out->dims()[3]); const int img_xw = static_cast<int>(out->dims()[3]);
int r = xpu::conv2d_transpose_v2<float, float, float, int16_t>( int fccal_type = FCCalcType<XPUT>();
ctx.x_context(), if (fccal_type == XPUFCCalcType::FC_INT32) {
x.data<float>(), int r = xpu::conv2d_transpose_v2<float, float, float, int32_t>(
filter_.data<float>(), ctx.x_context(),
out->data<float>(), x.data<float>(),
batch_size, filter_.data<float>(),
img_yc, out->data<float>(),
img_xh, batch_size,
img_xw, img_yc,
img_xc, img_xh,
ksize, img_xw,
strides, img_xc,
paddings_, ksize,
dilations_, strides,
groups, paddings_,
nullptr, dilations_,
nullptr, groups,
nullptr, nullptr,
true); nullptr,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2"); nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
int r = xpu::conv2d_transpose_v2<float, float, float, float>(
ctx.x_context(),
x.data<float>(),
filter_.data<float>(),
out->data<float>(),
batch_size,
img_yc,
img_xh,
img_xw,
img_xc,
ksize,
strides,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
} else {
int r = xpu::conv2d_transpose_v2<float, float, float, int16_t>(
ctx.x_context(),
x.data<float>(),
filter_.data<float>(),
out->data<float>(),
batch_size,
img_yc,
img_xh,
img_xw,
img_xc,
ksize,
strides,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
}
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册