提交 dbe2b893 编写于 作者: M Megvii Engine Team

fix(midout): fix brain opr midout 2/2 (also see a6aa1574)

fix extern_c_opr midout

GitOrigin-RevId: 7de4f650d1c8fc3a6a4cedb04b4386c4cd66600f
上级 888895e9
...@@ -224,7 +224,7 @@ SymbolVarArray _Opr::extern_c_opr_placeholder( ...@@ -224,7 +224,7 @@ SymbolVarArray _Opr::extern_c_opr_placeholder(
} }
} }
auto opr = serialization::ExternCOprRunner::make_placeholder( auto opr = opr::ExternCOprRunner::make_placeholder(
inputs, cpp_output_shapes, dump_name, PyBytes_AsString(data_bytes), inputs, cpp_output_shapes, dump_name, PyBytes_AsString(data_bytes),
PyBytes_Size(data_bytes), config, cpp_output_dtypes); PyBytes_Size(data_bytes), config, cpp_output_dtypes);
SymbolVarArray ret; SymbolVarArray ret;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
using namespace mgb; using namespace mgb;
using namespace serialization; using namespace serialization;
using namespace opr;
namespace { namespace {
......
...@@ -16,18 +16,19 @@ namespace mgb { ...@@ -16,18 +16,19 @@ namespace mgb {
namespace serialization { namespace serialization {
template <> template <>
struct OprLoadDumpImpl<ExternCOprRunner, 0> { struct OprLoadDumpImpl<opr::ExternCOprRunner, 0> {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ExternCOprRunner::dump(ctx, opr); opr::ExternCOprRunner::dump(ctx, opr);
} }
static cg::OperatorNodeBase* load(OprLoadContext& ctx, static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
return ExternCOprRunner::load(ctx, inputs, config); return opr::ExternCOprRunner::load(ctx, inputs, config);
} }
}; };
using ExternCOprRunner = opr::ExternCOprRunner;
MGB_SEREG_OPR(ExternCOprRunner, 0); MGB_SEREG_OPR(ExternCOprRunner, 0);
MGB_REG_OPR_SHALLOW_COPY(ExternCOprRunner, ExternCOprRunner::shallow_copy); MGB_REG_OPR_SHALLOW_COPY(ExternCOprRunner, ExternCOprRunner::shallow_copy);
} // namespace serialization } // namespace serialization
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "megbrain/serialization/opr_registry.h" #include "megbrain/serialization/opr_registry.h"
namespace mgb { namespace mgb {
namespace serialization { namespace opr {
//! an operator to run extern C oprs //! an operator to run extern C oprs
MGB_DEFINE_OPR_CLASS(ExternCOprRunner, MGB_DEFINE_OPR_CLASS(ExternCOprRunner,
...@@ -68,10 +68,11 @@ public: ...@@ -68,10 +68,11 @@ public:
static bool unregister_loader(const char* name); static bool unregister_loader(const char* name);
//! impl for serialization dump //! impl for serialization dump
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr); static void dump(serialization::OprDumpContext& ctx,
const cg::OperatorNodeBase& opr);
//! impl for serialization load //! impl for serialization load
static cg::OperatorNodeBase* load(OprLoadContext& ctx, static cg::OperatorNodeBase* load(serialization::OprLoadContext& ctx,
const cg::VarNodeArray& inputs, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
...@@ -88,7 +89,7 @@ public: ...@@ -88,7 +89,7 @@ public:
static TensorShape tensor_shape_from_c(const MGBTensorShape& shape); static TensorShape tensor_shape_from_c(const MGBTensorShape& shape);
}; };
} // namespace serialization } // namespace opr
} // namespace mgb } // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -179,9 +179,18 @@ namespace { \ ...@@ -179,9 +179,18 @@ namespace { \
} \ } \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)
//! use to check type is complete or not, midout need a complete type
template <class T, class = void>
struct IsComplete : std::false_type {};
template <class T>
struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {};
//! call OprRegistry::add with only loader, used for backward compatibility //! call OprRegistry::add with only loader, used for backward compatibility
#define MGB_SEREG_OPR_COMPAT(_name, _load) \ #define MGB_SEREG_OPR_COMPAT(_name, _load) \
namespace { \ namespace { \
static_assert(IsComplete<_name>(), \
"need a complete type for MGB_SEREG_OPR_COMPAT"); \
struct _OprReg##_name { \ struct _OprReg##_name { \
static cg::OperatorNodeBase* compat_loader( \ static cg::OperatorNodeBase* compat_loader( \
serialization::OprLoadContext& ctx, \ serialization::OprLoadContext& ctx, \
......
...@@ -182,7 +182,7 @@ std::vector<uint8_t> create_graph_dump(float bias, float extra_scale, ...@@ -182,7 +182,7 @@ std::vector<uint8_t> create_graph_dump(float bias, float extra_scale,
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
if (sleep) if (sleep)
x = opr::Sleep::make(x, sleep); x = opr::Sleep::make(x, sleep);
x = serialization::ExternCOprRunner::make_placeholder( x = opr::ExternCOprRunner::make_placeholder(
{x}, {TensorShape{1}}, {x}, {TensorShape{1}},
dtype == MGB_DTYPE_FLOAT32 dtype == MGB_DTYPE_FLOAT32
? "bias_adder_dump" ? "bias_adder_dump"
...@@ -280,7 +280,7 @@ TEST(TestExternCOpr, Dedup) { ...@@ -280,7 +280,7 @@ TEST(TestExternCOpr, Dedup) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto make_opr = [x](float bias) { auto make_opr = [x](float bias) {
return ExternCOprRunner::make_from_desc( return opr::ExternCOprRunner::make_from_desc(
{x.node()}, MGBOprDescImpl<>::make(bias)); {x.node()}, MGBOprDescImpl<>::make(bias));
}; };
auto y0 = make_opr(0.5), y1 = make_opr(0.6), y2 = make_opr(0.5); auto y0 = make_opr(0.5), y1 = make_opr(0.6), y2 = make_opr(0.5);
......
...@@ -115,7 +115,7 @@ std::vector<uint8_t> create_graph_dump(float bias, float extra_scale, ...@@ -115,7 +115,7 @@ std::vector<uint8_t> create_graph_dump(float bias, float extra_scale,
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
if (sleep) if (sleep)
x = opr::Sleep::make(x, sleep); x = opr::Sleep::make(x, sleep);
x = serialization::ExternCOprRunner::make_placeholder( x = opr::ExternCOprRunner::make_placeholder(
{x}, {TensorShape{1}}, "bias_adder_dump_v23", &bias, sizeof(bias)) {x}, {TensorShape{1}}, "bias_adder_dump_v23", &bias, sizeof(bias))
->output(0); ->output(0);
if (extra_scale) if (extra_scale)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册