Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7ae05ac8
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7ae05ac8
编写于
6月 24, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): merge common c++ code to megbrain
GitOrigin-RevId: d093778e103a6977bb4c5c9da85005e276d60e50
上级
9e904f68
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
220 addition
and
53 deletion
+220
-53
CMakeLists.txt
CMakeLists.txt
+37
-4
dnn/CMakeLists.txt
dnn/CMakeLists.txt
+3
-1
dnn/scripts/gen_param_defs.py
dnn/scripts/gen_param_defs.py
+37
-14
src/CMakeLists.txt
src/CMakeLists.txt
+16
-12
src/core/impl/graph/cg_impl.cpp
src/core/impl/graph/cg_impl.cpp
+17
-0
src/core/impl/graph/grad_manager.cpp
src/core/impl/graph/grad_manager.cpp
+9
-0
src/core/impl/graph/operator_node.cpp
src/core/impl/graph/operator_node.cpp
+5
-0
src/core/impl/graph/var_node.cpp
src/core/impl/graph/var_node.cpp
+4
-0
src/core/include/megbrain/graph/cg.h
src/core/include/megbrain/graph/cg.h
+2
-0
src/core/include/megbrain/graph/grad_impl.h
src/core/include/megbrain/graph/grad_impl.h
+5
-0
src/core/include/megbrain/graph/operator_node.h
src/core/include/megbrain/graph/operator_node.h
+4
-0
src/core/include/megbrain/graph/var_node.h
src/core/include/megbrain/graph/var_node.h
+5
-0
src/megbrain_build_config.h.in
src/megbrain_build_config.h.in
+2
-0
src/opr/impl/dnn/batch_norm.cpp
src/opr/impl/dnn/batch_norm.cpp
+61
-22
src/opr/impl/internal/megdnn_opr_wrapper.cpp
src/opr/impl/internal/megdnn_opr_wrapper.cpp
+9
-0
src/opr/include/megbrain/opr/dnn/batch_norm.h
src/opr/include/megbrain/opr/dnn/batch_norm.h
+4
-0
未找到文件。
CMakeLists.txt
浏览文件 @
7ae05ac8
...
...
@@ -213,6 +213,15 @@ if(MGE_WITH_TEST)
endif
()
option
(
MGE_WITH_DISTRIBUTED
"Build with distributed support"
ON
)
option
(
MGE_BUILD_XXX
"Build _xxx.so instead of mgb.so "
OFF
)
if
(
MGE_BUILD_XXX
)
set
(
CMAKE_CXX_STANDARD 17
)
endif
()
option
(
MGE_BUILD_SDK
"Build load_and_run"
ON
)
if
(
MGE_BUILD_XXX
)
set
(
MGE_BUILD_SDK OFF
)
endif
()
if
(
NOT MGE_WITH_CUDA
)
message
(
"-- Disable distributed support, as CUDA is not enabled."
)
...
...
@@ -522,7 +531,7 @@ endif()
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
MARCH
}
"
)
set
(
MGB_ENABLE_IMPERATIVE,
${
MGE_BUILD_XXX
}
)
# Write out megbrain_build_config.h
# It defines macros needed by both megbrain and dnn
configure_file
(
src/megbrain_build_config.h.in
${
CMAKE_CURRENT_BINARY_DIR
}
/genfiles/megbrain_build_config.h
)
...
...
@@ -566,14 +575,23 @@ if(MGE_WITH_DISTRIBUTED)
endif
()
add_subdirectory
(
src
)
add_subdirectory
(
sdk/load-and-run
)
if
(
MGE_BUILD_SDK
)
add_subdirectory
(
sdk/load-and-run
)
endif
()
if
(
MGE_WITH_PYTHON_MODULE
)
add_subdirectory
(
python_module
)
if
(
MGE_BUILD_XXX
)
add_subdirectory
(
imperative
)
else
()
add_subdirectory
(
python_module
)
endif
()
endif
()
if
(
MGE_WITH_TEST AND MGE_ENABLE_RTTI
)
add_subdirectory
(
test
)
if
(
NOT MGE_BUILD_XXX
)
add_subdirectory
(
test
)
endif
()
endif
()
if
(
TARGET mgb
)
...
...
@@ -597,6 +615,21 @@ if(TARGET mgb)
DEPENDS mgb
VERBATIM
)
elseif
(
TARGET _xxx
)
add_custom_target
(
develop
COMMAND
${
CMAKE_COMMAND
}
-E create_symlink
${
CMAKE_CURRENT_BINARY_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/$<TARGET_FILE_NAME:
${
MODULE_NAME
}
>
${
CMAKE_CURRENT_SOURCE_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/$<TARGET_FILE_NAME:
${
MODULE_NAME
}
>
COMMAND
${
CMAKE_COMMAND
}
-E create_symlink
${
CMAKE_CURRENT_BINARY_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/ops/_internal/generated_ops.py
${
CMAKE_CURRENT_SOURCE_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/ops/_internal/generated_ops.py
COMMAND
${
CMAKE_COMMAND
}
-E create_symlink
${
CMAKE_CURRENT_BINARY_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/ops/_internal/param_defs.py
${
CMAKE_CURRENT_SOURCE_DIR
}
/imperative/python/
${
PACKAGE_NAME
}
/ops/_internal/param_defs.py
DEPENDS _xxx
VERBATIM
)
endif
()
IF
(
APPLE
)
...
...
dnn/CMakeLists.txt
浏览文件 @
7ae05ac8
...
...
@@ -59,7 +59,9 @@ install(TARGETS opr_param_defs EXPORT ${MGE_EXPORT_TARGETS})
if
(
MGE_WITH_TEST
)
add_subdirectory
(
test
)
if
(
NOT MGE_BUILD_XXX
)
add_subdirectory
(
test
)
endif
()
endif
()
add_subdirectory
(
src
)
...
...
dnn/scripts/gen_param_defs.py
浏览文件 @
7ae05ac8
...
...
@@ -298,6 +298,9 @@ class PyWriter(IndentWriterBase):
_enum_member2num
=
None
def
__init__
(
self
,
for_imperative
=
False
):
self
.
_imperative
=
for_imperative
def
__call__
(
self
,
fout
,
defs
):
super
().
__call__
(
fout
)
self
.
_enum_member2num
=
[]
...
...
@@ -339,19 +342,35 @@ class PyWriter(IndentWriterBase):
' return super()._missing_(value)
\n
'
'
\n
'
)
self
.
_write
(
'def _as_dtype_num(dtype):
\n
'
' import megengine._internal.mgb as m
\n
'
' return m._get_dtype_num(dtype)
\n
'
'
\n
'
)
self
.
_write
(
'''
def _as_serialized_dtype(dtype):
import megengine._internal.mgb as m
return m._get_serialized_dtype(dtype)
'''
)
if
not
self
.
_imperative
:
self
.
_write
(
'def _as_dtype_num(dtype):
\n
'
' import megengine._internal.mgb as m
\n
'
' return m._get_dtype_num(dtype)
\n
'
'
\n
'
)
self
.
_write
(
'def _as_serialized_dtype(dtype):
\n
'
' import megengine._internal.mgb as m
\n
'
' return m._get_serialized_dtype(dtype)
\n
'
'
\n
'
)
else
:
self
.
_write
(
'def _as_dtype_num(dtype):
\n
'
' import xxx._xxx.utils as m
\n
'
' return m._get_dtype_num(dtype)
\n
'
'
\n
'
)
self
.
_write
(
'def _as_serialized_dtype(dtype):
\n
'
' import xxx._xxx.utils as m
\n
'
' return m._get_serialized_dtype(dtype)
\n
'
'
\n
'
)
self
.
_process
(
defs
)
self
.
_write
(
'''
...
...
@@ -777,8 +796,12 @@ def main():
'cpp file'
)
parser
.
add_argument
(
'input'
)
parser
.
add_argument
(
'output'
)
parser
.
add_argument
(
'--imperative'
,
action
=
'store_true'
,
help
=
'generate files for imperatvie '
)
args
=
parser
.
parse_args
()
for_imperative
=
args
.
imperative
with
open
(
args
.
input
)
as
fin
:
inputs
=
fin
.
read
()
exec
(
inputs
,
{
'pdef'
:
ParamDef
,
'Doc'
:
member_defs
.
Doc
})
...
...
@@ -787,7 +810,7 @@ def main():
input_hash
=
input_hash
.
hexdigest
()
if
args
.
type
==
'py'
:
writer
=
PyWriter
()
writer
=
PyWriter
(
for_imperative
=
for_imperative
)
else
:
assert
args
.
type
==
'c++'
if
args
.
enumv
:
...
...
src/CMakeLists.txt
浏览文件 @
7ae05ac8
...
...
@@ -151,27 +151,31 @@ if(ANDROID)
target_link_libraries
(
megbrain PUBLIC log
)
endif
()
# Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF
add_library
(
megengine
)
target_link_libraries
(
megengine PUBLIC megbrain megdnn
)
if
(
UNIX AND NOT APPLE
)
# TODO: Use target_link_options after upgrading to CMake 3.13
target_link_options
(
megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=
${
PROJECT_SOURCE_DIR
}
/python_module/src/version.ld
)
endif
()
set_target_properties
(
megengine PROPERTIES CXX_VISIBILITY_PRESET default
)
set_target_properties
(
megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE
)
if
(
MGE_WITH_DISTRIBUTED
)
if
(
NOT MGE_BUILD_XXX
)
# Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF
add_library
(
megengine
)
target_link_libraries
(
megengine PUBLIC megbrain megdnn
)
if
(
UNIX AND NOT APPLE
)
# TODO: Use target_link_options after upgrading to CMake 3.13
# FIXME; Please use right directory for mgb or imperative
target_link_options
(
megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=
${
PROJECT_SOURCE_DIR
}
/python_module/src/version.ld
)
endif
()
set_target_properties
(
megengine PROPERTIES CXX_VISIBILITY_PRESET default
)
set_target_properties
(
megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE
)
# Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready
# for this.
install
(
TARGETS megengine
LIBRARY DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
ARCHIVE DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
)
else
()
install
(
TARGETS megengine megbrain
endif
()
if
(
NOT MGE_WITH_DISTRIBUTED
)
install
(
TARGETS megbrain
EXPORT
${
MGE_EXPORT_TARGETS
}
LIBRARY DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
ARCHIVE DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
)
endif
()
foreach
(
_PATH
${
MGB_INC
}
)
install
(
DIRECTORY
${
_PATH
}
/megbrain DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
FILES_MATCHING PATTERN
"*.h"
)
endforeach
()
src/core/impl/graph/cg_impl.cpp
浏览文件 @
7ae05ac8
...
...
@@ -271,6 +271,23 @@ OperatorNodeBase* ComputingGraphImpl::insert_opr(
std
::
unique_ptr
<
OperatorNodeBase
>
opr_uniqp
)
{
auto
opr
=
opr_uniqp
.
get
();
if
(
options
().
imperative_proxy_graph
)
{
if
(
!
opr
->
inserted_in_graph
())
{
m_opr_refkeeper
.
emplace_back
(
std
::
move
(
opr_uniqp
));
opr
->
set_inserted_in_graph
();
opr
->
init_output_comp_node
();
opr
->
init_output_dtype
();
opr
->
init_output_format
();
// register static infer
{
auto
&&
mgr
=
static_infer_manager_impl
();
auto
old
=
mgr
.
set_register_allowed_opr
(
opr
);
opr
->
init_output_static_infer_desc
();
mgr
.
set_register_allowed_opr
(
old
);
}
}
return
opr
;
}
if
(
opr
->
inserted_in_graph
())
{
// FIXME: it's just a trick used for re-evaluation in eager evaluation
// mode. Since comp_graph has already taken an ownership of the opr,
...
...
src/core/impl/graph/grad_manager.cpp
浏览文件 @
7ae05ac8
...
...
@@ -133,6 +133,15 @@ void cg::register_grad_func(Typeinfo *opr_type, OprGradFunc grad) {
opr_type
->
name
);
}
OprGradFunc
*
cg
::
lookup_grad_func
(
Typeinfo
*
opr_type
)
{
auto
giter
=
static_data
().
grad_func_registry
.
find
(
opr_type
);
if
(
giter
!=
static_data
().
grad_func_registry
.
end
())
{
return
&
giter
->
second
;
}
else
{
return
nullptr
;
}
}
class
GradManager
::
StreamStrongPropInfer
{
DepOprIter
m_opr_iter
;
ThinHashSet
<
OperatorNodeBase
*>
m_strong_oprs
;
...
...
src/core/impl/graph/operator_node.cpp
浏览文件 @
7ae05ac8
...
...
@@ -101,6 +101,11 @@ OperatorNodeBase::~OperatorNodeBase() noexcept {
}
void
OperatorNodeBase
::
execute
(
ExecEnv
&
env
)
{
if
(
owner_graph
()
->
options
().
imperative_proxy_graph
)
{
do_execute
(
env
);
return
;
}
owner_graph
()
->
event
().
signal_inplace
<
event
::
OprExecStart
>
(
this
,
&
env
);
// dispatch waiting commands
...
...
src/core/impl/graph/var_node.cpp
浏览文件 @
7ae05ac8
...
...
@@ -230,6 +230,9 @@ VarNode& VarNode::format(TensorFormat format) {
bool
VarNode
::
set_fwd_in2out_readonly
(
VarNode
*
input
,
const
SubTensorSpec
&
sub
)
{
if
(
owner_graph
()
->
options
().
imperative_proxy_graph
)
{
return
false
;
}
return
static_cast
<
ComputingGraphImpl
*>
(
owner_graph
())
->
var_node_mem_manager
().
fwd_in2out_readonly
(
input
,
sub
,
this
);
}
...
...
@@ -242,6 +245,7 @@ VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) {
VarNode
&
VarNode
::
set_fwd_in2out_writable_force
(
VarNode
*
input
)
{
mgb_assert
(
!
owner_graph
()
->
options
().
imperative_proxy_graph
);
static_cast
<
ComputingGraphImpl
*>
(
owner_graph
())
->
var_node_mem_manager
().
fwd_in2out_writable_force
(
input
,
this
);
return
*
this
;
...
...
src/core/include/megbrain/graph/cg.h
浏览文件 @
7ae05ac8
...
...
@@ -440,6 +440,8 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
bool
eager_evaluation
=
false
;
#endif
bool
imperative_proxy_graph
=
false
;
//! add extra deps for the comp seq if a specific var is dependent
ThinHashMap
<
VarNode
*
,
VarNodeArray
>
extra_vardeps
;
...
...
src/core/include/megbrain/graph/grad_impl.h
浏览文件 @
7ae05ac8
...
...
@@ -73,6 +73,11 @@ namespace cg {
*/
void
register_grad_func
(
Typeinfo
*
opr_type
,
OprGradFunc
grad
);
/*!
* \brief lookup grad func for an operator type
*/
OprGradFunc
*
lookup_grad_func
(
Typeinfo
*
opr_type
);
/*!
* \brief add a callback to be invoked when grad of given var is computed
*
...
...
src/core/include/megbrain/graph/operator_node.h
浏览文件 @
7ae05ac8
...
...
@@ -69,6 +69,10 @@ class OperatorNodeConfig final: public Hashable {
return
*
this
;
}
const
Maybe
<
std
::
string
>&
name
()
const
{
return
m_name
;
}
/*!
* \brief update instance ID
*
...
...
src/core/include/megbrain/graph/var_node.h
浏览文件 @
7ae05ac8
...
...
@@ -22,6 +22,10 @@
#include <mutex>
namespace
mgb
{
namespace
imperative
{
class
ProxyGraph
;
}
// namespace imperative
namespace
cg
{
namespace
static_infer
{
class
StaticInferManagerImpl
;
...
...
@@ -576,6 +580,7 @@ class VarNode final: public GraphNodeBase {
friend
class
VarDevMemDefragmenter
;
friend
class
EagerEvalManager
;
friend
class
MemAllocPlan
;
friend
class
imperative
::
ProxyGraph
;
};
enum
class
VarNode
::
Flag
:
uint32_t
{
...
...
src/megbrain_build_config.h.in
浏览文件 @
7ae05ac8
...
...
@@ -29,6 +29,8 @@
#cmakedefine01 MGB_ENABLE_FBS_SERIALIZATION
#cmakedefine01 MGB_IS_DEV
#cmakedefine01 MGB_ENABLE_IMPERATIVE
// DNN related flags
// Platform macro's
#cmakedefine01 MEGDNN_WITH_CUDA
...
...
src/opr/impl/dnn/batch_norm.cpp
浏览文件 @
7ae05ac8
...
...
@@ -40,29 +40,37 @@ BatchNormForward::BatchNormForward(VarNode *x,
Super
{
x
->
owner_graph
(),
config
,
"batch_norm"
,
{
x
,
scale
,
bias
,
mean
,
variance
}}
{
auto
check_dest
=
[
&
](
VarNode
*
dest
)
{
auto
dest_opr
=
dest
->
owner_opr
();
mgb_throw_if
(
!
(
dest_opr
->
same_type
<
SharedDeviceTensor
>
()
||
dest_opr
->
same_type
<
VolatileSharedDeviceTensor
>
()),
GraphError
,
"mean&variance in BatchNorm must be SharedDeviceTensor/VolatileSharedDeviceTensor; "
"got %s{%s} actually"
,
dest_opr
->
cname
(),
dest_opr
->
dyn_typeinfo
()
->
name
);
};
check_dest
(
mean
);
check_dest
(
variance
);
if
(
owner_graph
()
->
options
().
imperative_proxy_graph
)
{
m_force_inplace
=
false
;
}
if
(
m_force_inplace
)
{
auto
check_dest
=
[
&
](
VarNode
*
dest
)
{
auto
dest_opr
=
dest
->
owner_opr
();
mgb_throw_if
(
!
(
dest_opr
->
same_type
<
SharedDeviceTensor
>
()
||
dest_opr
->
same_type
<
VolatileSharedDeviceTensor
>
()),
GraphError
,
"mean and variance in BatchNorm must be SharedDeviceTensor "
"or VolatileSharedDeviceTensor; got %s{%s} actually"
,
dest_opr
->
cname
(),
dest_opr
->
dyn_typeinfo
()
->
name
);
};
check_dest
(
mean
);
check_dest
(
variance
);
}
init_megdnn_opr
(
*
this
,
param
);
add_input
({
x
,
scale
,
bias
,
mean
,
variance
});
output
(
0
)
->
set_fwd_in2out_writable_force
(
input
(
3
)).
add_flag
(
VarNode
::
Flag
::
NO_MEM_RECLAIM
);
if
(
m_force_inplace
)
{
output
(
0
)
->
set_fwd_in2out_writable_force
(
input
(
3
)).
add_flag
(
VarNode
::
Flag
::
NO_MEM_RECLAIM
);
output
(
1
)
->
set_fwd_in2out_writable_force
(
input
(
4
)).
add_flag
(
VarNode
::
Flag
::
NO_MEM_RECLAIM
);
output
(
1
)
->
set_fwd_in2out_writable_force
(
input
(
4
)).
add_flag
(
VarNode
::
Flag
::
NO_MEM_RECLAIM
);
}
}
BatchNormForward
::
BatchNormForward
(
VarNode
*
x
,
...
...
@@ -129,17 +137,40 @@ BatchNormForward::do_make_node_prop() const {
void
BatchNormForward
::
scn_do_execute
()
{
auto
&&
x
=
input
(
0
)
->
dev_tensor
();
auto
&&
y
=
output
(
4
)
->
dev_tensor
();
mgb_assert
(
x
.
layout
().
is_contiguous
()
&&
y
.
layout
().
is_contiguous
());
#if MGB_ENABLE_IMPERATIVE
if
(
input
().
size
()
==
5
)
{
// need running mean/variance
auto
&&
o0
=
output
(
0
)
->
dev_tensor
(),
&&
o1
=
output
(
1
)
->
dev_tensor
(),
&&
i0
=
input
(
3
)
->
dev_tensor
(),
&&
i1
=
input
(
4
)
->
dev_tensor
();
mgb_assert
(
o0
.
raw_ptr
()
&&
o1
.
raw_ptr
());
// non-empty tensor
mgb_assert
(
o0
.
comp_node
()
==
i0
.
comp_node
()
&&
o1
.
comp_node
()
==
i1
.
comp_node
()
&&
o0
.
layout
().
eq_layout
(
i0
.
layout
())
&&
o1
.
layout
().
eq_layout
(
i1
.
layout
()));
if
(
!
m_force_inplace
)
{
if
(
o0
.
raw_ptr
()
!=
i0
.
raw_ptr
())
{
o0
.
copy_from_fixlayout
(
i0
);
}
if
(
o1
.
raw_ptr
()
!=
i1
.
raw_ptr
())
{
o1
.
copy_from_fixlayout
(
i1
);
}
}
else
{
mgb_assert
(
o0
.
raw_ptr
()
==
i0
.
raw_ptr
()
&&
o1
.
raw_ptr
()
==
i1
.
raw_ptr
());
}
}
#endif
auto
scale
=
input
(
1
)
->
dev_tensor
().
as_megdnn
();
auto
bias
=
input
(
2
)
->
dev_tensor
().
as_megdnn
();
auto
mean
=
output
(
0
)
->
dev_tensor
().
as_megdnn
();
auto
variance
=
output
(
1
)
->
dev_tensor
().
as_megdnn
();
auto
save_mean
=
output
(
2
)
->
dev_tensor
().
as_megdnn
();
auto
save_variance
=
output
(
3
)
->
dev_tensor
().
as_megdnn
();
auto
&&
y
=
output
(
4
)
->
dev_tensor
();
auto
workspace
=
intl
::
get_megdnn_workspace_from_var
(
output
().
back
());
mgb_assert
(
x
.
layout
().
is_contiguous
()
&&
y
.
layout
().
is_contiguous
());
auto
workspace
=
intl
::
get_megdnn_workspace_from_var
(
output
().
back
());
megdnn_opr
()
->
exec
(
x
.
as_megdnn
(),
scale
,
bias
,
mean
,
variance
,
save_mean
,
save_variance
,
y
.
as_megdnn
(),
workspace
);
}
...
...
@@ -191,6 +222,14 @@ void BatchNormForward::init_output_dtype() {
}
}
void
BatchNormForward
::
mem_plan_fwd_in2out_writable
()
{
if
(
!
m_force_inplace
&&
input
().
size
()
==
5
)
{
// TODO: testing
output
(
0
)
->
set_fwd_in2out_writable
(
input
(
3
));
output
(
1
)
->
set_fwd_in2out_writable
(
input
(
4
));
}
}
MGB_IMPL_OPR_GRAD
(
BatchNormForward
)
{
mgb_assert
(
wrt_idx
<
5
);
if
(
wrt_idx
<
3
)
{
...
...
src/opr/impl/internal/megdnn_opr_wrapper.cpp
浏览文件 @
7ae05ac8
...
...
@@ -271,17 +271,26 @@ WorkspaceLimitGetter::get_impl(ComputingGraph *graph) {
size_t
WorkspaceLimitGetter
::
get_workspace_limit
(
ComputingGraph
*
graph
,
CompNode
cn
,
size_t
old_limit
)
{
if
(
graph
->
options
().
imperative_proxy_graph
)
{
return
old_limit
;
}
if
(
!
graph
->
options
().
seq_opt
.
enable_mem_reuse_alloc
)
return
old_limit
;
return
get_impl
(
graph
)
->
get_workspace_limit
(
cn
,
old_limit
);
}
bool
WorkspaceLimitGetter
::
is_prealloc_run
(
ComputingGraph
*
graph
)
{
if
(
graph
->
options
().
imperative_proxy_graph
)
{
return
false
;
}
return
graph
->
options
().
seq_opt
.
enable_mem_reuse_alloc
&&
get_impl
(
graph
)
->
is_prealloc_run
();
}
VarNode
*
WorkspaceLimitGetter
::
register_to_graph
(
ComputingGraph
*
graph
)
{
if
(
graph
->
options
().
imperative_proxy_graph
)
{
return
nullptr
;
}
auto
maker
=
[
graph
](){
return
std
::
make_shared
<
Impl
>
(
graph
);
};
...
...
src/opr/include/megbrain/opr/dnn/batch_norm.h
浏览文件 @
7ae05ac8
...
...
@@ -75,6 +75,10 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward,
const
TensorShapeArray
&
output_shapes
)
const
override
;
void
init_output_static_infer_desc
()
override
;
void
init_output_dtype
()
override
;
void
mem_plan_fwd_in2out_writable
()
override
;
// if set to True, running mean/variance will be updated inplace
bool
m_force_inplace
=
true
;
};
using
BatchNorm
=
BatchNormForward
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录