未验证 提交 6916215e 编写于 作者: Z zhangyikun02 提交者: GitHub

matmul_v2 support new case and fix masked_select bug for xpu, test=kunlun (#47370)

上级 cd59c10c
......@@ -62,14 +62,16 @@ void MaskedSelectKernel(const Context& dev_ctx,
auto input_shape = vectorize<int>(input_dim);
auto mask_shape = vectorize<int>(mask_dim);
PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(),
input_data,
mask_data,
out_data,
input_shape,
mask_shape,
out_size_cpu),
"masked_select");
if (out_size_cpu > 0) {
PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(),
input_data,
mask_data,
out_data,
input_shape,
mask_shape,
out_size_cpu),
"masked_select");
}
}
} // namespace phi
......
......@@ -56,6 +56,15 @@ void MatmulGradKernel(const Context& dev_ctx,
: reinterpret_cast<XPUType*>(dx->data<T>());
XPUType* c_2 = (dy == NULL) ? reinterpret_cast<XPUType*>(NULL)
: reinterpret_cast<XPUType*>(dy->data<T>());
if (info_forward.is_x_need_broadcast) {
XPUType* new_c_1 = nullptr;
new_c_1 = RAII_GUARD.alloc_l3_or_gm<XPUType>(
info_forward.bs * info_forward.m * info_forward.k);
PADDLE_ENFORCE_XDNN_NOT_NULL(new_c_1);
c_1 = new_c_1;
}
XpuFcInfo info_dx;
XpuFcInfo info_dy;
std::tuple<XpuFcInfo,
......@@ -75,6 +84,15 @@ void MatmulGradKernel(const Context& dev_ctx,
std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info;
if (dx) {
MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f);
if (info_forward.is_x_need_broadcast) {
int r = xpu::reduce_sum<XPUType>(
xpu_ctx,
c_1,
reinterpret_cast<XPUType*>(dx->data<T>()),
{info_forward.bs, info_forward.m, info_forward.k},
{0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
}
}
if (dy) {
MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
......
......@@ -58,6 +58,7 @@ struct XpuFcInfo {
float* max_x;
float* max_y;
float* max_out;
bool is_x_need_broadcast;
XpuFcInfo()
: bs(0),
m(0),
......@@ -70,7 +71,8 @@ struct XpuFcInfo {
stride_out(0),
max_x(nullptr),
max_y(nullptr),
max_out(nullptr) {}
max_out(nullptr),
is_x_need_broadcast(false) {}
void InitFcInfo(int bs,
int m,
int n,
......@@ -145,8 +147,12 @@ static void GetFCInfo(const phi::DDim& x_dims,
y_dims.to_str(),
mat_dim_a.trans_,
mat_dim_b.trans_));
mat_dim_b.height_ *= mat_dim_b.batch_size_;
mat_dim_b.batch_size_ = 0;
if (mat_dim_a.width_ == mat_dim_b.batch_size_ * mat_dim_b.height_) {
mat_dim_b.height_ *= mat_dim_b.batch_size_;
mat_dim_b.batch_size_ = 0;
} else {
info->is_x_need_broadcast = true;
}
}
if (mat_dim_a.width_ == mat_dim_b.height_) {
......@@ -171,7 +177,7 @@ static void GetFCInfo(const phi::DDim& x_dims,
info->m = mat_dim_a.height_;
info->n = mat_dim_b.width_;
info->k = mat_dim_a.width_;
info->bs = mat_dim_a.batch_size_;
info->bs = std::max(mat_dim_a.batch_size_, mat_dim_b.batch_size_);
info->trans_x = trans_x;
info->trans_y = trans_y;
......@@ -406,6 +412,7 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
float* max_x = fcinfo.max_x;
float* max_y = fcinfo.max_y;
float* max_out = fcinfo.max_out;
bool is_x_need_broadcast = fcinfo.is_x_need_broadcast;
if (batch_size <= 1) {
fc_api(xpu_ctx,
......@@ -428,6 +435,19 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
nullptr,
xpu::Activation_t::LINEAR);
} else {
const XPUType* x_data = reinterpret_cast<const XPUType*>(x);
if (is_x_need_broadcast) {
XPUType* x_broadcast_data = nullptr;
xpu::ctx_guard RAII_GUARD(xpu_ctx);
x_broadcast_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(batch_size * m * k);
PADDLE_ENFORCE_XDNN_NOT_NULL(x_broadcast_data);
std::vector<int> x_shape = {1, m, k};
std::vector<int> new_x_shape = {batch_size, m, k};
int r = xpu::broadcast<XPUType>(
xpu_ctx, x_data, x_broadcast_data, x_shape, new_x_shape);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
x_data = x_broadcast_data;
}
// batch matmul
fc_batch_api(xpu_ctx, // Context* ctx,
batch_size, // int batch_size,
......@@ -437,7 +457,7 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
n, // int n,
k, // int k,
alpha, // float alpha,
reinterpret_cast<const XPUType*>(x), // const TX* x,
x_data, // const TX* x,
ldx, // int stride_a,
reinterpret_cast<const XPUType*>(y), // const TW* w,
ldy, // int stride_b,
......@@ -554,6 +574,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
nullptr,
max_dout,
nullptr);
dy_shape.is_x_need_broadcast = dout_shape.is_x_need_broadcast;
dy_a = x, dy_b = dout_new;
} else if (trans_y) {
// dx = dout * y
......@@ -600,6 +621,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
nullptr,
max_dout,
nullptr);
dy_shape.is_x_need_broadcast = dout_shape.is_x_need_broadcast;
dy_a = x, dy_b = dout_new;
}
std::tuple<XpuFcInfo, XpuFcInfo, const T*, const T*, const T*, const T*>
......
......@@ -294,6 +294,30 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper):
self.trans_x = False
self.trans_y = False
class TestMatMulOp19(TestMatMulV2Op):
"""
case 19 : (x.ndim <= 2) && (y.ndim >= 3),
x need to broadcast and trans_y is false
"""
def config(self):
self.x_shape = (10, 20)
self.y_shape = (2, 20, 4)
self.trans_x = False
self.trans_y = False
class TestMatMulOp20(TestMatMulV2Op):
"""
case 20 : (x.ndim <= 2) && (y.ndim >= 3),
x need to broadcast and trans_y is false
"""
def config(self):
self.x_shape = (20, 10)
self.y_shape = (2, 20, 4)
self.trans_x = True
self.trans_y = False
support_types = get_xpu_op_support_types('matmul_v2')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册