Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
3244a9de
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3244a9de
编写于
7月 25, 2022
作者:
光明和真理
提交者:
GitHub
7月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MLU]transpose convbpf output to HWCN for better performance (#44552)
上级
f9cd526b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
11 addition
and
7 deletion
+11
-7
paddle/fluid/operators/conv_op_mlu.cc
paddle/fluid/operators/conv_op_mlu.cc
+11
-7
未找到文件。
paddle/fluid/operators/conv_op_mlu.cc
浏览文件 @
3244a9de
...
...
@@ -436,6 +436,8 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
Tensor
output_grad_tensor
(
output_grad
->
type
());
const
std
::
vector
<
int
>
perm_to_nhwc
=
{
0
,
2
,
3
,
1
};
const
std
::
vector
<
int
>
perm_to_nchw
=
{
0
,
3
,
1
,
2
};
const
std
::
vector
<
int
>
perm_hwcm_to_mchw
=
{
3
,
2
,
0
,
1
};
const
std
::
vector
<
int
>
perm_mchw_to_hwcm
=
{
2
,
3
,
1
,
0
};
if
(
channel_last
)
{
input_tensor
.
ShareDataWith
(
*
input
);
output_grad_tensor
.
ShareDataWith
(
*
output_grad
);
...
...
@@ -462,10 +464,12 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
auto
filter_grad_dims
=
filter_grad
->
dims
();
Tensor
temp_filter_grad
(
filter_grad
->
type
());
temp_filter_grad
.
mutable_data
<
T
>
({
filter_grad_dims
[
0
],
filter_grad_dims
[
2
],
filter_grad_dims
[
3
],
filter_grad_dims
[
1
]},
// Details about setting diff_w hwcn for better performance, see the CNNL
// documentation.
temp_filter_grad
.
mutable_data
<
T
>
({
filter_grad_dims
[
perm_mchw_to_hwcm
[
0
]],
filter_grad_dims
[
perm_mchw_to_hwcm
[
1
]],
filter_grad_dims
[
perm_mchw_to_hwcm
[
2
]],
filter_grad_dims
[
perm_mchw_to_hwcm
[
3
]]},
ctx
.
GetPlace
());
cnnlDataType_t
tensor_dtype
=
ToCnnlDataType
<
T
>
();
...
...
@@ -474,7 +478,7 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc
out_grad_desc
(
output_grad_tensor
,
data_layout
,
tensor_dtype
);
MLUCnnlTensorDesc
temp_filter_grad_desc
(
temp_filter_grad
,
data_layout
,
tensor_dtype
);
temp_filter_grad
,
CNNL_LAYOUT_HWCN
,
tensor_dtype
);
MLUCnnlConvolutionDesc
conv_desc
(
in_dims_size
,
paddings
.
data
(),
...
...
@@ -492,9 +496,9 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
temp_filter_grad_desc
.
get
(),
GetBasePtr
(
&
temp_filter_grad
));
// transpose filter_grad from
MHWC
to MCHW
// transpose filter_grad from
HWCM
to MCHW
TransposeFromMLUTensor
<
T
>
(
ctx
,
perm_
to_n
chw
,
perm_
hwcm_to_m
chw
,
&
temp_filter_grad
,
filter_grad
,
false
/*need_reshape_or_alloc*/
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录