提交 270f1aa2 编写于 作者: M Megvii Engine Team

feat(mgb/serialization): add Accessor for OprLoader to fix BN output compatibility

GitOrigin-RevId: 3b95da02c8fa3cd2a6c6d47d7ede93b7b36aa3a7
上级 c0ccd0ea
...@@ -76,7 +76,7 @@ public: ...@@ -76,7 +76,7 @@ public:
} }
}; };
cg::OperatorNodeBase* apply_on_var_node( VarNodeArray apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) { const OpDef& def, const VarNodeArray& inputs) {
auto&& attr = def.cast_final_safe<OprAttr>(); auto&& attr = def.cast_final_safe<OprAttr>();
auto config = attr.config; auto config = attr.config;
...@@ -85,7 +85,7 @@ cg::OperatorNodeBase* apply_on_var_node( ...@@ -85,7 +85,7 @@ cg::OperatorNodeBase* apply_on_var_node(
auto registry = serialization::OprRegistry::find_by_name(attr.type); auto registry = serialization::OprRegistry::find_by_name(attr.type);
mgb_assert(registry, "operator %s not found", attr.type.c_str()); mgb_assert(registry, "operator %s not found", attr.type.c_str());
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
return registry->loader(ctx, inputs, config); return registry->loader(ctx, inputs, config).usable_output();
} }
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
......
...@@ -200,7 +200,7 @@ TEST(TestImperative, BatchNormGrad) { ...@@ -200,7 +200,7 @@ TEST(TestImperative, BatchNormGrad) {
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn};
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn};
{ {
auto op = OprAttr::make("BatchNorm"); auto op = OprAttr::make("BatchNormV1");
auto&& attr = op->cast_final_safe<OprAttr>(); auto&& attr = op->cast_final_safe<OprAttr>();
Param param; Param param;
param.fwd_mode = Param::FwdMode::TRAINING; param.fwd_mode = Param::FwdMode::TRAINING;
...@@ -210,7 +210,7 @@ TEST(TestImperative, BatchNormGrad) { ...@@ -210,7 +210,7 @@ TEST(TestImperative, BatchNormGrad) {
{false, false, false, false, false, true}); {false, false, false, false, false, true});
} }
{ {
auto op = OprAttr::make("BatchNorm"); auto op = OprAttr::make("BatchNormV1");
auto&& attr = op->cast_final_safe<OprAttr>(); auto&& attr = op->cast_final_safe<OprAttr>();
Param param; Param param;
param.fwd_mode = Param::FwdMode::TRAINING; param.fwd_mode = Param::FwdMode::TRAINING;
......
...@@ -59,7 +59,7 @@ TEST(TestImperative, Reduce) { ...@@ -59,7 +59,7 @@ TEST(TestImperative, Reduce) {
} }
TEST(TestImperative, BatchNorm) { TEST(TestImperative, BatchNorm) {
auto op = OprAttr::make("BatchNorm"); auto op = OprAttr::make("BatchNormV1");
auto&& attr = op->cast_final_safe<OprAttr>(); auto&& attr = op->cast_final_safe<OprAttr>();
using Param = opr::BatchNorm::Param; using Param = opr::BatchNorm::Param;
Param param; Param param;
......
...@@ -16,14 +16,13 @@ ...@@ -16,14 +16,13 @@
#include "megbrain/opr/dnn/correlation.h" #include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/sliding_window_transpose.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h" #include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h" #include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/dnn/sliding_window_transpose.h"
#include "megbrain/opr/dnn/tqt.h" #include "megbrain/opr/dnn/tqt.h"
#include "megbrain/serialization/sereg.h" #include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
...@@ -390,6 +389,7 @@ struct OprMaker<opr::BatchNorm, 0> { ...@@ -390,6 +389,7 @@ struct OprMaker<opr::BatchNorm, 0> {
} }
}; };
// OprMaker in MGB_SEREG_OPR only support unique output opr
template <> template <>
struct OprMaker<opr::BatchNormBackward, 6> { struct OprMaker<opr::BatchNormBackward, 6> {
using Param = opr::BatchNormBackward::Param; using Param = opr::BatchNormBackward::Param;
...@@ -398,8 +398,8 @@ struct OprMaker<opr::BatchNormBackward, 6> { ...@@ -398,8 +398,8 @@ struct OprMaker<opr::BatchNormBackward, 6> {
ComputingGraph& graph, ComputingGraph& graph,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph); MGB_MARK_USED_VAR(graph);
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param, return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5],
config)[0] param, config)[0]
.node() .node()
->owner_opr(); ->owner_opr();
} }
...@@ -575,8 +575,10 @@ MGB_SEREG_OPR(Convolution3DBackwardFilter, 0); ...@@ -575,8 +575,10 @@ MGB_SEREG_OPR(Convolution3DBackwardFilter, 0);
using ConvBiasForwardV4 = ConvBiasForward; using ConvBiasForwardV4 = ConvBiasForward;
MGB_SEREG_OPR(ConvBiasForwardV4, 0); MGB_SEREG_OPR(ConvBiasForwardV4, 0);
MGB_SEREG_OPR(BatchNorm, 0); using BatchNormV1 = BatchNorm;
MGB_SEREG_OPR(BatchNormBackward, 6); using BatchNormBackwardV1 = BatchNormBackward;
MGB_SEREG_OPR(BatchNormV1, 0);
MGB_SEREG_OPR(BatchNormBackwardV1, 6);
using LocalShareForwardV1 = LocalShareForward; using LocalShareForwardV1 = LocalShareForward;
using LocalShareBackwardDataV1 = LocalShareBackwardData; using LocalShareBackwardDataV1 = LocalShareBackwardData;
......
...@@ -39,7 +39,7 @@ namespace { ...@@ -39,7 +39,7 @@ namespace {
return inst; return inst;
} }
cg::OperatorNodeBase* dynamic_loader( OprWithOutputAccessor dynamic_loader(
OprLoadContext &ctx, const cg::VarNodeArray &inputs, OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config) { const OperatorNodeConfig &config) {
auto name = ctx.load_buf_with_len(); auto name = ctx.load_buf_with_len();
...@@ -171,4 +171,20 @@ std::vector<std::pair<size_t, std::string>> OprRegistry::dump_registries() { ...@@ -171,4 +171,20 @@ std::vector<std::pair<size_t, std::string>> OprRegistry::dump_registries() {
} }
#endif #endif
namespace {
const VarNodeArray& default_accessor(const VarNodeArray& outputs) {
return outputs;
}
}
OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr) : m_opr(opr){
m_accessor = &default_accessor;
};
OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor)
: OprWithOutputAccessor(opr) {
if (accessor) {
m_accessor = accessor;
}
};
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -207,7 +207,7 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( ...@@ -207,7 +207,7 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
registry->dumper(dumper, opr); registry->dumper(dumper, opr);
OprLoadContextMemory loader{opr.owner_graph(), dumper}; OprLoadContextMemory loader{opr.owner_graph(), dumper};
return registry->loader(loader, inputs, config); return registry->loader(loader, inputs, config).opr();
} }
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -782,7 +782,8 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( ...@@ -782,7 +782,8 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
} }
// call loader // call loader
auto opr = registry->loader(*this, inputs, config); auto accessor = registry->loader(*this, inputs, config);
auto opr = accessor.opr();
// check opr type; note that: // check opr type; note that:
// 1. registry->type may be empty for dynamic opr loaders or legacy oprs // 1. registry->type may be empty for dynamic opr loaders or legacy oprs
...@@ -794,7 +795,7 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( ...@@ -794,7 +795,7 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name); opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name);
// record output vars; read output names // record output vars; read output names
size_t i = 0; size_t i = 0;
for (auto ovar : opr->output()) { for (auto ovar : accessor.output()) {
if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
m_id2varnode.push_back(ovar); m_id2varnode.push_back(ovar);
if (fbopr->output_name()) { if (fbopr->output_name()) {
......
...@@ -19,16 +19,36 @@ namespace serialization { ...@@ -19,16 +19,36 @@ namespace serialization {
class OprDumpContext; class OprDumpContext;
class OprLoadContext; class OprLoadContext;
class OprShallowCopyContext; class OprShallowCopyContext;
class OprWithOutputAccessor {
cg::OperatorNodeBase* m_opr;
using Accessor = thin_function<const VarNodeArray(const VarNodeArray&)>;
Accessor m_accessor;
public:
OprWithOutputAccessor(cg::OperatorNodeBase* opr);
OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor);
VarNode* output(size_t idx) const { return output().at(idx); }
VarNodeArray output() const { return m_accessor(m_opr->output()); }
VarNodeArray usable_output() const { return m_accessor(m_opr->usable_output()); }
cg::OperatorNodeBase* opr() { return m_opr; }
};
//! dump opr internal params to OprDumpContext //! dump opr internal params to OprDumpContext
using OprDumper = thin_function<void( using OprDumper = thin_function<void(
OprDumpContext &ctx, const cg::OperatorNodeBase &opr)>; OprDumpContext &ctx, const cg::OperatorNodeBase &opr)>;
//! load and restore operator from OprLoadContext //! load and restore operator from OprLoadContext
//! is also used by GraphLoadConfig.
using OprLoader = thin_function<cg::OperatorNodeBase*( using OprLoader = thin_function<cg::OperatorNodeBase*(
OprLoadContext &ctx, const cg::VarNodeArray &inputs, OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config)>; const OperatorNodeConfig &config)>;
//! loader that can change opr output map for compatibility
using OprLoaderWrapper = thin_function<OprWithOutputAccessor(
OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config)>;
//! shallow copy function for a single operator //! shallow copy function for a single operator
using OprShallowCopy = thin_function<cg::OperatorNodeBase*( using OprShallowCopy = thin_function<cg::OperatorNodeBase*(
const OprShallowCopyContext &ctx, const OprShallowCopyContext &ctx,
...@@ -41,7 +61,7 @@ namespace serialization { ...@@ -41,7 +61,7 @@ namespace serialization {
uint64_t persist_type_id; uint64_t persist_type_id;
std::string name; std::string name;
OprDumper dumper; OprDumper dumper;
OprLoader loader; OprLoaderWrapper loader;
OprShallowCopy shallow_copy; //!< set to empty to use default impl OprShallowCopy shallow_copy; //!< set to empty to use default impl
uint64_t unversioned_type_id; uint64_t unversioned_type_id;
......
...@@ -167,16 +167,22 @@ namespace { \ ...@@ -167,16 +167,22 @@ namespace { \
/*! /*!
* \brief register opr serialization methods * \brief register opr serialization methods
*/ */
#define MGB_SEREG_OPR(_cls, _arity) \ #define MGB_SEREG_OPR(_cls, _arity) \
namespace { \ namespace { \
struct _OprReg##_cls { \ namespace ser = ::mgb::serialization; \
static void entry() { \ struct _OprReg##_cls { \
using Impl = ::mgb::serialization::OprLoadDumpImpl< \ using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \
_cls, _arity>; \ static ser::OprWithOutputAccessor wrap_loader( \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, Impl::load); \ ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
} \ const mgb::cg::OperatorNodeConfig& config) { \
}; \ return ser::OprWithOutputAccessor( \
} \ Impl::load(ctx, inputs, config)); \
} \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)
//! use to check type is complete or not, midout need a complete type //! use to check type is complete or not, midout need a complete type
...@@ -187,33 +193,35 @@ template <class T> ...@@ -187,33 +193,35 @@ template <class T>
struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {};
//! call OprRegistry::add with only loader, used for backward compatibility //! call OprRegistry::add with only loader, used for backward compatibility
#define MGB_SEREG_OPR_COMPAT(_name, _load) \ #define MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, _accessor) \
namespace { \ namespace { \
static_assert(IsComplete<_name>(), \ static_assert(IsComplete<_name>(), \
"need a complete type for MGB_SEREG_OPR_COMPAT"); \ "need a complete type for MGB_SEREG_OPR_COMPAT"); \
struct _OprReg##_name { \ namespace ser = ::mgb::serialization; \
static cg::OperatorNodeBase* compat_loader( \ struct _OprReg##_name { \
serialization::OprLoadContext& ctx, \ static ser::OprWithOutputAccessor compat_loader( \
const cg::VarNodeArray& inputs, \ ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const OperatorNodeConfig& config) { \ const mgb::cg::OperatorNodeConfig& config) { \
return _load( \ auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \
static_cast<serialization::OprLoadContextRawPOD&>(ctx), \ return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), \
inputs, config); \ _accessor); \
} \ } \
static void entry() { \ static void entry() { \
::mgb::serialization::OprRegistry::add( \ ser::OprRegistry::add({nullptr, \
{nullptr, \ MGB_HASH_STR(#_name), \
MGB_HASH_STR(#_name), \ _MGB_SEREG_OPR_NAME_FROM_CLS(_name), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ nullptr, \
nullptr, \ compat_loader, \
compat_loader, \ {}, \
{}, \ {}}); \
{}}); \ } \
} \ }; \
}; \ } \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name) MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name)
#define MGB_SEREG_OPR_COMPAT(_name, _load) \
MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, nullptr)
/*! /*!
* \brief use \p _copy to implement shallow copy for given operator * \brief use \p _copy to implement shallow copy for given operator
*/ */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册