Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
44855da3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
44855da3
编写于
1月 20, 2023
作者:
J
jakpiase
提交者:
GitHub
1月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix for bad_alloc in oneDNN matmul_grad kernel (#48593)
* fix for matmul_grad * another fix for matmul_grad * fix
上级
ee4e5323
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
74 addition
and
46 deletion
+74
-46
paddle/phi/kernels/onednn/matmul_grad_kernel.cc
paddle/phi/kernels/onednn/matmul_grad_kernel.cc
+74
-46
未找到文件。
paddle/phi/kernels/onednn/matmul_grad_kernel.cc
浏览文件 @
44855da3
...
@@ -19,37 +19,64 @@
...
@@ -19,37 +19,64 @@
namespace
phi
{
namespace
phi
{
std
::
vector
<
int64_t
>
ExtendDimsWithOnes
(
const
std
::
vector
<
int64_t
>
&
dims
,
void
CalculateMatrixDims
(
const
std
::
vector
<
int64_t
>
&
x_dims
,
int
new_size
)
{
const
std
::
vector
<
int64_t
>
&
y_dims
,
std
::
vector
<
int64_t
>
new_dims
(
new_size
,
1
);
const
std
::
vector
<
int64_t
>
&
out_dims
,
for
(
size_t
i
=
0
;
i
<
dims
.
size
();
++
i
)
{
std
::
vector
<
int64_t
>
*
x_bd_dims
,
new_dims
[
new_size
-
dims
.
size
()
+
i
]
=
dims
[
i
];
std
::
vector
<
int64_t
>
*
y_bd_dims
,
std
::
vector
<
int64_t
>
*
out_bd_dims
,
bool
trans_x
,
bool
trans_y
)
{
if
(
x_dims
.
size
()
==
1
)
{
(
*
x_bd_dims
)[
x_bd_dims
->
size
()
-
1
]
=
x_dims
[
0
];
}
else
if
(
x_dims
.
size
()
==
2
)
{
(
*
x_bd_dims
)[
x_bd_dims
->
size
()
-
1
]
=
x_dims
[
1
];
(
*
x_bd_dims
)[
x_bd_dims
->
size
()
-
2
]
=
x_dims
[
0
];
}
else
{
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
++
i
)
{
(
*
x_bd_dims
)[
x_bd_dims
->
size
()
-
x_dims
.
size
()
+
i
]
=
x_dims
[
i
];
}
}
if
(
y_dims
.
size
()
==
1
)
{
(
*
y_bd_dims
)[
x_bd_dims
->
size
()
-
2
]
=
y_dims
[
0
];
}
else
if
(
y_dims
.
size
()
==
2
)
{
(
*
y_bd_dims
)[
y_bd_dims
->
size
()
-
1
]
=
y_dims
[
1
];
(
*
y_bd_dims
)[
y_bd_dims
->
size
()
-
2
]
=
y_dims
[
0
];
}
else
{
for
(
size_t
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
(
*
y_bd_dims
)[
y_bd_dims
->
size
()
-
y_dims
.
size
()
+
i
]
=
y_dims
[
i
];
}
}
for
(
size_t
i
=
0
;
i
<
x_bd_dims
->
size
()
-
2
;
++
i
)
{
(
*
out_bd_dims
)[
i
]
=
std
::
max
((
*
x_bd_dims
)[
i
],
(
*
y_bd_dims
)[
i
]);
}
}
int
h_idx
=
trans_x
?
x_bd_dims
->
size
()
-
1
:
x_bd_dims
->
size
()
-
2
;
int
w_idx
=
trans_y
?
y_bd_dims
->
size
()
-
2
:
y_bd_dims
->
size
()
-
1
;
return
new_dims
;
(
*
out_bd_dims
)[
x_bd_dims
->
size
()
-
2
]
=
(
*
x_bd_dims
)[
h_idx
];
(
*
out_bd_dims
)[
y_bd_dims
->
size
()
-
1
]
=
(
*
y_bd_dims
)[
w_idx
];
}
}
template
<
typename
T
>
template
<
typename
T
>
void
CalculateGradMatrixDims
(
const
OneDNNContext
&
dev_ctx
,
void
CalculateGradMatrixDims
(
const
OneDNNContext
&
dev_ctx
,
DenseTensor
*
dx_tmp
,
DenseTensor
*
dx_tmp
,
DenseTensor
*
dy_tmp
,
DenseTensor
*
dy_tmp
,
const
std
::
vector
<
int64_t
>
&
dx_dims
,
const
std
::
vector
<
int64_t
>
&
dy_dims
,
std
::
vector
<
int64_t
>
*
dx_bd_dims
,
std
::
vector
<
int64_t
>
*
dx_bd_dims
,
std
::
vector
<
int64_t
>
*
dy_bd_dims
)
{
std
::
vector
<
int64_t
>
*
dy_bd_dims
)
{
for
(
size_t
i
=
0
;
i
<
dx_
dims
.
size
()
-
2
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dx_
bd_dims
->
size
()
-
2
;
++
i
)
{
if
(
dx_dims
[
i
]
!=
dy_dims
[
i
])
{
if
(
(
*
dx_bd_dims
)[
i
]
!=
(
*
dy_bd_dims
)
[
i
])
{
if
(
dx_dims
[
i
]
==
1
)
{
if
(
(
*
dx_bd_dims
)
[
i
]
==
1
)
{
(
*
dx_bd_dims
)[
i
]
=
dy_dims
[
i
];
(
*
dx_bd_dims
)[
i
]
=
(
*
dy_bd_dims
)
[
i
];
}
else
{
}
else
{
(
*
dy_bd_dims
)[
i
]
=
dx_dims
[
i
];
(
*
dy_bd_dims
)[
i
]
=
(
*
dx_bd_dims
)
[
i
];
}
}
}
}
}
}
dx_tmp
->
Resize
(
make_ddim
(
(
*
dx_bd_dims
)
));
dx_tmp
->
Resize
(
make_ddim
(
*
dx_bd_dims
));
dev_ctx
.
template
Alloc
<
T
>(
dx_tmp
);
dev_ctx
.
template
Alloc
<
T
>(
dx_tmp
);
dy_tmp
->
Resize
(
make_ddim
(
(
*
dy_bd_dims
)
));
dy_tmp
->
Resize
(
make_ddim
(
*
dy_bd_dims
));
dev_ctx
.
template
Alloc
<
T
>(
dy_tmp
);
dev_ctx
.
template
Alloc
<
T
>(
dy_tmp
);
}
}
...
@@ -58,7 +85,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
...
@@ -58,7 +85,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
const
DenseTensor
*
dx_tmp
,
const
DenseTensor
*
dx_tmp
,
DenseTensor
*
dx
,
DenseTensor
*
dx
,
const
std
::
vector
<
int64_t
>
&
dx_dims
,
const
std
::
vector
<
int64_t
>
&
dx_dims
,
const
std
::
vector
<
int64_t
>
&
squeezed
_dims
)
{
const
std
::
vector
<
int64_t
>
&
x
_dims
)
{
funcs
::
ReductionOneDNNHandler
<
T
>
handler
(
dnnl
::
algorithm
::
reduction_sum
,
funcs
::
ReductionOneDNNHandler
<
T
>
handler
(
dnnl
::
algorithm
::
reduction_sum
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
...
@@ -66,7 +93,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
...
@@ -66,7 +93,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
dev_ctx
.
GetPlace
(),
dev_ctx
.
GetPlace
(),
dx_tmp
,
dx_tmp
,
dx
,
dx
,
d
x_dims
);
x_dims
);
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
dx_tmp
);
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
dx_tmp
);
auto
dst_memory_p
=
handler
.
AcquireDstMemory
(
dx
);
auto
dst_memory_p
=
handler
.
AcquireDstMemory
(
dx
);
...
@@ -79,8 +106,6 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
...
@@ -79,8 +106,6 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
reduction_p
->
execute
(
astream
,
reduction_args
);
reduction_p
->
execute
(
astream
,
reduction_args
);
astream
.
wait
();
astream
.
wait
();
dx
->
set_mem_desc
(
dst_memory_p
->
get_desc
().
reshape
(
squeezed_dims
));
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
...
@@ -99,64 +124,67 @@ void MatmulGradKernel(const Context &dev_ctx,
...
@@ -99,64 +124,67 @@ void MatmulGradKernel(const Context &dev_ctx,
size_t
ndims
=
std
::
max
(
x_dims
.
size
(),
y_dims
.
size
());
size_t
ndims
=
std
::
max
(
x_dims
.
size
(),
y_dims
.
size
());
ndims
=
std
::
max
<
size_t
>
(
ndims
,
3
);
ndims
=
std
::
max
<
size_t
>
(
ndims
,
3
);
if
(
x_dims
.
size
()
!=
ndims
)
{
x_dims
=
ExtendDimsWithOnes
(
x_dims
,
ndims
);
}
if
(
y_dims
.
size
()
!=
ndims
)
{
y_dims
=
ExtendDimsWithOnes
(
y_dims
,
ndims
);
}
if
(
dout_dims
.
size
()
!=
ndims
)
{
dout_dims
=
ExtendDimsWithOnes
(
dout_dims
,
ndims
);
}
// in broadcasting scenario new memory is required because
// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
// reduce sum must be calculated upon broadcasted dims
DenseTensor
dx_tmp
,
dy_tmp
;
DenseTensor
dx_tmp
,
dy_tmp
;
std
::
vector
<
int64_t
>
dx_bd_dims
(
x_dims
);
std
::
vector
<
int64_t
>
dout_bd_dims
(
ndims
,
1
);
std
::
vector
<
int64_t
>
dy_bd_dims
(
y_dims
);
std
::
vector
<
int64_t
>
x_bd_dims
(
ndims
,
1
);
std
::
vector
<
int64_t
>
y_bd_dims
(
ndims
,
1
);
CalculateMatrixDims
(
x_dims
,
y_dims
,
dout_dims
,
&
x_bd_dims
,
&
y_bd_dims
,
&
dout_bd_dims
,
transpose_x
,
transpose_y
);
std
::
vector
<
int64_t
>
dx_bd_dims
(
x_bd_dims
);
std
::
vector
<
int64_t
>
dy_bd_dims
(
y_bd_dims
);
CalculateGradMatrixDims
<
T
>
(
CalculateGradMatrixDims
<
T
>
(
dev_ctx
,
&
dx_tmp
,
&
dy_tmp
,
x_dims
,
y_dims
,
&
dx_bd_dims
,
&
dy_bd_dims
);
dev_ctx
,
&
dx_tmp
,
&
dy_tmp
,
&
dx_bd_dims
,
&
dy_bd_dims
);
if
(
transpose_x
&&
transpose_y
)
{
if
(
transpose_x
&&
transpose_y
)
{
funcs
::
ExecuteMatmul
<
T
,
T
>
(
funcs
::
ExecuteMatmul
<
T
,
T
>
(
dev_ctx
,
y
,
dout
,
y_
dims
,
dout
_dims
,
true
,
true
,
&
dx_tmp
);
dev_ctx
,
y
,
dout
,
y_
bd_dims
,
dout_bd
_dims
,
true
,
true
,
&
dx_tmp
);
funcs
::
ExecuteMatmul
<
T
,
T
>
(
funcs
::
ExecuteMatmul
<
T
,
T
>
(
dev_ctx
,
dout
,
x
,
dout_
dims
,
x
_dims
,
true
,
true
,
&
dy_tmp
);
dev_ctx
,
dout
,
x
,
dout_
bd_dims
,
x_bd
_dims
,
true
,
true
,
&
dy_tmp
);
}
else
if
(
transpose_x
)
{
}
else
if
(
transpose_x
)
{
funcs
::
ExecuteMatmul
<
T
,
T
>
(
funcs
::
ExecuteMatmul
<
T
,
T
>
(
dev_ctx
,
y
,
dout
,
y_
dims
,
dout
_dims
,
false
,
true
,
&
dx_tmp
);
dev_ctx
,
y
,
dout
,
y_
bd_dims
,
dout_bd
_dims
,
false
,
true
,
&
dx_tmp
);
funcs
::
ExecuteMatmul
<
T
,
T
>
(
funcs
::
ExecuteMatmul
<
T
,
T
>
(
dev_ctx
,
x
,
dout
,
x_
dims
,
dout
_dims
,
false
,
false
,
&
dy_tmp
);
dev_ctx
,
x
,
dout
,
x_
bd_dims
,
dout_bd
_dims
,
false
,
false
,
&
dy_tmp
);
}
else
if
(
transpose_y
)
{
}
else
if
(
transpose_y
)
{
funcs
::
ExecuteMatmul
<
T
,
T
>
(
funcs
::
ExecuteMatmul
<
T
,
T
>
(
dev_ctx
,
dout
,
y
,
dout_
dims
,
y
_dims
,
false
,
false
,
&
dx_tmp
);
dev_ctx
,
dout
,
y
,
dout_
bd_dims
,
y_bd
_dims
,
false
,
false
,
&
dx_tmp
);
funcs
::
ExecuteMatmul
<
T
,
T
>
(
funcs
::
ExecuteMatmul
<
T
,
T
>
(
dev_ctx
,
dout
,
x
,
dout_
dims
,
x
_dims
,
true
,
false
,
&
dy_tmp
);
dev_ctx
,
dout
,
x
,
dout_
bd_dims
,
x_bd
_dims
,
true
,
false
,
&
dy_tmp
);
}
else
{
}
else
{
funcs
::
ExecuteMatmul
<
T
,
T
>
(
funcs
::
ExecuteMatmul
<
T
,
T
>
(
dev_ctx
,
dout
,
y
,
dout_
dims
,
y
_dims
,
false
,
true
,
&
dx_tmp
);
dev_ctx
,
dout
,
y
,
dout_
bd_dims
,
y_bd
_dims
,
false
,
true
,
&
dx_tmp
);
funcs
::
ExecuteMatmul
<
T
,
T
>
(
funcs
::
ExecuteMatmul
<
T
,
T
>
(
dev_ctx
,
x
,
dout
,
x_
dims
,
dout
_dims
,
true
,
false
,
&
dy_tmp
);
dev_ctx
,
x
,
dout
,
x_
bd_dims
,
dout_bd
_dims
,
true
,
false
,
&
dy_tmp
);
}
}
if
(
x_dims
!=
dx_bd_dims
)
{
if
(
x_
bd_
dims
!=
dx_bd_dims
)
{
ReduceSumForMatmulGradOutput
<
T
>
(
ReduceSumForMatmulGradOutput
<
T
>
(
dev_ctx
,
&
dx_tmp
,
dx
,
x_dims
,
vectorize
(
x
.
dims
())
);
dev_ctx
,
&
dx_tmp
,
dx
,
dx_bd_dims
,
x_bd_dims
);
}
else
{
}
else
{
*
dx
=
std
::
move
(
dx_tmp
);
*
dx
=
std
::
move
(
dx_tmp
);
}
}
if
(
y_dims
!=
dy_bd_dims
)
{
if
(
y_
bd_
dims
!=
dy_bd_dims
)
{
ReduceSumForMatmulGradOutput
<
T
>
(
ReduceSumForMatmulGradOutput
<
T
>
(
dev_ctx
,
&
dy_tmp
,
dy
,
y_dims
,
vectorize
(
y
.
dims
())
);
dev_ctx
,
&
dy_tmp
,
dy
,
dy_bd_dims
,
y_bd_dims
);
}
else
{
}
else
{
*
dy
=
std
::
move
(
dy_tmp
);
*
dy
=
std
::
move
(
dy_tmp
);
}
}
dx
->
set_mem_desc
(
x
.
mem_desc
());
dx
->
Resize
(
x
.
dims
());
dx
->
Resize
(
x
.
dims
());
d
x
->
set_mem_desc
(
x
.
mem_desc
().
reshape
(
vectorize
(
x
.
dims
())
));
d
y
->
set_mem_desc
(
y
.
mem_desc
(
));
dy
->
Resize
(
y
.
dims
());
dy
->
Resize
(
y
.
dims
());
dy
->
set_mem_desc
(
y
.
mem_desc
().
reshape
(
vectorize
(
y
.
dims
())));
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录