Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
da620ca1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
da620ca1
编写于
3月 21, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(imperative): specialize batchnorm implementation
GitOrigin-RevId: 83a82590441b9ea4078e5df3117f788652e96745
上级
5ebc9d50
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
115 addition
and
1 deletion
+115
-1
imperative/src/impl/ops/batch_norm.cpp
imperative/src/impl/ops/batch_norm.cpp
+115
-1
未找到文件。
imperative/src/impl/ops/batch_norm.cpp
浏览文件 @
da620ca1
...
...
@@ -10,6 +10,8 @@
*/
#include "megbrain/opr/dnn/batch_norm.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/autogen.h"
...
...
@@ -138,7 +140,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
SmallVector
<
LogicalTensorDesc
>
out_shapes
(
nr_out
);
auto
&&
i0
=
inputs
[
0
];
auto
&&
i1
=
inputs
[
1
];
// [running_mean, running_var,] save_mean, save_var
// [running_mean, running_var,] save_mean, save_var
iance
for
(
size_t
i
=
0
;
i
<
nr_out
-
2
;
++
i
)
{
out_shapes
[
i
]
=
{
i1
.
layout
,
i1
.
comp_node
};
}
...
...
@@ -148,10 +150,122 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return
{
out_shapes
,
out_shapes
[
nr_out
-
1
].
layout
.
ndim
!=
0
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
BatchNorm
>
();
auto
&&
comp_node
=
inputs
[
0
]
->
comp_node
();
using
TensorND
=
megdnn
::
TensorND
;
SmallVector
<
TensorND
>
inp_tensornds
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
inp_tensornds
[
i
]
=
inputs
[
i
]
->
dnn_tensor
();
}
DnnOprCaller
<
megdnn
::
BN
>
dnn_opr
(
comp_node
);
dnn_opr
.
op
->
param
()
=
op_def
.
param
();
TensorLayout
src_layout
=
inputs
[
0
]
->
layout
();
TensorLayout
scale_layout
=
inputs
[
1
]
->
layout
();
bool
empty_input
=
src_layout
.
is_empty
();
size_t
nr_inp
=
inputs
.
size
();
DeviceTensorND
ws
,
reserve
;
size_t
sz
=
0
,
rsz
=
0
;
TensorLayout
w_layout
({
sz
},
dtype
::
Byte
());
TensorLayout
r_layout
({
rsz
},
dtype
::
Byte
());
if
(
!
empty_input
)
{
sz
=
dnn_opr
.
op
->
get_workspace_in_bytes
(
src_layout
,
src_layout
,
src_layout
,
src_layout
,
src_layout
,
src_layout
,
src_layout
,
src_layout
,
src_layout
);
rsz
=
dnn_opr
.
op
->
get_reserve_in_bytes
(
src_layout
);
w_layout
=
TensorLayout
({
sz
},
dtype
::
Byte
());
r_layout
=
TensorLayout
({
rsz
},
dtype
::
Byte
());
}
auto
wk
=
Blob
::
make
(
comp_node
,
sz
);
auto
ptr
=
wk
->
storage
().
get
();
megdnn
::
Workspace
dnn_wk
(
ptr
,
sz
);
reserve
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
r_layout
);
// alloc memory
DeviceTensorND
y
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
src_layout
);
DeviceTensorND
save_mean
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
scale_layout
);
DeviceTensorND
save_variance
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
scale_layout
);
if
(
op_def
.
fwd_mode
==
::
megdnn
::
param
::
BN
::
FwdMode
::
INFERENCE
)
{
if
(
!
empty_input
)
dnn_opr
.
op
->
exec
(
inp_tensornds
[
0
],
inp_tensornds
[
1
],
inp_tensornds
[
2
],
inp_tensornds
[
3
],
inp_tensornds
[
4
],
save_mean
.
as_megdnn
(),
save_variance
.
as_megdnn
(),
reserve
.
as_megdnn
(),
y
.
as_megdnn
(),
dnn_wk
);
return
{
inputs
[
3
],
inputs
[
4
],
Tensor
::
make
(
reserve
),
Tensor
::
make
(
y
)};
}
else
{
DeviceTensorND
mean
,
variance
;
if
(
nr_inp
==
5
)
{
mean
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
scale_layout
);
variance
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
scale_layout
);
megdnn
::
RefPtr
src_ptr1
(
inp_tensornds
[
3
].
get_ref_ptr
().
get_ptr
(),
inputs
[
3
]
->
offset
());
megdnn
::
RefPtr
dst_ptr1
(
mean
.
storage
().
get_ref_ptr
(),
mean
.
storage
().
offset
(),
false
);
comp_node
.
peer_copy_to_ref
(
comp_node
,
dst_ptr1
,
src_ptr1
,
scale_layout
.
span
().
high_byte
);
megdnn
::
RefPtr
src_ptr2
(
inp_tensornds
[
4
].
get_ref_ptr
().
get_ptr
(),
inputs
[
4
]
->
offset
());
megdnn
::
RefPtr
dst_ptr2
(
variance
.
storage
().
get_ref_ptr
(),
variance
.
storage
().
offset
(),
false
);
comp_node
.
peer_copy_to_ref
(
comp_node
,
dst_ptr2
,
src_ptr2
,
scale_layout
.
span
().
high_byte
);
if
(
!
empty_input
)
dnn_opr
.
op
->
exec
(
inp_tensornds
[
0
],
inp_tensornds
[
1
],
inp_tensornds
[
2
],
mean
.
as_megdnn
(),
variance
.
as_megdnn
(),
save_mean
.
as_megdnn
(),
save_variance
.
as_megdnn
(),
reserve
.
as_megdnn
(),
y
.
as_megdnn
(),
dnn_wk
);
return
{
Tensor
::
make
(
mean
),
Tensor
::
make
(
variance
),
Tensor
::
make
(
save_mean
),
Tensor
::
make
(
save_variance
),
Tensor
::
make
(
reserve
),
Tensor
::
make
(
y
)};
}
TensorLayout
m_layout
({
0
},
scale_layout
.
dtype
);
mean
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
m_layout
);
variance
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
comp_node
,
m_layout
);
if
(
!
empty_input
)
{
dnn_opr
.
op
->
exec
(
inp_tensornds
[
0
],
inp_tensornds
[
1
],
inp_tensornds
[
2
],
mean
.
as_megdnn
(),
variance
.
as_megdnn
(),
save_mean
.
as_megdnn
(),
save_variance
.
as_megdnn
(),
reserve
.
as_megdnn
(),
y
.
as_megdnn
(),
dnn_wk
);
}
return
{
Tensor
::
make
(
save_mean
),
Tensor
::
make
(
save_variance
),
Tensor
::
make
(
reserve
),
Tensor
::
make
(
y
)};
}
}
OP_TRAIT_REG
(
BatchNorm
,
BatchNorm
,
opr
::
BatchNorm
)
.
make_from_op_node
(
make_from_op_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace bn
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录