Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6916215e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6916215e
编写于
11月 04, 2022
作者:
Z
zhangyikun02
提交者:
GitHub
11月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
matmul_v2 support new case and fix masked_select bug for xpu, test=kunlun (#47370)
上级
cd59c10c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
79 addition
and
13 deletion
+79
-13
paddle/phi/kernels/xpu/masked_select_kernel.cc
paddle/phi/kernels/xpu/masked_select_kernel.cc
+10
-8
paddle/phi/kernels/xpu/matmul_grad_kernel.cc
paddle/phi/kernels/xpu/matmul_grad_kernel.cc
+18
-0
paddle/phi/kernels/xpu/xpu_api_wrapper.h
paddle/phi/kernels/xpu/xpu_api_wrapper.h
+27
-5
python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py
...paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py
+24
-0
未找到文件。
paddle/phi/kernels/xpu/masked_select_kernel.cc
浏览文件 @
6916215e
...
...
@@ -62,14 +62,16 @@ void MaskedSelectKernel(const Context& dev_ctx,
auto
input_shape
=
vectorize
<
int
>
(
input_dim
);
auto
mask_shape
=
vectorize
<
int
>
(
mask_dim
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
xpu
::
masked_select
(
dev_ctx
.
x_context
(),
input_data
,
mask_data
,
out_data
,
input_shape
,
mask_shape
,
out_size_cpu
),
"masked_select"
);
if
(
out_size_cpu
>
0
)
{
PADDLE_ENFORCE_XDNN_SUCCESS
(
xpu
::
masked_select
(
dev_ctx
.
x_context
(),
input_data
,
mask_data
,
out_data
,
input_shape
,
mask_shape
,
out_size_cpu
),
"masked_select"
);
}
}
}
// namespace phi
...
...
paddle/phi/kernels/xpu/matmul_grad_kernel.cc
浏览文件 @
6916215e
...
...
@@ -56,6 +56,15 @@ void MatmulGradKernel(const Context& dev_ctx,
:
reinterpret_cast
<
XPUType
*>
(
dx
->
data
<
T
>
());
XPUType
*
c_2
=
(
dy
==
NULL
)
?
reinterpret_cast
<
XPUType
*>
(
NULL
)
:
reinterpret_cast
<
XPUType
*>
(
dy
->
data
<
T
>
());
if
(
info_forward
.
is_x_need_broadcast
)
{
XPUType
*
new_c_1
=
nullptr
;
new_c_1
=
RAII_GUARD
.
alloc_l3_or_gm
<
XPUType
>
(
info_forward
.
bs
*
info_forward
.
m
*
info_forward
.
k
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
new_c_1
);
c_1
=
new_c_1
;
}
XpuFcInfo
info_dx
;
XpuFcInfo
info_dy
;
std
::
tuple
<
XpuFcInfo
,
...
...
@@ -75,6 +84,15 @@ void MatmulGradKernel(const Context& dev_ctx,
std
::
tie
(
info_dx
,
info_dy
,
a_1
,
b_1
,
a_2
,
b_2
)
=
fc_info
;
if
(
dx
)
{
MatMulXPUFunction
<
XPUType
>
(
xpu_ctx
,
a_1
,
b_1
,
c_1
,
info_dx
,
1.0
f
);
if
(
info_forward
.
is_x_need_broadcast
)
{
int
r
=
xpu
::
reduce_sum
<
XPUType
>
(
xpu_ctx
,
c_1
,
reinterpret_cast
<
XPUType
*>
(
dx
->
data
<
T
>
()),
{
info_forward
.
bs
,
info_forward
.
m
,
info_forward
.
k
},
{
0
});
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"reduce_sum"
);
}
}
if
(
dy
)
{
MatMulXPUFunction
<
XPUType
>
(
xpu_ctx
,
a_2
,
b_2
,
c_2
,
info_dy
,
1.0
f
);
...
...
paddle/phi/kernels/xpu/xpu_api_wrapper.h
浏览文件 @
6916215e
...
...
@@ -58,6 +58,7 @@ struct XpuFcInfo {
float
*
max_x
;
float
*
max_y
;
float
*
max_out
;
bool
is_x_need_broadcast
;
XpuFcInfo
()
:
bs
(
0
),
m
(
0
),
...
...
@@ -70,7 +71,8 @@ struct XpuFcInfo {
stride_out
(
0
),
max_x
(
nullptr
),
max_y
(
nullptr
),
max_out
(
nullptr
)
{}
max_out
(
nullptr
),
is_x_need_broadcast
(
false
)
{}
void
InitFcInfo
(
int
bs
,
int
m
,
int
n
,
...
...
@@ -145,8 +147,12 @@ static void GetFCInfo(const phi::DDim& x_dims,
y_dims
.
to_str
(),
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
));
mat_dim_b
.
height_
*=
mat_dim_b
.
batch_size_
;
mat_dim_b
.
batch_size_
=
0
;
if
(
mat_dim_a
.
width_
==
mat_dim_b
.
batch_size_
*
mat_dim_b
.
height_
)
{
mat_dim_b
.
height_
*=
mat_dim_b
.
batch_size_
;
mat_dim_b
.
batch_size_
=
0
;
}
else
{
info
->
is_x_need_broadcast
=
true
;
}
}
if
(
mat_dim_a
.
width_
==
mat_dim_b
.
height_
)
{
...
...
@@ -171,7 +177,7 @@ static void GetFCInfo(const phi::DDim& x_dims,
info
->
m
=
mat_dim_a
.
height_
;
info
->
n
=
mat_dim_b
.
width_
;
info
->
k
=
mat_dim_a
.
width_
;
info
->
bs
=
mat_dim_a
.
batch_size_
;
info
->
bs
=
std
::
max
(
mat_dim_a
.
batch_size_
,
mat_dim_b
.
batch_size_
)
;
info
->
trans_x
=
trans_x
;
info
->
trans_y
=
trans_y
;
...
...
@@ -406,6 +412,7 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
float
*
max_x
=
fcinfo
.
max_x
;
float
*
max_y
=
fcinfo
.
max_y
;
float
*
max_out
=
fcinfo
.
max_out
;
bool
is_x_need_broadcast
=
fcinfo
.
is_x_need_broadcast
;
if
(
batch_size
<=
1
)
{
fc_api
(
xpu_ctx
,
...
...
@@ -428,6 +435,19 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
}
else
{
const
XPUType
*
x_data
=
reinterpret_cast
<
const
XPUType
*>
(
x
);
if
(
is_x_need_broadcast
)
{
XPUType
*
x_broadcast_data
=
nullptr
;
xpu
::
ctx_guard
RAII_GUARD
(
xpu_ctx
);
x_broadcast_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
XPUType
>
(
batch_size
*
m
*
k
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
x_broadcast_data
);
std
::
vector
<
int
>
x_shape
=
{
1
,
m
,
k
};
std
::
vector
<
int
>
new_x_shape
=
{
batch_size
,
m
,
k
};
int
r
=
xpu
::
broadcast
<
XPUType
>
(
xpu_ctx
,
x_data
,
x_broadcast_data
,
x_shape
,
new_x_shape
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"broadcast"
);
x_data
=
x_broadcast_data
;
}
// batch matmul
fc_batch_api
(
xpu_ctx
,
// Context* ctx,
batch_size
,
// int batch_size,
...
...
@@ -437,7 +457,7 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
n
,
// int n,
k
,
// int k,
alpha
,
// float alpha,
reinterpret_cast
<
const
XPUType
*>
(
x
),
// const TX* x,
x_data
,
// const TX* x,
ldx
,
// int stride_a,
reinterpret_cast
<
const
XPUType
*>
(
y
),
// const TW* w,
ldy
,
// int stride_b,
...
...
@@ -554,6 +574,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
nullptr
,
max_dout
,
nullptr
);
dy_shape
.
is_x_need_broadcast
=
dout_shape
.
is_x_need_broadcast
;
dy_a
=
x
,
dy_b
=
dout_new
;
}
else
if
(
trans_y
)
{
// dx = dout * y
...
...
@@ -600,6 +621,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
nullptr
,
max_dout
,
nullptr
);
dy_shape
.
is_x_need_broadcast
=
dout_shape
.
is_x_need_broadcast
;
dy_a
=
x
,
dy_b
=
dout_new
;
}
std
::
tuple
<
XpuFcInfo
,
XpuFcInfo
,
const
T
*
,
const
T
*
,
const
T
*
,
const
T
*>
...
...
python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py
浏览文件 @
6916215e
...
...
@@ -294,6 +294,30 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper):
self
.
trans_x
=
False
self
.
trans_y
=
False
class
TestMatMulOp19
(
TestMatMulV2Op
):
"""
case 19 : (x.ndim <= 2) && (y.ndim >= 3),
x need to broadcast and trans_y is false
"""
def
config
(
self
):
self
.
x_shape
=
(
10
,
20
)
self
.
y_shape
=
(
2
,
20
,
4
)
self
.
trans_x
=
False
self
.
trans_y
=
False
class
TestMatMulOp20
(
TestMatMulV2Op
):
"""
case 20 : (x.ndim <= 2) && (y.ndim >= 3),
x need to broadcast and trans_y is false
"""
def
config
(
self
):
self
.
x_shape
=
(
20
,
10
)
self
.
y_shape
=
(
2
,
20
,
4
)
self
.
trans_x
=
True
self
.
trans_y
=
False
support_types
=
get_xpu_op_support_types
(
'matmul_v2'
)
for
stype
in
support_types
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录