提交 a09fc5f7 编写于 作者: M Megvii Engine Team

fix(mgb/serialization): disable inplace arith graph opt in graph load

GitOrigin-RevId: d63baf8356d345013886692e464f8e4f49594887
上级 cd7090ac
......@@ -755,8 +755,10 @@ class trace:
h2v = {}
graph = G.Graph()
# only graph_opt_level takes effect in dump
self._apply_graph_options(graph)
# apply graph_opt_level in dump
if self._graph_opt_level is not None:
graph.options.graph_opt_level = self._graph_opt_level
for i, h in enumerate(self._arg_bindings):
info = self._tinfo[h]
......
......@@ -244,7 +244,6 @@ def test_goptions_log_sum_exp():
np.testing.assert_almost_equal(g(d, o), val)
@pytest.mark.skip(reason="could not use opt_level=0 with dump")
def test_goptions_log_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x):
......
......@@ -355,6 +355,14 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
*/
int16_t graph_opt_level = 2;
/*!
* disable inplace arith transformations during graph
* construction
* it effectively disable level-1 graph optimization
* only for internal use during de-serialization
*/
bool disable_inplace_arith_opt = false;
/*!
* max size of allreduce packs in MB
* set this option to zero to disable PackAllReducePass
......
......@@ -221,7 +221,8 @@ SymbolVar Elemwise::make(const VarNodeArrayView& inputs, Param param,
trait.name, cg::dump_var_info(inputs).c_str());
#if !MGB_BUILD_SLIM_SERVING
if (inputs[0]->owner_graph()->options().graph_opt_level) {
auto&& options = inputs[0]->owner_graph()->options();
if (options.graph_opt_level && !(options.disable_inplace_arith_opt)) {
auto repl = gopt::optimize_elemwise_expr_inplace(dtp.get_vars(), param,
config);
if (repl)
......
......@@ -756,10 +756,16 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
GraphLoader::LoadResult GraphLoaderOSS::OprLoadContextImpl::load_oprs() {
// load oprs
const auto* oprs = m_loader->m_graph->oprs();
{
// inplace arith graph optimization is disabled during opr load
// it tries to restore the same graph as it was dumped
// see test TestSerializer2.LOGEXP for example
GraphLoader::ScopedGraphOptDisabler _(m_graph);
for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) {
m_current_opr = oprs->Get(i);
load_single_opr(m_current_opr);
}
}
// batched loading device values
m_device_value_loader.apply();
......
......@@ -61,6 +61,21 @@ namespace serialization {
const ComputingGraph::OutputSpec &outspec);
};
//! helper to disable inplace arith graph optimization during
//! de-serialization
struct ScopedGraphOptDisabler {
bool option_saved;
std::shared_ptr<ComputingGraph> cg;
ScopedGraphOptDisabler(std::shared_ptr<ComputingGraph>& cg_p)
: option_saved(true), cg(cg_p) {
std::swap(option_saved,
cg->options().disable_inplace_arith_opt);
}
~ScopedGraphOptDisabler() {
cg->options().disable_inplace_arith_opt = option_saved;
}
};
//! mem_node => tensor_value
using SharedTensorMapEntry =
ThinHashMap<MemNode, std::shared_ptr<DeviceTensorND>>;
......
......@@ -761,4 +761,41 @@ TEST(TestSerializer2, HasOutputDtype) {
load();
}
TEST(TestSerializer2, LOGEXP) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};
using Mode = opr::Elemwise::Mode;
bool inplace_opt = true;
auto dump = [&]() {
auto cn = CompNode::load("xpu0");
auto host_x = std::make_shared<HostTensorND>(cn, shape);
for (size_t i = 0, it = shape.total_nr_elems(); i < it; ++i)
host_x->ptr<float>()[i] = 0.0; // To avoid NAN
auto graph = ComputingGraph::make();
if (!inplace_opt)
graph->options().graph_opt_level = 0;
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"});
auto y = opr::Elemwise::make({x}, Mode::EXP);
auto z = opr::Elemwise::make({y}, Mode::LOG);
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = dumper->dump({z.rename("z"), z});
size_t expected_nr_opr = inplace_opt? 1: 3;
ASSERT_EQ(expected_nr_opr, rst.nr_opr);
};
auto load = [&]() {
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = loader->load();
};
dump();
load();
inplace_opt = !inplace_opt;
dump();
load();
}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册