提交 7ae05ac8 编写于 作者: M Megvii Engine Team

feat(imperative): merge common c++ code to megbrain

GitOrigin-RevId: d093778e103a6977bb4c5c9da85005e276d60e50
上级 9e904f68
...@@ -213,6 +213,15 @@ if(MGE_WITH_TEST) ...@@ -213,6 +213,15 @@ if(MGE_WITH_TEST)
endif() endif()
option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON)
option(MGE_BUILD_XXX "Build _xxx.so instead of mgb.so " OFF)
if(MGE_BUILD_XXX)
set(CMAKE_CXX_STANDARD 17)
endif()
option(MGE_BUILD_SDK "Build load_and_run" ON)
if(MGE_BUILD_XXX)
set(MGE_BUILD_SDK OFF)
endif()
if(NOT MGE_WITH_CUDA) if(NOT MGE_WITH_CUDA)
message("-- Disable distributed support, as CUDA is not enabled.") message("-- Disable distributed support, as CUDA is not enabled.")
...@@ -522,7 +531,7 @@ endif() ...@@ -522,7 +531,7 @@ endif()
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}") set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}")
set(MGB_ENABLE_IMPERATIVE, ${MGE_BUILD_XXX})
# Write out megbrain_build_config.h # Write out megbrain_build_config.h
# It defines macros needed by both megbrain and dnn # It defines macros needed by both megbrain and dnn
configure_file(src/megbrain_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/megbrain_build_config.h) configure_file(src/megbrain_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/megbrain_build_config.h)
...@@ -566,14 +575,23 @@ if(MGE_WITH_DISTRIBUTED) ...@@ -566,14 +575,23 @@ if(MGE_WITH_DISTRIBUTED)
endif() endif()
add_subdirectory(src) add_subdirectory(src)
add_subdirectory(sdk/load-and-run)
if(MGE_BUILD_SDK)
add_subdirectory(sdk/load-and-run)
endif()
if(MGE_WITH_PYTHON_MODULE) if(MGE_WITH_PYTHON_MODULE)
add_subdirectory(python_module) if(MGE_BUILD_XXX)
add_subdirectory(imperative)
else()
add_subdirectory(python_module)
endif()
endif() endif()
if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) if(MGE_WITH_TEST AND MGE_ENABLE_RTTI)
add_subdirectory(test) if(NOT MGE_BUILD_XXX)
add_subdirectory(test)
endif()
endif() endif()
if(TARGET mgb) if(TARGET mgb)
...@@ -597,6 +615,21 @@ if(TARGET mgb) ...@@ -597,6 +615,21 @@ if(TARGET mgb)
DEPENDS mgb DEPENDS mgb
VERBATIM VERBATIM
) )
elseif(TARGET _xxx)
add_custom_target(
develop
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/$<TARGET_FILE_NAME:${MODULE_NAME}>
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/$<TARGET_FILE_NAME:${MODULE_NAME}>
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/generated_ops.py
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/generated_ops.py
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/param_defs.py
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/param_defs.py
DEPENDS _xxx
VERBATIM
)
endif() endif()
IF(APPLE) IF(APPLE)
......
...@@ -59,7 +59,9 @@ install(TARGETS opr_param_defs EXPORT ${MGE_EXPORT_TARGETS}) ...@@ -59,7 +59,9 @@ install(TARGETS opr_param_defs EXPORT ${MGE_EXPORT_TARGETS})
if(MGE_WITH_TEST) if(MGE_WITH_TEST)
add_subdirectory(test) if(NOT MGE_BUILD_XXX)
add_subdirectory(test)
endif()
endif() endif()
add_subdirectory(src) add_subdirectory(src)
......
...@@ -298,6 +298,9 @@ class PyWriter(IndentWriterBase): ...@@ -298,6 +298,9 @@ class PyWriter(IndentWriterBase):
_enum_member2num = None _enum_member2num = None
def __init__(self, for_imperative=False):
self._imperative = for_imperative
def __call__(self, fout, defs): def __call__(self, fout, defs):
super().__call__(fout) super().__call__(fout)
self._enum_member2num = [] self._enum_member2num = []
...@@ -339,19 +342,35 @@ class PyWriter(IndentWriterBase): ...@@ -339,19 +342,35 @@ class PyWriter(IndentWriterBase):
' return super()._missing_(value)\n' ' return super()._missing_(value)\n'
'\n' '\n'
) )
self._write( if not self._imperative:
'def _as_dtype_num(dtype):\n' self._write(
' import megengine._internal.mgb as m\n' 'def _as_dtype_num(dtype):\n'
' return m._get_dtype_num(dtype)\n' ' import megengine._internal.mgb as m\n'
'\n' ' return m._get_dtype_num(dtype)\n'
) '\n'
self._write( )
'''
def _as_serialized_dtype(dtype): self._write(
import megengine._internal.mgb as m 'def _as_serialized_dtype(dtype):\n'
return m._get_serialized_dtype(dtype) ' import megengine._internal.mgb as m\n'
''' ' return m._get_serialized_dtype(dtype)\n'
) '\n'
)
else:
self._write(
'def _as_dtype_num(dtype):\n'
' import xxx._xxx.utils as m\n'
' return m._get_dtype_num(dtype)\n'
'\n'
)
self._write(
'def _as_serialized_dtype(dtype):\n'
' import xxx._xxx.utils as m\n'
' return m._get_serialized_dtype(dtype)\n'
'\n'
)
self._process(defs) self._process(defs)
self._write( self._write(
''' '''
...@@ -777,8 +796,12 @@ def main(): ...@@ -777,8 +796,12 @@ def main():
'cpp file') 'cpp file')
parser.add_argument('input') parser.add_argument('input')
parser.add_argument('output') parser.add_argument('output')
parser.add_argument('--imperative', action='store_true',
help='generate files for imperatvie ')
args = parser.parse_args() args = parser.parse_args()
for_imperative = args.imperative
with open(args.input) as fin: with open(args.input) as fin:
inputs = fin.read() inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
...@@ -787,7 +810,7 @@ def main(): ...@@ -787,7 +810,7 @@ def main():
input_hash = input_hash.hexdigest() input_hash = input_hash.hexdigest()
if args.type == 'py': if args.type == 'py':
writer = PyWriter() writer = PyWriter(for_imperative=for_imperative)
else: else:
assert args.type == 'c++' assert args.type == 'c++'
if args.enumv: if args.enumv:
......
...@@ -151,27 +151,31 @@ if(ANDROID) ...@@ -151,27 +151,31 @@ if(ANDROID)
target_link_libraries(megbrain PUBLIC log) target_link_libraries(megbrain PUBLIC log)
endif() endif()
# Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF if(NOT MGE_BUILD_XXX)
add_library(megengine) # Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF
target_link_libraries(megengine PUBLIC megbrain megdnn) add_library(megengine)
if (UNIX AND NOT APPLE) target_link_libraries(megengine PUBLIC megbrain megdnn)
# TODO: Use target_link_options after upgrading to CMake 3.13 if (UNIX AND NOT APPLE)
target_link_options(megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=${PROJECT_SOURCE_DIR}/python_module/src/version.ld) # TODO: Use target_link_options after upgrading to CMake 3.13
endif() # FIXME; Please use right directory for mgb or imperative
set_target_properties(megengine PROPERTIES CXX_VISIBILITY_PRESET default) target_link_options(megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=${PROJECT_SOURCE_DIR}/python_module/src/version.ld)
set_target_properties(megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) endif()
if (MGE_WITH_DISTRIBUTED) set_target_properties(megengine PROPERTIES CXX_VISIBILITY_PRESET default)
set_target_properties(megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
# Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready # Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready
# for this. # for this.
install(TARGETS megengine install(TARGETS megengine
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
else() endif()
install(TARGETS megengine megbrain
if (NOT MGE_WITH_DISTRIBUTED)
install(TARGETS megbrain
EXPORT ${MGE_EXPORT_TARGETS} EXPORT ${MGE_EXPORT_TARGETS}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif() endif()
foreach(_PATH ${MGB_INC}) foreach(_PATH ${MGB_INC})
install(DIRECTORY ${_PATH}/megbrain DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} FILES_MATCHING PATTERN "*.h") install(DIRECTORY ${_PATH}/megbrain DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} FILES_MATCHING PATTERN "*.h")
endforeach() endforeach()
...@@ -271,6 +271,23 @@ OperatorNodeBase* ComputingGraphImpl::insert_opr( ...@@ -271,6 +271,23 @@ OperatorNodeBase* ComputingGraphImpl::insert_opr(
std::unique_ptr<OperatorNodeBase> opr_uniqp) { std::unique_ptr<OperatorNodeBase> opr_uniqp) {
auto opr = opr_uniqp.get(); auto opr = opr_uniqp.get();
if (options().imperative_proxy_graph) {
if (!opr->inserted_in_graph()) {
m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
opr->set_inserted_in_graph();
opr->init_output_comp_node();
opr->init_output_dtype();
opr->init_output_format();
// register static infer
{
auto&& mgr = static_infer_manager_impl();
auto old = mgr.set_register_allowed_opr(opr);
opr->init_output_static_infer_desc();
mgr.set_register_allowed_opr(old);
}
}
return opr;
}
if (opr->inserted_in_graph()) { if (opr->inserted_in_graph()) {
// FIXME: it's just a trick used for re-evaluation in eager evaluation // FIXME: it's just a trick used for re-evaluation in eager evaluation
// mode. Since comp_graph has already taken an ownership of the opr, // mode. Since comp_graph has already taken an ownership of the opr,
......
...@@ -133,6 +133,15 @@ void cg::register_grad_func(Typeinfo *opr_type, OprGradFunc grad) { ...@@ -133,6 +133,15 @@ void cg::register_grad_func(Typeinfo *opr_type, OprGradFunc grad) {
opr_type->name); opr_type->name);
} }
OprGradFunc* cg::lookup_grad_func(Typeinfo *opr_type) {
auto giter = static_data().grad_func_registry.find(opr_type);
if (giter != static_data().grad_func_registry.end()) {
return &giter->second;
} else {
return nullptr;
}
}
class GradManager::StreamStrongPropInfer { class GradManager::StreamStrongPropInfer {
DepOprIter m_opr_iter; DepOprIter m_opr_iter;
ThinHashSet<OperatorNodeBase*> m_strong_oprs; ThinHashSet<OperatorNodeBase*> m_strong_oprs;
......
...@@ -101,6 +101,11 @@ OperatorNodeBase::~OperatorNodeBase() noexcept { ...@@ -101,6 +101,11 @@ OperatorNodeBase::~OperatorNodeBase() noexcept {
} }
void OperatorNodeBase::execute(ExecEnv &env) { void OperatorNodeBase::execute(ExecEnv &env) {
if (owner_graph()->options().imperative_proxy_graph) {
do_execute(env);
return;
}
owner_graph()->event().signal_inplace<event::OprExecStart>(this, &env); owner_graph()->event().signal_inplace<event::OprExecStart>(this, &env);
// dispatch waiting commands // dispatch waiting commands
......
...@@ -230,6 +230,9 @@ VarNode& VarNode::format(TensorFormat format) { ...@@ -230,6 +230,9 @@ VarNode& VarNode::format(TensorFormat format) {
bool VarNode::set_fwd_in2out_readonly( bool VarNode::set_fwd_in2out_readonly(
VarNode *input, const SubTensorSpec &sub) { VarNode *input, const SubTensorSpec &sub) {
if (owner_graph()->options().imperative_proxy_graph) {
return false;
}
return static_cast<ComputingGraphImpl*>(owner_graph()) return static_cast<ComputingGraphImpl*>(owner_graph())
->var_node_mem_manager().fwd_in2out_readonly(input, sub, this); ->var_node_mem_manager().fwd_in2out_readonly(input, sub, this);
} }
...@@ -242,6 +245,7 @@ VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) { ...@@ -242,6 +245,7 @@ VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) {
VarNode& VarNode::set_fwd_in2out_writable_force(VarNode *input) { VarNode& VarNode::set_fwd_in2out_writable_force(VarNode *input) {
mgb_assert(!owner_graph()->options().imperative_proxy_graph);
static_cast<ComputingGraphImpl*>(owner_graph()) static_cast<ComputingGraphImpl*>(owner_graph())
->var_node_mem_manager().fwd_in2out_writable_force(input, this); ->var_node_mem_manager().fwd_in2out_writable_force(input, this);
return *this; return *this;
......
...@@ -440,6 +440,8 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, ...@@ -440,6 +440,8 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
bool eager_evaluation = false; bool eager_evaluation = false;
#endif #endif
bool imperative_proxy_graph = false;
//! add extra deps for the comp seq if a specific var is dependent //! add extra deps for the comp seq if a specific var is dependent
ThinHashMap<VarNode*, VarNodeArray> extra_vardeps; ThinHashMap<VarNode*, VarNodeArray> extra_vardeps;
......
...@@ -73,6 +73,11 @@ namespace cg { ...@@ -73,6 +73,11 @@ namespace cg {
*/ */
void register_grad_func(Typeinfo *opr_type, OprGradFunc grad); void register_grad_func(Typeinfo *opr_type, OprGradFunc grad);
/*!
* \brief lookup grad func for an operator type
*/
OprGradFunc* lookup_grad_func(Typeinfo *opr_type);
/*! /*!
* \brief add a callback to be invoked when grad of given var is computed * \brief add a callback to be invoked when grad of given var is computed
* *
......
...@@ -69,6 +69,10 @@ class OperatorNodeConfig final: public Hashable { ...@@ -69,6 +69,10 @@ class OperatorNodeConfig final: public Hashable {
return *this; return *this;
} }
const Maybe<std::string>& name() const {
return m_name;
}
/*! /*!
* \brief update instance ID * \brief update instance ID
* *
......
...@@ -22,6 +22,10 @@ ...@@ -22,6 +22,10 @@
#include <mutex> #include <mutex>
namespace mgb { namespace mgb {
namespace imperative {
class ProxyGraph;
} // namespace imperative
namespace cg { namespace cg {
namespace static_infer { namespace static_infer {
class StaticInferManagerImpl; class StaticInferManagerImpl;
...@@ -576,6 +580,7 @@ class VarNode final: public GraphNodeBase { ...@@ -576,6 +580,7 @@ class VarNode final: public GraphNodeBase {
friend class VarDevMemDefragmenter; friend class VarDevMemDefragmenter;
friend class EagerEvalManager; friend class EagerEvalManager;
friend class MemAllocPlan; friend class MemAllocPlan;
friend class imperative::ProxyGraph;
}; };
enum class VarNode::Flag: uint32_t { enum class VarNode::Flag: uint32_t {
......
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
#cmakedefine01 MGB_ENABLE_FBS_SERIALIZATION #cmakedefine01 MGB_ENABLE_FBS_SERIALIZATION
#cmakedefine01 MGB_IS_DEV #cmakedefine01 MGB_IS_DEV
#cmakedefine01 MGB_ENABLE_IMPERATIVE
// DNN related flags // DNN related flags
// Platform macro's // Platform macro's
#cmakedefine01 MEGDNN_WITH_CUDA #cmakedefine01 MEGDNN_WITH_CUDA
......
...@@ -40,29 +40,37 @@ BatchNormForward::BatchNormForward(VarNode *x, ...@@ -40,29 +40,37 @@ BatchNormForward::BatchNormForward(VarNode *x,
Super{x->owner_graph(), config, "batch_norm", Super{x->owner_graph(), config, "batch_norm",
{x, scale, bias, mean, variance}} {x, scale, bias, mean, variance}}
{ {
auto check_dest = [&](VarNode* dest) { if(owner_graph()->options().imperative_proxy_graph) {
auto dest_opr = dest->owner_opr(); m_force_inplace = false;
mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() || }
dest_opr->same_type<VolatileSharedDeviceTensor>()),
GraphError, if (m_force_inplace) {
"mean&variance in BatchNorm must be SharedDeviceTensor/VolatileSharedDeviceTensor; " auto check_dest = [&](VarNode* dest) {
"got %s{%s} actually", auto dest_opr = dest->owner_opr();
dest_opr->cname(), dest_opr->dyn_typeinfo()->name); mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() ||
}; dest_opr->same_type<VolatileSharedDeviceTensor>()),
check_dest(mean); GraphError,
check_dest(variance); "mean and variance in BatchNorm must be SharedDeviceTensor "
"or VolatileSharedDeviceTensor; got %s{%s} actually",
dest_opr->cname(), dest_opr->dyn_typeinfo()->name);
};
check_dest(mean);
check_dest(variance);
}
init_megdnn_opr(*this, param); init_megdnn_opr(*this, param);
add_input({x, scale, bias, mean, variance}); add_input({x, scale, bias, mean, variance});
output(0)-> if (m_force_inplace) {
set_fwd_in2out_writable_force(input(3)). output(0)->
add_flag(VarNode::Flag::NO_MEM_RECLAIM); set_fwd_in2out_writable_force(input(3)).
add_flag(VarNode::Flag::NO_MEM_RECLAIM);
output(1)-> output(1)->
set_fwd_in2out_writable_force(input(4)). set_fwd_in2out_writable_force(input(4)).
add_flag(VarNode::Flag::NO_MEM_RECLAIM); add_flag(VarNode::Flag::NO_MEM_RECLAIM);
}
} }
BatchNormForward::BatchNormForward(VarNode *x, BatchNormForward::BatchNormForward(VarNode *x,
...@@ -129,17 +137,40 @@ BatchNormForward::do_make_node_prop() const { ...@@ -129,17 +137,40 @@ BatchNormForward::do_make_node_prop() const {
void BatchNormForward::scn_do_execute() { void BatchNormForward::scn_do_execute() {
auto &&x = input(0)->dev_tensor(); auto &&x = input(0)->dev_tensor();
auto &&y = output(4)->dev_tensor();
mgb_assert(x.layout().is_contiguous() &&
y.layout().is_contiguous());
#if MGB_ENABLE_IMPERATIVE
if (input().size() == 5) { // need running mean/variance
auto &&o0 = output(0)->dev_tensor(),
&&o1 = output(1)->dev_tensor(),
&&i0 = input(3)->dev_tensor(),
&&i1 = input(4)->dev_tensor();
mgb_assert(o0.raw_ptr() && o1.raw_ptr()); // non-empty tensor
mgb_assert(o0.comp_node() == i0.comp_node() &&
o1.comp_node() == i1.comp_node() &&
o0.layout().eq_layout(i0.layout()) &&
o1.layout().eq_layout(i1.layout()));
if (!m_force_inplace) {
if (o0.raw_ptr() != i0.raw_ptr()) {
o0.copy_from_fixlayout(i0);
}
if (o1.raw_ptr() != i1.raw_ptr()) {
o1.copy_from_fixlayout(i1);
}
} else {
mgb_assert(o0.raw_ptr() == i0.raw_ptr()
&& o1.raw_ptr() == i1.raw_ptr());
}
}
#endif
auto scale = input(1)->dev_tensor().as_megdnn(); auto scale = input(1)->dev_tensor().as_megdnn();
auto bias = input(2)->dev_tensor().as_megdnn(); auto bias = input(2)->dev_tensor().as_megdnn();
auto mean = output(0)->dev_tensor().as_megdnn(); auto mean = output(0)->dev_tensor().as_megdnn();
auto variance = output(1)->dev_tensor().as_megdnn(); auto variance = output(1)->dev_tensor().as_megdnn();
auto save_mean = output(2)->dev_tensor().as_megdnn(); auto save_mean = output(2)->dev_tensor().as_megdnn();
auto save_variance = output(3)->dev_tensor().as_megdnn(); auto save_variance = output(3)->dev_tensor().as_megdnn();
auto &&y = output(4)->dev_tensor(); auto workspace = intl::get_megdnn_workspace_from_var(output().back());
auto workspace = intl::get_megdnn_workspace_from_var(
output().back());
mgb_assert(x.layout().is_contiguous() &&
y.layout().is_contiguous());
megdnn_opr()->exec(x.as_megdnn(), scale, bias, mean, variance, megdnn_opr()->exec(x.as_megdnn(), scale, bias, mean, variance,
save_mean, save_variance, y.as_megdnn(), workspace); save_mean, save_variance, y.as_megdnn(), workspace);
} }
...@@ -191,6 +222,14 @@ void BatchNormForward::init_output_dtype() { ...@@ -191,6 +222,14 @@ void BatchNormForward::init_output_dtype() {
} }
} }
void BatchNormForward::mem_plan_fwd_in2out_writable() {
if (!m_force_inplace && input().size() == 5) {
// TODO: testing
output(0)->set_fwd_in2out_writable(input(3));
output(1)->set_fwd_in2out_writable(input(4));
}
}
MGB_IMPL_OPR_GRAD(BatchNormForward) { MGB_IMPL_OPR_GRAD(BatchNormForward) {
mgb_assert(wrt_idx < 5); mgb_assert(wrt_idx < 5);
if (wrt_idx < 3) { if (wrt_idx < 3) {
......
...@@ -271,17 +271,26 @@ WorkspaceLimitGetter::get_impl(ComputingGraph *graph) { ...@@ -271,17 +271,26 @@ WorkspaceLimitGetter::get_impl(ComputingGraph *graph) {
size_t WorkspaceLimitGetter::get_workspace_limit( size_t WorkspaceLimitGetter::get_workspace_limit(
ComputingGraph *graph, CompNode cn, size_t old_limit) { ComputingGraph *graph, CompNode cn, size_t old_limit) {
if (graph->options().imperative_proxy_graph) {
return old_limit;
}
if (!graph->options().seq_opt.enable_mem_reuse_alloc) if (!graph->options().seq_opt.enable_mem_reuse_alloc)
return old_limit; return old_limit;
return get_impl(graph)->get_workspace_limit(cn, old_limit); return get_impl(graph)->get_workspace_limit(cn, old_limit);
} }
bool WorkspaceLimitGetter::is_prealloc_run(ComputingGraph* graph) { bool WorkspaceLimitGetter::is_prealloc_run(ComputingGraph* graph) {
if (graph->options().imperative_proxy_graph) {
return false;
}
return graph->options().seq_opt.enable_mem_reuse_alloc && return graph->options().seq_opt.enable_mem_reuse_alloc &&
get_impl(graph)->is_prealloc_run(); get_impl(graph)->is_prealloc_run();
} }
VarNode* WorkspaceLimitGetter::register_to_graph(ComputingGraph *graph) { VarNode* WorkspaceLimitGetter::register_to_graph(ComputingGraph *graph) {
if (graph->options().imperative_proxy_graph) {
return nullptr;
}
auto maker = [graph](){ auto maker = [graph](){
return std::make_shared<Impl>(graph); return std::make_shared<Impl>(graph);
}; };
......
...@@ -75,6 +75,10 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward, ...@@ -75,6 +75,10 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward,
const TensorShapeArray &output_shapes) const override; const TensorShapeArray &output_shapes) const override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
void init_output_dtype() override; void init_output_dtype() override;
void mem_plan_fwd_in2out_writable() override;
// if set to True, running mean/variance will be updated inplace
bool m_force_inplace = true;
}; };
using BatchNorm = BatchNormForward; using BatchNorm = BatchNormForward;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册