提交 1f7bf1ad 编写于 作者: M Megvii Engine Team

fix(opr): fix the compatilibity of elemwise multitype new mode

GitOrigin-RevId: ee58271276ee4a31e11aa26c53c466f5c07dd019
上级 b3a7d149
......@@ -40,16 +40,16 @@ mgb::cg::OperatorNodeBase* custom_loader(
} // namespace serialization
} // namespace mgb
#define CUSTOM_OP_SEREG_REG(cls) \
namespace { \
struct _OprReg##cls { \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD( \
cls, ::mgb::serialization::custom_dumper, \
::mgb::serialization::custom_loader); \
} \
}; \
} \
#define CUSTOM_OP_SEREG_REG(cls) \
namespace { \
struct _OprReg##cls { \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD( \
cls, ::mgb::serialization::custom_dumper, \
::mgb::serialization::custom_loader, true); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls)
#define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \
......
......@@ -131,10 +131,10 @@ cg::OperatorNodeBase* serialization::opr_shallow_copy_loop(
}
void LoopSerializer::reg_all() {
MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop);
MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker);
MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop, true);
MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker, true);
MGB_SEREG_OPR_INTL_CALL_ADD(
CounterProvider, dump_counter_provider, load_counter_provider);
CounterProvider, dump_counter_provider, load_counter_provider, true);
MGB_SEREG_OPR_INTL_CALL_ADD_V2(
opr::Loop, dump_loop, load_loop, nullptr, 2, CURRENT_VERSION);
......
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/serialization/sereg.h"
......@@ -7,10 +8,74 @@ template <>
struct OprMaker<opr::ElemwiseMultiType, 0>
: public OprMakerVariadic<opr::ElemwiseMultiType> {};
template <>
struct OprLoadDumpImplV2<opr::ElemwiseMultiType, 0> {
using Opr = opr::ElemwiseMultiType;
using PersisParam = opr::ElemwiseMultiType::Param;
using PersisElemwseiParam = opr::Elemwise::Param;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param());
}
static cg::OperatorNodeBase* replace_opr(
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) {
auto mode = opr->cast_final_safe<Opr>().param().mode;
auto change_to_elemwise_mode = [&](PersisParam::Mode multitype_mode) {
if (multitype_mode == PersisParam::Mode::EQ) {
return PersisElemwseiParam::Mode::EQ;
} else if (multitype_mode == PersisParam::Mode::LT) {
return PersisElemwseiParam::Mode::LT;
} else if (multitype_mode == PersisParam::Mode::LEQ) {
return PersisElemwseiParam::Mode::LEQ;
}
mgb_assert(0, "no supported model.");
};
if (PersisParam::Mode::EQ == mode || PersisParam::Mode::LT == mode ||
PersisParam::Mode::LEQ == mode) {
auto elemwise_mode = change_to_elemwise_mode(mode);
auto elemiwse_out = opr::Elemwise::make(inputs, {elemwise_mode});
return opr::TypeCvt::make(elemiwse_out, dtype::Bool()).node()->owner_opr();
} else if (PersisParam::Mode::NEQ == mode) {
auto elemiwse_out =
opr::Elemwise::make(inputs, {PersisElemwseiParam::Mode::EQ});
auto bool_out = opr::TypeCvt::make(elemiwse_out, dtype::Bool());
return opr::Elemwise::make({bool_out}, {PersisElemwseiParam::Mode::NOT})
.node()
->owner_opr();
} else if (PersisParam::Mode::ISNAN == mode) {
auto elemiwse_out = opr::Elemwise::make(
{inputs[0], inputs[0]}, {PersisElemwseiParam::Mode::EQ});
auto bool_out = opr::TypeCvt::make(elemiwse_out, dtype::Bool());
return opr::Elemwise::make({bool_out}, {PersisElemwseiParam::Mode::NOT})
.node()
->owner_opr();
} else if (PersisParam::Mode::ISINF == mode) {
auto input_var = SymbolVar{inputs[0]};
auto inf_var = input_var.make_scalar(INFINITY);
auto float_out = opr::TypeCvt::make(inputs[0], dtype::Float32());
auto elemiwse_out = opr::Elemwise::make(
{float_out, inf_var}, {PersisElemwseiParam::Mode::EQ});
return opr::TypeCvt::make(elemiwse_out, dtype::Bool()).node()->owner_opr();
}
return opr;
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
return OprMaker<opr::ElemwiseMultiType, 0>::make(
ctx.read_param<PersisParam>(), inputs, ctx.graph(), config);
}
};
} // namespace serialization
namespace opr {
MGB_SEREG_OPR(ElemwiseMultiType, 0);
MGB_SEREG_OPR_CONDITION(ElemwiseMultiType, 0, false);
MGB_SEREG_OPR_V2(
ElemwiseMultiType, 0,
(mgb::serialization::OprLoadDumpImplV2<opr::ElemwiseMultiType, 0>::replace_opr),
VERSION_1, VERSION_1);
MGB_SEREG_OPR(AffineInt, 3);
} // namespace opr
} // namespace mgb
......
......@@ -125,16 +125,18 @@ ComputingGraph* serialization::OprShallowCopyContext::owner_graph(
cg::OperatorNodeBase* serialization::copy_opr_shallow(
const cg::OperatorNodeBase& opr, const VarNodeArray& inputs,
const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) {
auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo());
mgb_assert(
registry, "could not find OprReceiver to copy opr %s{%s}", opr.cname(),
opr.dyn_typeinfo()->name);
OprShallowCopy shallow_copy = nullptr;
if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) {
shallow_copy = registry->shallow_copy;
} else {
shallow_copy = intl::copy_opr_shallow_default_impl;
}
mgb_assert(inputs.size() == opr.input().size());
auto dst_og = ctx.owner_graph(opr, inputs);
auto do_copy = [&]() {
auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph();
auto ret = registry->shallow_copy(ctx, opr, inputs, config);
auto ret = shallow_copy(ctx, opr, inputs, config);
if (dst_og != opr.owner_graph() ||
opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) {
......@@ -188,18 +190,28 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr,
const VarNodeArray& inputs, const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(ctx);
OprDumper opr_dumper = nullptr;
OprLoaderWrapper opr_loader = nullptr;
auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo());
if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) {
opr_loader = registry->loader;
opr_dumper = registry->dumper;
} else {
auto registryv2 = OprRegistryV2::versioned_find_by_typeinfo(
opr.dyn_typeinfo(), CURRENT_VERSION);
opr_loader = registryv2->loader;
opr_dumper = registryv2->dumper;
}
mgb_assert(
registry && registry->dumper && registry->loader,
opr_dumper && opr_loader,
"can not shallow_copy operator %s{%s}: "
"no dumper/loader registered",
opr.cname(), opr.dyn_typeinfo()->name);
OprDumpContextMemory dumper;
registry->dumper(dumper, opr);
OprDumpContextMemory memory_dumper;
opr_dumper(memory_dumper, opr);
OprLoadContextMemory loader{opr.owner_graph(), dumper};
return registry->loader(loader, inputs, config).opr();
OprLoadContextMemory loader{opr.owner_graph(), memory_dumper};
return opr_loader(loader, inputs, config).opr();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -358,6 +358,11 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
auto new_output_vars = output_vars;
if (!config.no_change_graph) {
new_output_vars = converter_all_opr_to_compatiable(output_vars);
mgb_assert(output_vars.size() == new_output_vars.size());
for (size_t id = 0; id < output_vars.size(); id++) {
auto& new_var = new_output_vars[id];
new_var.rename(output_vars[id].node()->name());
}
}
auto begin_pos = m_file->tell();
......
......@@ -151,20 +151,22 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
//! call OprRegistryV2::versioned_add for new serialization which is compatiable
//! with old serialization, convert is nullptr, this registry is just only for
//! varsion 1
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \
do { \
::mgb::serialization::OprRegistry::add( \
{_cls::typeinfo(), \
MGB_HASH_STR(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \
_dump, \
_load, \
{}, \
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \
::mgb::serialization::OprRegistryV2::versioned_add( \
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \
::mgb::VERSION_1, ::mgb::VERSION_1); \
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load, _registerv2) \
do { \
::mgb::serialization::OprRegistry::add( \
{_cls::typeinfo(), \
MGB_HASH_STR(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \
_dump, \
_load, \
{}, \
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \
if (_registerv2) { \
::mgb::serialization::OprRegistryV2::versioned_add( \
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \
::mgb::VERSION_1, ::mgb::VERSION_1); \
} \
} while (0)
//! call OprRegistryV2::versioned_add for new serialization, in which convert the
......@@ -181,23 +183,25 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
/*!
* \brief register opr serialization methods
*/
#define MGB_SEREG_OPR(_cls, _arity) \
namespace { \
namespace ser = ::mgb::serialization; \
struct _OprReg##_cls { \
using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \
static ser::OprWithOutputAccessor wrap_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \
} \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \
} \
}; \
} \
#define MGB_SEREG_OPR_CONDITION(_cls, _arity, _registerv2) \
namespace { \
namespace ser = ::mgb::serialization; \
struct _OprReg##_cls { \
using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \
static ser::OprWithOutputAccessor wrap_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \
} \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader, _registerv2); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)
#define MGB_SEREG_OPR(_cls, _arity) MGB_SEREG_OPR_CONDITION(_cls, _arity, true)
//! new dump/load function should implement in OprLoadDumpImplV2, _converter is
//! optional , if not implement pass nullptr
#define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \
......
#include "megbrain/opr/nn_int.h"
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/opr/basic_arith_wrapper.h"
......@@ -1016,4 +1017,107 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) {
load();
}
TEST(TestSerializer2, TestElemwiseMultiTypeLoadDump) {
auto fname = GET_OUTPUT_FILE(GraphDumpFormat::FLATBUFFERS_V2);
TensorShape shape{3};
auto cn = CompNode::load("xpu0");
std::shared_ptr<HostTensorND> host0 =
std::make_shared<HostTensorND>(cn, shape, dtype::Float32{});
std::shared_ptr<HostTensorND> host1 =
std::make_shared<HostTensorND>(cn, shape, dtype::Float32{});
HostTensorND dst_truth;
host0->ptr<float>()[0] = 2;
host0->ptr<float>()[1] = 2;
host0->ptr<float>()[2] = -1;
host1->ptr<float>()[0] = 1;
host1->ptr<float>()[1] = 2;
host1->ptr<float>()[2] = 3;
auto dump = [&](opr::ElemwiseMultiType::Param::Mode mode, size_t nr_opr) {
auto graph = ComputingGraph::make();
OperatorNodeConfig config;
config.name("input0");
auto h2d0 = opr::Host2DeviceCopy::make(*graph, host0, config);
config.name("input1");
auto h2d1 = opr::Host2DeviceCopy::make(*graph, host1, config);
auto x = opr::ElemwiseMultiType::make(
{h2d0, h2d1}, {mode}, OperatorNodeConfig{dtype::Bool()});
x.rename("out");
auto func = graph->compile({make_callback_copy(x, dst_truth)});
auto dumper = GraphDumper::make(
OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2);
auto rst = dumper->dump({x});
func->execute().wait();
ASSERT_EQ(rst.nr_opr, nr_opr);
};
auto load = [&]() {
auto loader = GraphLoader::make(
InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2);
auto rst = loader->load();
ASSERT_EQ(rst.tensor_map.size(), 2);
ASSERT_EQ(rst.output_var_map.count("out"), 1);
HostTensorND host_x;
auto func =
rst.graph_compile({make_callback_copy(rst.output_var_list[0], host_x)});
for (auto& input : rst.tensor_map) {
if (input.first == "input0") {
input.second->copy_from(*host0).sync();
} else if (input.first == "input1") {
input.second->copy_from(*host1).sync();
}
}
func->execute().wait();
for (int i = 0; i < 3; i++) {
EXPECT_EQ(host_x.ptr<bool>()[i], dst_truth.ptr<bool>()[i]);
}
};
dump(opr::ElemwiseMultiType::Param::Mode::EQ, 4);
load();
dump(opr::ElemwiseMultiType::Param::Mode::LT, 4);
load();
dump(opr::ElemwiseMultiType::Param::Mode::LEQ, 4);
load();
dump(opr::ElemwiseMultiType::Param::Mode::NEQ, 5);
load();
auto dump_single_input = [&](opr::ElemwiseMultiType::Param::Mode mode,
size_t nr_opr) {
auto graph = ComputingGraph::make();
auto h2d0 = opr::Host2DeviceCopy::make(*graph, host0);
auto x = opr::ElemwiseMultiType::make(
{h2d0}, {mode}, OperatorNodeConfig{dtype::Bool()});
x.rename("out");
auto func = graph->compile({make_callback_copy(x, dst_truth)});
auto dumper = GraphDumper::make(
OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2);
auto rst = dumper->dump({x});
func->execute().wait();
ASSERT_EQ(rst.nr_opr, nr_opr);
};
auto load_single_input = [&]() {
auto loader = GraphLoader::make(
InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2);
auto rst = loader->load();
ASSERT_EQ(rst.tensor_map.size(), 1);
ASSERT_EQ(rst.output_var_map.count("out"), 1);
HostTensorND host_x;
auto func =
rst.graph_compile({make_callback_copy(rst.output_var_list[0], host_x)});
rst.tensor_map.begin()->second->copy_from(*host0).sync();
func->execute().wait();
for (int i = 0; i < 3; i++) {
EXPECT_EQ(host_x.ptr<bool>()[i], dst_truth.ptr<bool>()[i]);
}
};
host0->ptr<float>()[2] = INFINITY;
dump_single_input(opr::ElemwiseMultiType::Param::Mode::ISINF, 4);
load_single_input();
host0->ptr<float>()[2] = NAN;
dump_single_input(opr::ElemwiseMultiType::Param::Mode::ISNAN, 4);
load_single_input();
}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册