未验证 提交 6916215e 编写于 作者: Z zhangyikun02 提交者: GitHub

matmul_v2 support new case and fix masked_select bug for xpu, test=kunlun (#47370)

上级 cd59c10c
...@@ -62,14 +62,16 @@ void MaskedSelectKernel(const Context& dev_ctx, ...@@ -62,14 +62,16 @@ void MaskedSelectKernel(const Context& dev_ctx,
auto input_shape = vectorize<int>(input_dim); auto input_shape = vectorize<int>(input_dim);
auto mask_shape = vectorize<int>(mask_dim); auto mask_shape = vectorize<int>(mask_dim);
PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(), if (out_size_cpu > 0) {
input_data, PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(),
mask_data, input_data,
out_data, mask_data,
input_shape, out_data,
mask_shape, input_shape,
out_size_cpu), mask_shape,
"masked_select"); out_size_cpu),
"masked_select");
}
} }
} // namespace phi } // namespace phi
......
...@@ -56,6 +56,15 @@ void MatmulGradKernel(const Context& dev_ctx, ...@@ -56,6 +56,15 @@ void MatmulGradKernel(const Context& dev_ctx,
: reinterpret_cast<XPUType*>(dx->data<T>()); : reinterpret_cast<XPUType*>(dx->data<T>());
XPUType* c_2 = (dy == NULL) ? reinterpret_cast<XPUType*>(NULL) XPUType* c_2 = (dy == NULL) ? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dy->data<T>()); : reinterpret_cast<XPUType*>(dy->data<T>());
if (info_forward.is_x_need_broadcast) {
XPUType* new_c_1 = nullptr;
new_c_1 = RAII_GUARD.alloc_l3_or_gm<XPUType>(
info_forward.bs * info_forward.m * info_forward.k);
PADDLE_ENFORCE_XDNN_NOT_NULL(new_c_1);
c_1 = new_c_1;
}
XpuFcInfo info_dx; XpuFcInfo info_dx;
XpuFcInfo info_dy; XpuFcInfo info_dy;
std::tuple<XpuFcInfo, std::tuple<XpuFcInfo,
...@@ -75,6 +84,15 @@ void MatmulGradKernel(const Context& dev_ctx, ...@@ -75,6 +84,15 @@ void MatmulGradKernel(const Context& dev_ctx,
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (dx) { if (dx) {
MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f); MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f);
if (info_forward.is_x_need_broadcast) {
int r = xpu::reduce_sum<XPUType>(
xpu_ctx,
c_1,
reinterpret_cast<XPUType*>(dx->data<T>()),
{info_forward.bs, info_forward.m, info_forward.k},
{0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
}
} }
if (dy) { if (dy) {
MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f); MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
......
...@@ -58,6 +58,7 @@ struct XpuFcInfo { ...@@ -58,6 +58,7 @@ struct XpuFcInfo {
float* max_x; float* max_x;
float* max_y; float* max_y;
float* max_out; float* max_out;
bool is_x_need_broadcast;
XpuFcInfo() XpuFcInfo()
: bs(0), : bs(0),
m(0), m(0),
...@@ -70,7 +71,8 @@ struct XpuFcInfo { ...@@ -70,7 +71,8 @@ struct XpuFcInfo {
stride_out(0), stride_out(0),
max_x(nullptr), max_x(nullptr),
max_y(nullptr), max_y(nullptr),
max_out(nullptr) {} max_out(nullptr),
is_x_need_broadcast(false) {}
void InitFcInfo(int bs, void InitFcInfo(int bs,
int m, int m,
int n, int n,
...@@ -145,8 +147,12 @@ static void GetFCInfo(const phi::DDim& x_dims, ...@@ -145,8 +147,12 @@ static void GetFCInfo(const phi::DDim& x_dims,
y_dims.to_str(), y_dims.to_str(),
mat_dim_a.trans_, mat_dim_a.trans_,
mat_dim_b.trans_)); mat_dim_b.trans_));
mat_dim_b.height_ *= mat_dim_b.batch_size_; if (mat_dim_a.width_ == mat_dim_b.batch_size_ * mat_dim_b.height_) {
mat_dim_b.batch_size_ = 0; mat_dim_b.height_ *= mat_dim_b.batch_size_;
mat_dim_b.batch_size_ = 0;
} else {
info->is_x_need_broadcast = true;
}
} }
if (mat_dim_a.width_ == mat_dim_b.height_) { if (mat_dim_a.width_ == mat_dim_b.height_) {
...@@ -171,7 +177,7 @@ static void GetFCInfo(const phi::DDim& x_dims, ...@@ -171,7 +177,7 @@ static void GetFCInfo(const phi::DDim& x_dims,
info->m = mat_dim_a.height_; info->m = mat_dim_a.height_;
info->n = mat_dim_b.width_; info->n = mat_dim_b.width_;
info->k = mat_dim_a.width_; info->k = mat_dim_a.width_;
info->bs = mat_dim_a.batch_size_; info->bs = std::max(mat_dim_a.batch_size_, mat_dim_b.batch_size_);
info->trans_x = trans_x; info->trans_x = trans_x;
info->trans_y = trans_y; info->trans_y = trans_y;
...@@ -406,6 +412,7 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, ...@@ -406,6 +412,7 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
float* max_x = fcinfo.max_x; float* max_x = fcinfo.max_x;
float* max_y = fcinfo.max_y; float* max_y = fcinfo.max_y;
float* max_out = fcinfo.max_out; float* max_out = fcinfo.max_out;
bool is_x_need_broadcast = fcinfo.is_x_need_broadcast;
if (batch_size <= 1) { if (batch_size <= 1) {
fc_api(xpu_ctx, fc_api(xpu_ctx,
...@@ -428,6 +435,19 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, ...@@ -428,6 +435,19 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
nullptr, nullptr,
xpu::Activation_t::LINEAR); xpu::Activation_t::LINEAR);
} else { } else {
const XPUType* x_data = reinterpret_cast<const XPUType*>(x);
if (is_x_need_broadcast) {
XPUType* x_broadcast_data = nullptr;
xpu::ctx_guard RAII_GUARD(xpu_ctx);
x_broadcast_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(batch_size * m * k);
PADDLE_ENFORCE_XDNN_NOT_NULL(x_broadcast_data);
std::vector<int> x_shape = {1, m, k};
std::vector<int> new_x_shape = {batch_size, m, k};
int r = xpu::broadcast<XPUType>(
xpu_ctx, x_data, x_broadcast_data, x_shape, new_x_shape);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
x_data = x_broadcast_data;
}
// batch matmul // batch matmul
fc_batch_api(xpu_ctx, // Context* ctx, fc_batch_api(xpu_ctx, // Context* ctx,
batch_size, // int batch_size, batch_size, // int batch_size,
...@@ -437,7 +457,7 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, ...@@ -437,7 +457,7 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
n, // int n, n, // int n,
k, // int k, k, // int k,
alpha, // float alpha, alpha, // float alpha,
reinterpret_cast<const XPUType*>(x), // const TX* x, x_data, // const TX* x,
ldx, // int stride_a, ldx, // int stride_a,
reinterpret_cast<const XPUType*>(y), // const TW* w, reinterpret_cast<const XPUType*>(y), // const TW* w,
ldy, // int stride_b, ldy, // int stride_b,
...@@ -554,6 +574,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx, ...@@ -554,6 +574,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
nullptr, nullptr,
max_dout, max_dout,
nullptr); nullptr);
dy_shape.is_x_need_broadcast = dout_shape.is_x_need_broadcast;
dy_a = x, dy_b = dout_new; dy_a = x, dy_b = dout_new;
} else if (trans_y) { } else if (trans_y) {
// dx = dout * y // dx = dout * y
...@@ -600,6 +621,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx, ...@@ -600,6 +621,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
nullptr, nullptr,
max_dout, max_dout,
nullptr); nullptr);
dy_shape.is_x_need_broadcast = dout_shape.is_x_need_broadcast;
dy_a = x, dy_b = dout_new; dy_a = x, dy_b = dout_new;
} }
std::tuple<XpuFcInfo, XpuFcInfo, const T*, const T*, const T*, const T*> std::tuple<XpuFcInfo, XpuFcInfo, const T*, const T*, const T*, const T*>
......
...@@ -294,6 +294,30 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper): ...@@ -294,6 +294,30 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper):
self.trans_x = False self.trans_x = False
self.trans_y = False self.trans_y = False
class TestMatMulOp19(TestMatMulV2Op):
"""
case 19 : (x.ndim <= 2) && (y.ndim >= 3),
x need to broadcast and trans_y is false
"""
def config(self):
self.x_shape = (10, 20)
self.y_shape = (2, 20, 4)
self.trans_x = False
self.trans_y = False
class TestMatMulOp20(TestMatMulV2Op):
"""
case 20 : (x.ndim <= 2) && (y.ndim >= 3),
x need to broadcast and trans_y is false
"""
def config(self):
self.x_shape = (20, 10)
self.y_shape = (2, 20, 4)
self.trans_x = True
self.trans_y = False
support_types = get_xpu_op_support_types('matmul_v2') support_types = get_xpu_op_support_types('matmul_v2')
for stype in support_types: for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册