diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 2d6840f9ddd378f46f48be7bcd82248fca3adeec..10480c44c651b83e24aeefe543e356d0568d2ecd 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -435,31 +435,6 @@ void NetworkImplDft::cross_compnode_model_detect() { m_nr_device_type = nr_used_device_type.size(); } -void NetworkImplDft::adapt_option_valid() { - auto&& options = m_load_config.comp_graph->options(); - if (m_user_config->options.force_output_use_user_specified_memory) { - for (auto&& out : m_load_result.output_var_list) { - auto opr = out.node()->owner_opr(); - //! all the dest operator inherit from ReadonlyFwdHelper can't - //! support force_output_use_user_specified_memory options - if (opr->try_cast_final() || - opr->try_cast_final() || - opr->try_cast_final() || - opr->try_cast_final() || - opr->try_cast_final()) { - m_user_config->options.force_output_use_user_specified_memory = false; - options.force_output_use_user_specified_memory = false; - LITE_WARN( - "detect the unsupported dest operator %s when config " - "force_output_use_user_specified_memory, set " - "force_output_use_user_specified_memory to false\n", - opr->cname()); - break; - } - } - } -} - void NetworkImplDft::layout_transform_optimization() { if (m_set_layout_transform) { mgb::ThinHashMap out_var_map; @@ -611,10 +586,6 @@ void NetworkImplDft::configure_after_loaded() { layout_transform_optimization(); - //! some optimization option maybe invalid in some case, so here just - //! auto determine whether some options will apply. - adapt_option_valid(); - //! find how many compnode the model has, this should call before update_io cross_compnode_model_detect(); diff --git a/lite/src/mge/network_impl.h b/lite/src/mge/network_impl.h index e2add83b8888b453c0eb2d89eb9c8054b4b1ea77..ff78c448df8664d0e59593519f2e03a124b477c3 100644 --- a/lite/src/mge/network_impl.h +++ b/lite/src/mge/network_impl.h @@ -239,9 +239,6 @@ private: //! optimized output tensor copy void output_tensor_copy_optimize(Var var, std::shared_ptr tensor); - //! adapt option valid, it should call after update_io - void adapt_option_valid(); - //! configure and optimize network after loaded void configure_after_loaded(); diff --git a/src/gopt/test/network.cpp b/src/gopt/test/network.cpp index 98c7f67c0b3b03d0286bfd7f840d1a036b985565..8134b1e3665c7e0bd9674e00f59386d4ebaa5a8c 100644 --- a/src/gopt/test/network.cpp +++ b/src/gopt/test/network.cpp @@ -1,4 +1,5 @@ #include "./network.h" +#include "megbrain/opr/tensor_manip.h" using namespace mgb; @@ -137,6 +138,35 @@ SymbolVar Network::add_concat(SymbolVar f, SymbolVar g, int axis) { return opr::Concat::make({f, g}, axis); } +SymbolVar Network::add_dimshuffle(SymbolVar f, std::vector pattern) { + return opr::Dimshuffle::make(f, pattern); +} + +SymbolVar Network::add_axisaddremove(SymbolVar f) { + return opr::AxisAddRemove::make( + f, {{opr::AxisAddRemove::AxisDesc::Method::REMOVE, {0}}}); +} + +SymbolVar Network::add_subtensor(SymbolVar f) { + using AIdx = opr::indexing::AxisIndexer; + return opr::Subtensor::make( + f, {AIdx::make_interval(0, f.make_scalar(0), None, None)}); +} + +SymbolVar Network::add_reshape(SymbolVar f) { + auto shp = opr::GetVarShape::make(f); + return opr::Reshape::make(f, shp); +} + +SymbolVar Network::add_broadcast(SymbolVar f) { + auto shp = opr::GetVarShape::make(f); + return opr::Broadcast::make(f, shp); +} + +SymbolVar Network::add_copy(SymbolVar f) { + return opr::Copy::make(f); +} + SymbolVar mgb::create_block( Network& network, SymbolVar f_in, size_t stride, size_t num_outputs1, bool has_proj, DType out_dtype) { diff --git a/src/gopt/test/network.h b/src/gopt/test/network.h index dd065469951ba4dec4e2ae3c9b367ff44d0e0809..4385bf1bfec985a6b6edc926616585077648c5e0 100644 --- a/src/gopt/test/network.h +++ b/src/gopt/test/network.h @@ -53,6 +53,12 @@ public: opr::Pooling::Param::Mode mode = opr::Pooling::Param::Mode::MAX); SymbolVar add_type_cvt(SymbolVar f, DType out_dtype = dtype::Float32()); SymbolVar add_concat(SymbolVar f, SymbolVar g, int axis = 0); + SymbolVar add_dimshuffle(SymbolVar f, std::vector pattern); + SymbolVar add_axisaddremove(SymbolVar f); + SymbolVar add_subtensor(SymbolVar f); + SymbolVar add_reshape(SymbolVar f); + SymbolVar add_broadcast(SymbolVar f); + SymbolVar add_copy(SymbolVar f); }; SymbolVar create_block( diff --git a/src/gopt/test/no_memory_copy.cpp b/src/gopt/test/no_memory_copy.cpp index 09d2469c7961c42e88668362b237c2da6081fb71..ce475d5f3355957347d96b11a8bf745c7b81a33e 100644 --- a/src/gopt/test/no_memory_copy.cpp +++ b/src/gopt/test/no_memory_copy.cpp @@ -45,6 +45,35 @@ struct TestGraph { m_out_var = m_network->add_concat(f, -f); } + void create_relayout_out_graph(int mem_forward_opr_type) { + input_tensor = m_gen({1, 3, 32, 32}, m_cn); + auto input = opr::Host2DeviceCopy::make(*m_network->graph, input_tensor, m_cn) + .rename("input"); + auto f = m_network->add_conv( + input, 4, {3, 3}, dtype::Float32(), true, {2, 2}, {0, 0}); + f = m_network->add_elemwise( + {f}, dtype::Float32(), opr::Elemwise::Param::Mode::EXP); + f = m_network->add_conv(f, 8, {3, 3}, dtype::Float32(), true, {1, 1}, {1, 1}); + f = m_network->add_pooling(f, {2, 2}, {2, 2}); + //! dimshuffle + if (mem_forward_opr_type == 0) { + f = m_network->add_dimshuffle(f, {0, 2, 3, 1}); + //! BroadCast + } else if (mem_forward_opr_type == 1) { + f = m_network->add_broadcast(f); + //! Subtensor + } else if (mem_forward_opr_type == 2) { + f = m_network->add_subtensor(f); + //! AxisAddRemove + } else if (mem_forward_opr_type == 3) { + f = m_network->add_axisaddremove(f); + //! Reshape + } else if (mem_forward_opr_type == 4) { + f = m_network->add_reshape(f); + } + m_out_var = m_network->add_copy(f); + } + void create_graph_with_subtensor_forward() { input_tensor = m_gen({2, 3, 32, 32}, m_cn); auto input = opr::Host2DeviceCopy::make(*m_network->graph, input_tensor, m_cn) @@ -211,6 +240,67 @@ TEST(TestNoCopy, IONoCopyPtrEQ) { } } +namespace { +auto test_memory_forward_io_no_copy(int opr_type, TensorShape shape) { + auto test_graph = TestGraph(); + auto compute_graph = test_graph.m_network->graph; + compute_graph->options().force_output_use_user_specified_memory = true; + test_graph.create_relayout_out_graph(opr_type); + HostTensorND truth; + auto func = test_graph.compile_without_copy(); + //! because the output tensor not assign user memory, so it will wrong + ASSERT_THROW(func->execute(), MegBrainError); + auto&& outvar = func->get_output_vars()[0]; + ASSERT_EQ(outvar, test_graph.m_out_var.node()); + size_t times = 10; + for (size_t i = 0; i < times; i++) { + auto input_tensor = test_graph.input_tensor; + auto layout = input_tensor->layout(); + size_t length = layout.total_nr_elems(); + auto storage = TensorStorage(test_graph.m_cn); + storage.ensure_size(length * sizeof(float)); + float* ptr = storage.ptr()->as(); + for (size_t d = 0; d < length; d++) { + ptr[d] = i / 5 + 3; + } + input_tensor->reset(storage, layout); + DeviceTensorND dv(test_graph.m_cn, shape); + outvar->init_mem_plan(&dv); + outvar->reset_dev_tensor_from_tensor(dv); + + func->execute(); + func->wait(); + if (i % 5 == 0) { + truth.copy_from(func->get_output_vars()[0]->dev_tensor()).sync(); + continue; + } + HostTensorND to_check; + to_check.copy_from(func->get_output_vars()[0]->dev_tensor()).sync(); + MGB_ASSERT_TENSOR_EQ(to_check, truth); + } +} +} // namespace + +TEST(TestNoCopy, IONoCopyEndWithDimshuffle) { + test_memory_forward_io_no_copy(0, {1, 7, 7, 8}); +} + +TEST(TestNoCopy, IONoCopyEndWithReshape) { + test_memory_forward_io_no_copy(4, {1, 8, 7, 7}); +} + +TEST(TestNoCopy, IONoCopyEndWithAxisAddRemove) { + test_memory_forward_io_no_copy(3, {8, 7, 7}); +} + +TEST(TestNoCopy, IONoCopyEndWithBroadCast) { + test_memory_forward_io_no_copy(1, {1, 8, 7, 7}); +} + +TEST(TestNoCopy, IONoCopyEndWithSubtensor) { + test_memory_forward_io_no_copy(2, {1, 8, 7, 7}); +} + TEST(TestNoCopy, IONoCopyCorrect) { auto test_graph = TestGraph(); auto compute_graph = test_graph.m_network->graph; diff --git a/src/serialization/impl/serializer.cpp b/src/serialization/impl/serializer.cpp index b44800cd4f7837ccd18263f0fb6a02124f704112..fd9d4c9153649e3d9509d96b9381f8aed1135d27 100644 --- a/src/serialization/impl/serializer.cpp +++ b/src/serialization/impl/serializer.cpp @@ -1,7 +1,25 @@ #include "megbrain/serialization/serializer.h" #include "megbrain/gopt/inference.h" +#include "megbrain/opr/io.h" +#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" +namespace { +bool is_opr_memforward_var(mgb::VarNode* var) { + if (var) { + auto opr = var->owner_opr(); + if (opr->try_cast_final() || + opr->try_cast_final() || + opr->try_cast_final() || + opr->try_cast_final() || + opr->try_cast_final()) { + return true; + } + }; + return false; +} +} // namespace + namespace mgb { namespace serialization { @@ -42,6 +60,14 @@ void GraphLoader::LoadResult::graph_compile_ahead() { //! just do basic optimize_for_inference ahead, and replace the var in //! LoadResult if (graph->options().force_output_use_user_specified_memory) { + //! if the output var is like dimshuffle, reshape, it maybe memory forward to + //! the output, so add a Copy operator in the end. + for (auto& var : output_var_list) { + if (is_opr_memforward_var(var.node())) { + std::string name = var.node()->name(); + var = opr::Copy::make(var, name); + } + } auto options = gopt::OptimizeForInferenceOptions{}; auto new_vars = gopt::optimize_for_inference(output_var_list, options); output_var_list = new_vars; @@ -62,6 +88,7 @@ void GraphLoader::LoadResult::graph_compile_ahead() { found, "can't find var name %s when optimize_for_inference. ", var.node()->cname()); } + output_var_map_id = var_map_id; } }