提交 7ae05ac8 编写于 作者: M Megvii Engine Team

feat(imperative): merge common c++ code to megbrain

GitOrigin-RevId: d093778e103a6977bb4c5c9da85005e276d60e50
上级 9e904f68
......@@ -213,6 +213,15 @@ if(MGE_WITH_TEST)
endif()
option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON)
option(MGE_BUILD_XXX "Build _xxx.so instead of mgb.so " OFF)
if(MGE_BUILD_XXX)
set(CMAKE_CXX_STANDARD 17)
endif()
option(MGE_BUILD_SDK "Build load_and_run" ON)
if(MGE_BUILD_XXX)
set(MGE_BUILD_SDK OFF)
endif()
if(NOT MGE_WITH_CUDA)
message("-- Disable distributed support, as CUDA is not enabled.")
......@@ -522,7 +531,7 @@ endif()
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}")
set(MGB_ENABLE_IMPERATIVE, ${MGE_BUILD_XXX})
# Write out megbrain_build_config.h
# It defines macros needed by both megbrain and dnn
configure_file(src/megbrain_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/megbrain_build_config.h)
......@@ -566,14 +575,23 @@ if(MGE_WITH_DISTRIBUTED)
endif()
add_subdirectory(src)
add_subdirectory(sdk/load-and-run)
if(MGE_BUILD_SDK)
add_subdirectory(sdk/load-and-run)
endif()
if(MGE_WITH_PYTHON_MODULE)
if(MGE_BUILD_XXX)
add_subdirectory(imperative)
else()
add_subdirectory(python_module)
endif()
endif()
if(MGE_WITH_TEST AND MGE_ENABLE_RTTI)
if(NOT MGE_BUILD_XXX)
add_subdirectory(test)
endif()
endif()
if(TARGET mgb)
......@@ -597,6 +615,21 @@ if(TARGET mgb)
DEPENDS mgb
VERBATIM
)
elseif(TARGET _xxx)
add_custom_target(
develop
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/$<TARGET_FILE_NAME:${MODULE_NAME}>
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/$<TARGET_FILE_NAME:${MODULE_NAME}>
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/generated_ops.py
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/generated_ops.py
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/param_defs.py
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/ops/_internal/param_defs.py
DEPENDS _xxx
VERBATIM
)
endif()
IF(APPLE)
......
......@@ -59,7 +59,9 @@ install(TARGETS opr_param_defs EXPORT ${MGE_EXPORT_TARGETS})
if(MGE_WITH_TEST)
if(NOT MGE_BUILD_XXX)
add_subdirectory(test)
endif()
endif()
add_subdirectory(src)
......
......@@ -298,6 +298,9 @@ class PyWriter(IndentWriterBase):
_enum_member2num = None
def __init__(self, for_imperative=False):
self._imperative = for_imperative
def __call__(self, fout, defs):
super().__call__(fout)
self._enum_member2num = []
......@@ -339,19 +342,35 @@ class PyWriter(IndentWriterBase):
' return super()._missing_(value)\n'
'\n'
)
if not self._imperative:
self._write(
'def _as_dtype_num(dtype):\n'
' import megengine._internal.mgb as m\n'
' return m._get_dtype_num(dtype)\n'
'\n'
)
self._write(
'''
def _as_serialized_dtype(dtype):
import megengine._internal.mgb as m
return m._get_serialized_dtype(dtype)
'''
'def _as_serialized_dtype(dtype):\n'
' import megengine._internal.mgb as m\n'
' return m._get_serialized_dtype(dtype)\n'
'\n'
)
else:
self._write(
'def _as_dtype_num(dtype):\n'
' import xxx._xxx.utils as m\n'
' return m._get_dtype_num(dtype)\n'
'\n'
)
self._write(
'def _as_serialized_dtype(dtype):\n'
' import xxx._xxx.utils as m\n'
' return m._get_serialized_dtype(dtype)\n'
'\n'
)
self._process(defs)
self._write(
'''
......@@ -777,8 +796,12 @@ def main():
'cpp file')
parser.add_argument('input')
parser.add_argument('output')
parser.add_argument('--imperative', action='store_true',
help='generate files for imperatvie ')
args = parser.parse_args()
for_imperative = args.imperative
with open(args.input) as fin:
inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
......@@ -787,7 +810,7 @@ def main():
input_hash = input_hash.hexdigest()
if args.type == 'py':
writer = PyWriter()
writer = PyWriter(for_imperative=for_imperative)
else:
assert args.type == 'c++'
if args.enumv:
......
......@@ -151,27 +151,31 @@ if(ANDROID)
target_link_libraries(megbrain PUBLIC log)
endif()
# Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF
add_library(megengine)
target_link_libraries(megengine PUBLIC megbrain megdnn)
if (UNIX AND NOT APPLE)
if(NOT MGE_BUILD_XXX)
# Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF
add_library(megengine)
target_link_libraries(megengine PUBLIC megbrain megdnn)
if (UNIX AND NOT APPLE)
# TODO: Use target_link_options after upgrading to CMake 3.13
# FIXME; Please use right directory for mgb or imperative
target_link_options(megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=${PROJECT_SOURCE_DIR}/python_module/src/version.ld)
endif()
set_target_properties(megengine PROPERTIES CXX_VISIBILITY_PRESET default)
set_target_properties(megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
if (MGE_WITH_DISTRIBUTED)
endif()
set_target_properties(megengine PROPERTIES CXX_VISIBILITY_PRESET default)
set_target_properties(megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
# Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready
# for this.
install(TARGETS megengine
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
else()
install(TARGETS megengine megbrain
endif()
if (NOT MGE_WITH_DISTRIBUTED)
install(TARGETS megbrain
EXPORT ${MGE_EXPORT_TARGETS}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif()
foreach(_PATH ${MGB_INC})
install(DIRECTORY ${_PATH}/megbrain DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} FILES_MATCHING PATTERN "*.h")
endforeach()
......@@ -271,6 +271,23 @@ OperatorNodeBase* ComputingGraphImpl::insert_opr(
std::unique_ptr<OperatorNodeBase> opr_uniqp) {
auto opr = opr_uniqp.get();
if (options().imperative_proxy_graph) {
if (!opr->inserted_in_graph()) {
m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
opr->set_inserted_in_graph();
opr->init_output_comp_node();
opr->init_output_dtype();
opr->init_output_format();
// register static infer
{
auto&& mgr = static_infer_manager_impl();
auto old = mgr.set_register_allowed_opr(opr);
opr->init_output_static_infer_desc();
mgr.set_register_allowed_opr(old);
}
}
return opr;
}
if (opr->inserted_in_graph()) {
// FIXME: it's just a trick used for re-evaluation in eager evaluation
// mode. Since comp_graph has already taken an ownership of the opr,
......
......@@ -133,6 +133,15 @@ void cg::register_grad_func(Typeinfo *opr_type, OprGradFunc grad) {
opr_type->name);
}
OprGradFunc* cg::lookup_grad_func(Typeinfo *opr_type) {
auto giter = static_data().grad_func_registry.find(opr_type);
if (giter != static_data().grad_func_registry.end()) {
return &giter->second;
} else {
return nullptr;
}
}
class GradManager::StreamStrongPropInfer {
DepOprIter m_opr_iter;
ThinHashSet<OperatorNodeBase*> m_strong_oprs;
......
......@@ -101,6 +101,11 @@ OperatorNodeBase::~OperatorNodeBase() noexcept {
}
void OperatorNodeBase::execute(ExecEnv &env) {
if (owner_graph()->options().imperative_proxy_graph) {
do_execute(env);
return;
}
owner_graph()->event().signal_inplace<event::OprExecStart>(this, &env);
// dispatch waiting commands
......
......@@ -230,6 +230,9 @@ VarNode& VarNode::format(TensorFormat format) {
bool VarNode::set_fwd_in2out_readonly(
VarNode *input, const SubTensorSpec &sub) {
if (owner_graph()->options().imperative_proxy_graph) {
return false;
}
return static_cast<ComputingGraphImpl*>(owner_graph())
->var_node_mem_manager().fwd_in2out_readonly(input, sub, this);
}
......@@ -242,6 +245,7 @@ VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) {
VarNode& VarNode::set_fwd_in2out_writable_force(VarNode *input) {
mgb_assert(!owner_graph()->options().imperative_proxy_graph);
static_cast<ComputingGraphImpl*>(owner_graph())
->var_node_mem_manager().fwd_in2out_writable_force(input, this);
return *this;
......
......@@ -440,6 +440,8 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
bool eager_evaluation = false;
#endif
bool imperative_proxy_graph = false;
//! add extra deps for the comp seq if a specific var is dependent
ThinHashMap<VarNode*, VarNodeArray> extra_vardeps;
......
......@@ -73,6 +73,11 @@ namespace cg {
*/
void register_grad_func(Typeinfo *opr_type, OprGradFunc grad);
/*!
* \brief lookup grad func for an operator type
*/
OprGradFunc* lookup_grad_func(Typeinfo *opr_type);
/*!
* \brief add a callback to be invoked when grad of given var is computed
*
......
......@@ -69,6 +69,10 @@ class OperatorNodeConfig final: public Hashable {
return *this;
}
const Maybe<std::string>& name() const {
return m_name;
}
/*!
* \brief update instance ID
*
......
......@@ -22,6 +22,10 @@
#include <mutex>
namespace mgb {
namespace imperative {
class ProxyGraph;
} // namespace imperative
namespace cg {
namespace static_infer {
class StaticInferManagerImpl;
......@@ -576,6 +580,7 @@ class VarNode final: public GraphNodeBase {
friend class VarDevMemDefragmenter;
friend class EagerEvalManager;
friend class MemAllocPlan;
friend class imperative::ProxyGraph;
};
enum class VarNode::Flag: uint32_t {
......
......@@ -29,6 +29,8 @@
#cmakedefine01 MGB_ENABLE_FBS_SERIALIZATION
#cmakedefine01 MGB_IS_DEV
#cmakedefine01 MGB_ENABLE_IMPERATIVE
// DNN related flags
// Platform macro's
#cmakedefine01 MEGDNN_WITH_CUDA
......
......@@ -40,22 +40,29 @@ BatchNormForward::BatchNormForward(VarNode *x,
Super{x->owner_graph(), config, "batch_norm",
{x, scale, bias, mean, variance}}
{
if(owner_graph()->options().imperative_proxy_graph) {
m_force_inplace = false;
}
if (m_force_inplace) {
auto check_dest = [&](VarNode* dest) {
auto dest_opr = dest->owner_opr();
mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() ||
dest_opr->same_type<VolatileSharedDeviceTensor>()),
GraphError,
"mean&variance in BatchNorm must be SharedDeviceTensor/VolatileSharedDeviceTensor; "
"got %s{%s} actually",
"mean and variance in BatchNorm must be SharedDeviceTensor "
"or VolatileSharedDeviceTensor; got %s{%s} actually",
dest_opr->cname(), dest_opr->dyn_typeinfo()->name);
};
check_dest(mean);
check_dest(variance);
}
init_megdnn_opr(*this, param);
add_input({x, scale, bias, mean, variance});
if (m_force_inplace) {
output(0)->
set_fwd_in2out_writable_force(input(3)).
add_flag(VarNode::Flag::NO_MEM_RECLAIM);
......@@ -63,6 +70,7 @@ BatchNormForward::BatchNormForward(VarNode *x,
output(1)->
set_fwd_in2out_writable_force(input(4)).
add_flag(VarNode::Flag::NO_MEM_RECLAIM);
}
}
BatchNormForward::BatchNormForward(VarNode *x,
......@@ -129,17 +137,40 @@ BatchNormForward::do_make_node_prop() const {
void BatchNormForward::scn_do_execute() {
auto &&x = input(0)->dev_tensor();
auto &&y = output(4)->dev_tensor();
mgb_assert(x.layout().is_contiguous() &&
y.layout().is_contiguous());
#if MGB_ENABLE_IMPERATIVE
if (input().size() == 5) { // need running mean/variance
auto &&o0 = output(0)->dev_tensor(),
&&o1 = output(1)->dev_tensor(),
&&i0 = input(3)->dev_tensor(),
&&i1 = input(4)->dev_tensor();
mgb_assert(o0.raw_ptr() && o1.raw_ptr()); // non-empty tensor
mgb_assert(o0.comp_node() == i0.comp_node() &&
o1.comp_node() == i1.comp_node() &&
o0.layout().eq_layout(i0.layout()) &&
o1.layout().eq_layout(i1.layout()));
if (!m_force_inplace) {
if (o0.raw_ptr() != i0.raw_ptr()) {
o0.copy_from_fixlayout(i0);
}
if (o1.raw_ptr() != i1.raw_ptr()) {
o1.copy_from_fixlayout(i1);
}
} else {
mgb_assert(o0.raw_ptr() == i0.raw_ptr()
&& o1.raw_ptr() == i1.raw_ptr());
}
}
#endif
auto scale = input(1)->dev_tensor().as_megdnn();
auto bias = input(2)->dev_tensor().as_megdnn();
auto mean = output(0)->dev_tensor().as_megdnn();
auto variance = output(1)->dev_tensor().as_megdnn();
auto save_mean = output(2)->dev_tensor().as_megdnn();
auto save_variance = output(3)->dev_tensor().as_megdnn();
auto &&y = output(4)->dev_tensor();
auto workspace = intl::get_megdnn_workspace_from_var(
output().back());
mgb_assert(x.layout().is_contiguous() &&
y.layout().is_contiguous());
auto workspace = intl::get_megdnn_workspace_from_var(output().back());
megdnn_opr()->exec(x.as_megdnn(), scale, bias, mean, variance,
save_mean, save_variance, y.as_megdnn(), workspace);
}
......@@ -191,6 +222,14 @@ void BatchNormForward::init_output_dtype() {
}
}
void BatchNormForward::mem_plan_fwd_in2out_writable() {
if (!m_force_inplace && input().size() == 5) {
// TODO: testing
output(0)->set_fwd_in2out_writable(input(3));
output(1)->set_fwd_in2out_writable(input(4));
}
}
MGB_IMPL_OPR_GRAD(BatchNormForward) {
mgb_assert(wrt_idx < 5);
if (wrt_idx < 3) {
......
......@@ -271,17 +271,26 @@ WorkspaceLimitGetter::get_impl(ComputingGraph *graph) {
size_t WorkspaceLimitGetter::get_workspace_limit(
ComputingGraph *graph, CompNode cn, size_t old_limit) {
if (graph->options().imperative_proxy_graph) {
return old_limit;
}
if (!graph->options().seq_opt.enable_mem_reuse_alloc)
return old_limit;
return get_impl(graph)->get_workspace_limit(cn, old_limit);
}
bool WorkspaceLimitGetter::is_prealloc_run(ComputingGraph* graph) {
if (graph->options().imperative_proxy_graph) {
return false;
}
return graph->options().seq_opt.enable_mem_reuse_alloc &&
get_impl(graph)->is_prealloc_run();
}
VarNode* WorkspaceLimitGetter::register_to_graph(ComputingGraph *graph) {
if (graph->options().imperative_proxy_graph) {
return nullptr;
}
auto maker = [graph](){
return std::make_shared<Impl>(graph);
};
......
......@@ -75,6 +75,10 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward,
const TensorShapeArray &output_shapes) const override;
void init_output_static_infer_desc() override;
void init_output_dtype() override;
void mem_plan_fwd_in2out_writable() override;
// if set to True, running mean/variance will be updated inplace
bool m_force_inplace = true;
};
using BatchNorm = BatchNormForward;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册