Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
212b51ef
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
212b51ef
编写于
8月 29, 2022
作者:
C
cambriconhsq
提交者:
GitHub
8月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MLU] optimize matmul_grad_v2 dy (B,M,K)*(K,N) for better performance (#45336)
上级
ac0a2e50
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
69 addition
and
15 deletion
+69
-15
paddle/fluid/operators/matmul_v2_op_mlu.cc
paddle/fluid/operators/matmul_v2_op_mlu.cc
+56
-15
python/paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py
...paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py
+13
-0
未找到文件。
paddle/fluid/operators/matmul_v2_op_mlu.cc
浏览文件 @
212b51ef
...
...
@@ -68,6 +68,37 @@ static void MatMul2D(const framework::ExecutionContext& ctx,
GetBasePtr
(
Out
));
}
template
<
typename
T
>
static
void
MatMul2DwithReduceBatch
(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
X
,
const
Tensor
&
Y
,
Tensor
*
Out
,
const
bool
trans_x
,
const
bool
trans_y
)
{
if
(
!
Out
->
initialized
())
{
Out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
// reshape to 2D matmul
std
::
vector
<
int64_t
>
x_dims
=
phi
::
vectorize
(
X
.
dims
());
std
::
vector
<
int64_t
>
y_dims
=
phi
::
vectorize
(
Y
.
dims
());
std
::
vector
<
int
>
realx_dims
(
{
static_cast
<
int
>
(
x_dims
[
0
]
*
x_dims
[
1
]),
static_cast
<
int
>
(
x_dims
[
2
])});
std
::
vector
<
int
>
realy_dims
(
{
static_cast
<
int
>
(
y_dims
[
0
]
*
y_dims
[
1
]),
static_cast
<
int
>
(
y_dims
[
2
])});
MLUCnnlTensorDesc
x_desc
(
2
,
realx_dims
.
data
(),
ToCnnlDataType
<
T
>
());
MLUCnnlTensorDesc
y_desc
(
2
,
realy_dims
.
data
(),
ToCnnlDataType
<
T
>
());
MLUCnnlTensorDesc
out_desc
(
*
Out
,
CNNL_LAYOUT_ARRAY
,
ToCnnlDataType
<
T
>
());
MLUCnnl
::
Matmul
(
ctx
,
trans_x
,
trans_y
,
x_desc
.
get
(),
GetBasePtr
(
&
X
),
y_desc
.
get
(),
GetBasePtr
(
&
Y
),
out_desc
.
get
(),
GetBasePtr
(
Out
));
}
template
<
typename
T
>
static
void
MatMulND
(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
X
,
...
...
@@ -333,22 +364,32 @@ class MatMulGradV2MLUKernel : public framework::OpKernel<T> {
}
if
(
dY
)
{
Tensor
dy_temp
(
Y
->
type
());
if
(
y_dims
!=
y_bcast_dims
)
{
dy_temp
.
Resize
(
phi
::
make_ddim
(
y_bcast_dims
));
}
else
{
dY
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dy_temp
.
ShareDataWith
(
*
dY
);
}
if
(
trans_y
)
{
MatMulND
<
T
>
(
ctx
,
dout_temp
,
x_temp
,
&
dy_temp
,
true
,
trans_x
);
// Case 3: [B, M, K] x [K, N] = [B, M, N] better performance
// otherwise, tensor dy_temp in else branch might encounter
// numel overflow due to cnnlTensorDescriptor limitation
if
(
x_dims
.
size
()
==
3
&&
phi
::
vectorize
(
Y
->
dims
()).
size
()
==
2
)
{
if
(
trans_y
)
{
MatMul2DwithReduceBatch
<
T
>
(
ctx
,
dout_temp
,
x_temp
,
dY
,
true
,
trans_x
);
}
else
{
MatMul2DwithReduceBatch
<
T
>
(
ctx
,
x_temp
,
dout_temp
,
dY
,
!
trans_x
,
false
);
}
}
else
{
MatMulND
<
T
>
(
ctx
,
x_temp
,
dout_temp
,
&
dy_temp
,
!
trans_x
,
false
);
}
if
(
y_dims
!=
y_bcast_dims
)
{
ReduceDims
<
T
>
(
ctx
,
y_dims
,
y_bcast_dims
,
dy_temp
,
dY
);
Tensor
dy_temp
(
Y
->
type
());
if
(
y_dims
!=
y_bcast_dims
)
{
dy_temp
.
Resize
(
phi
::
make_ddim
(
y_bcast_dims
));
}
else
{
dY
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dy_temp
.
ShareDataWith
(
*
dY
);
}
if
(
trans_y
)
{
MatMulND
<
T
>
(
ctx
,
dout_temp
,
x_temp
,
&
dy_temp
,
true
,
trans_x
);
}
else
{
MatMulND
<
T
>
(
ctx
,
x_temp
,
dout_temp
,
&
dy_temp
,
!
trans_x
,
false
);
}
if
(
y_dims
!=
y_bcast_dims
)
{
ReduceDims
<
T
>
(
ctx
,
y_dims
,
y_bcast_dims
,
dy_temp
,
dY
);
}
}
}
}
...
...
python/paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py
浏览文件 @
212b51ef
...
...
@@ -264,6 +264,18 @@ class TestMatMuklOp17(TestMatMulV2Op):
self
.
trans_y
=
False
class
TestMatMuklOp18
(
TestMatMulV2Op
):
"""
case 18 : to check the gradient for special case
"""
def
config
(
self
):
self
.
x_shape
=
(
2
,
32
,
100
)
self
.
y_shape
=
(
100
,
10
)
self
.
trans_x
=
False
self
.
trans_y
=
False
class
TestMatMuklOpBroadcast1
(
TestMatMulV2Op
):
"""
case 14_3
...
...
@@ -328,6 +340,7 @@ create_test_fp16_class(TestMatMuklOp14)
create_test_fp16_class
(
TestMatMuklOp15
)
create_test_fp16_class
(
TestMatMuklOp16
)
create_test_fp16_class
(
TestMatMuklOp17
)
create_test_fp16_class
(
TestMatMuklOp18
)
class
TestMatMulV2API
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录