提交 270f1aa2 编写于 作者: M Megvii Engine Team

feat(mgb/serialization): add Accessor for OprLoader to fix BN output compatibility

GitOrigin-RevId: 3b95da02c8fa3cd2a6c6d47d7ede93b7b36aa3a7
上级 c0ccd0ea
......@@ -76,7 +76,7 @@ public:
}
};
cg::OperatorNodeBase* apply_on_var_node(
VarNodeArray apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
auto&& attr = def.cast_final_safe<OprAttr>();
auto config = attr.config;
......@@ -85,7 +85,7 @@ cg::OperatorNodeBase* apply_on_var_node(
auto registry = serialization::OprRegistry::find_by_name(attr.type);
mgb_assert(registry, "operator %s not found", attr.type.c_str());
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
return registry->loader(ctx, inputs, config);
return registry->loader(ctx, inputs, config).usable_output();
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
......
......@@ -200,7 +200,7 @@ TEST(TestImperative, BatchNormGrad) {
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn};
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn};
{
auto op = OprAttr::make("BatchNorm");
auto op = OprAttr::make("BatchNormV1");
auto&& attr = op->cast_final_safe<OprAttr>();
Param param;
param.fwd_mode = Param::FwdMode::TRAINING;
......@@ -210,7 +210,7 @@ TEST(TestImperative, BatchNormGrad) {
{false, false, false, false, false, true});
}
{
auto op = OprAttr::make("BatchNorm");
auto op = OprAttr::make("BatchNormV1");
auto&& attr = op->cast_final_safe<OprAttr>();
Param param;
param.fwd_mode = Param::FwdMode::TRAINING;
......
......@@ -59,7 +59,7 @@ TEST(TestImperative, Reduce) {
}
TEST(TestImperative, BatchNorm) {
auto op = OprAttr::make("BatchNorm");
auto op = OprAttr::make("BatchNormV1");
auto&& attr = op->cast_final_safe<OprAttr>();
using Param = opr::BatchNorm::Param;
Param param;
......
......@@ -16,14 +16,13 @@
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/sliding_window_transpose.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/dnn/sliding_window_transpose.h"
#include "megbrain/opr/dnn/tqt.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
......@@ -390,6 +389,7 @@ struct OprMaker<opr::BatchNorm, 0> {
}
};
// OprMaker in MGB_SEREG_OPR only support unique output opr
template <>
struct OprMaker<opr::BatchNormBackward, 6> {
using Param = opr::BatchNormBackward::Param;
......@@ -398,8 +398,8 @@ struct OprMaker<opr::BatchNormBackward, 6> {
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param,
config)[0]
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5],
param, config)[0]
.node()
->owner_opr();
}
......@@ -575,8 +575,10 @@ MGB_SEREG_OPR(Convolution3DBackwardFilter, 0);
using ConvBiasForwardV4 = ConvBiasForward;
MGB_SEREG_OPR(ConvBiasForwardV4, 0);
MGB_SEREG_OPR(BatchNorm, 0);
MGB_SEREG_OPR(BatchNormBackward, 6);
using BatchNormV1 = BatchNorm;
using BatchNormBackwardV1 = BatchNormBackward;
MGB_SEREG_OPR(BatchNormV1, 0);
MGB_SEREG_OPR(BatchNormBackwardV1, 6);
using LocalShareForwardV1 = LocalShareForward;
using LocalShareBackwardDataV1 = LocalShareBackwardData;
......
......@@ -39,7 +39,7 @@ namespace {
return inst;
}
cg::OperatorNodeBase* dynamic_loader(
OprWithOutputAccessor dynamic_loader(
OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config) {
auto name = ctx.load_buf_with_len();
......@@ -171,4 +171,20 @@ std::vector<std::pair<size_t, std::string>> OprRegistry::dump_registries() {
}
#endif
namespace {
const VarNodeArray& default_accessor(const VarNodeArray& outputs) {
return outputs;
}
}
OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr) : m_opr(opr){
m_accessor = &default_accessor;
};
OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor)
: OprWithOutputAccessor(opr) {
if (accessor) {
m_accessor = accessor;
}
};
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -207,7 +207,7 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
registry->dumper(dumper, opr);
OprLoadContextMemory loader{opr.owner_graph(), dumper};
return registry->loader(loader, inputs, config);
return registry->loader(loader, inputs, config).opr();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -782,7 +782,8 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
}
// call loader
auto opr = registry->loader(*this, inputs, config);
auto accessor = registry->loader(*this, inputs, config);
auto opr = accessor.opr();
// check opr type; note that:
// 1. registry->type may be empty for dynamic opr loaders or legacy oprs
......@@ -794,7 +795,7 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name);
// record output vars; read output names
size_t i = 0;
for (auto ovar : opr->output()) {
for (auto ovar : accessor.output()) {
if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
m_id2varnode.push_back(ovar);
if (fbopr->output_name()) {
......
......@@ -19,16 +19,36 @@ namespace serialization {
class OprDumpContext;
class OprLoadContext;
class OprShallowCopyContext;
class OprWithOutputAccessor {
cg::OperatorNodeBase* m_opr;
using Accessor = thin_function<const VarNodeArray(const VarNodeArray&)>;
Accessor m_accessor;
public:
OprWithOutputAccessor(cg::OperatorNodeBase* opr);
OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor);
VarNode* output(size_t idx) const { return output().at(idx); }
VarNodeArray output() const { return m_accessor(m_opr->output()); }
VarNodeArray usable_output() const { return m_accessor(m_opr->usable_output()); }
cg::OperatorNodeBase* opr() { return m_opr; }
};
//! dump opr internal params to OprDumpContext
using OprDumper = thin_function<void(
OprDumpContext &ctx, const cg::OperatorNodeBase &opr)>;
//! load and restore operator from OprLoadContext
//! is also used by GraphLoadConfig.
using OprLoader = thin_function<cg::OperatorNodeBase*(
OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config)>;
//! loader that can change opr output map for compatibility
using OprLoaderWrapper = thin_function<OprWithOutputAccessor(
OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config)>;
//! shallow copy function for a single operator
using OprShallowCopy = thin_function<cg::OperatorNodeBase*(
const OprShallowCopyContext &ctx,
......@@ -41,7 +61,7 @@ namespace serialization {
uint64_t persist_type_id;
std::string name;
OprDumper dumper;
OprLoader loader;
OprLoaderWrapper loader;
OprShallowCopy shallow_copy; //!< set to empty to use default impl
uint64_t unversioned_type_id;
......
......@@ -169,11 +169,17 @@ namespace { \
*/
#define MGB_SEREG_OPR(_cls, _arity) \
namespace { \
namespace ser = ::mgb::serialization; \
struct _OprReg##_cls { \
using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \
static ser::OprWithOutputAccessor wrap_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
return ser::OprWithOutputAccessor( \
Impl::load(ctx, inputs, config)); \
} \
static void entry() { \
using Impl = ::mgb::serialization::OprLoadDumpImpl< \
_cls, _arity>; \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, Impl::load); \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \
} \
}; \
} \
......@@ -187,22 +193,21 @@ template <class T>
struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {};
//! call OprRegistry::add with only loader, used for backward compatibility
#define MGB_SEREG_OPR_COMPAT(_name, _load) \
#define MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, _accessor) \
namespace { \
static_assert(IsComplete<_name>(), \
"need a complete type for MGB_SEREG_OPR_COMPAT"); \
namespace ser = ::mgb::serialization; \
struct _OprReg##_name { \
static cg::OperatorNodeBase* compat_loader( \
serialization::OprLoadContext& ctx, \
const cg::VarNodeArray& inputs, \
const OperatorNodeConfig& config) { \
return _load( \
static_cast<serialization::OprLoadContextRawPOD&>(ctx), \
inputs, config); \
static ser::OprWithOutputAccessor compat_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \
return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), \
_accessor); \
} \
static void entry() { \
::mgb::serialization::OprRegistry::add( \
{nullptr, \
ser::OprRegistry::add({nullptr, \
MGB_HASH_STR(#_name), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_name), \
nullptr, \
......@@ -214,6 +219,9 @@ struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {};
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name)
#define MGB_SEREG_OPR_COMPAT(_name, _load) \
MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, nullptr)
/*!
* \brief use \p _copy to implement shallow copy for given operator
*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册