提交 896b0193 编写于 作者: M Megvii Engine Team

fix(mgb): fix zero copy error when model end with memory forward opr

GitOrigin-RevId: 2eba697d85a9970ba8e947cc66cdc8dbbcc32242
上级 babecba2
...@@ -435,31 +435,6 @@ void NetworkImplDft::cross_compnode_model_detect() { ...@@ -435,31 +435,6 @@ void NetworkImplDft::cross_compnode_model_detect() {
m_nr_device_type = nr_used_device_type.size(); m_nr_device_type = nr_used_device_type.size();
} }
void NetworkImplDft::adapt_option_valid() {
auto&& options = m_load_config.comp_graph->options();
if (m_user_config->options.force_output_use_user_specified_memory) {
for (auto&& out : m_load_result.output_var_list) {
auto opr = out.node()->owner_opr();
//! all the dest operator inherit from ReadonlyFwdHelper can't
//! support force_output_use_user_specified_memory options
if (opr->try_cast_final<mgb::opr::Reshape>() ||
opr->try_cast_final<mgb::opr::Broadcast>() ||
opr->try_cast_final<mgb::opr::Subtensor>() ||
opr->try_cast_final<mgb::opr::AxisAddRemove>() ||
opr->try_cast_final<mgb::opr::Dimshuffle>()) {
m_user_config->options.force_output_use_user_specified_memory = false;
options.force_output_use_user_specified_memory = false;
LITE_WARN(
"detect the unsupported dest operator %s when config "
"force_output_use_user_specified_memory, set "
"force_output_use_user_specified_memory to false\n",
opr->cname());
break;
}
}
}
}
void NetworkImplDft::layout_transform_optimization() { void NetworkImplDft::layout_transform_optimization() {
if (m_set_layout_transform) { if (m_set_layout_transform) {
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
...@@ -611,10 +586,6 @@ void NetworkImplDft::configure_after_loaded() { ...@@ -611,10 +586,6 @@ void NetworkImplDft::configure_after_loaded() {
layout_transform_optimization(); layout_transform_optimization();
//! some optimization option maybe invalid in some case, so here just
//! auto determine whether some options will apply.
adapt_option_valid();
//! find how many compnode the model has, this should call before update_io //! find how many compnode the model has, this should call before update_io
cross_compnode_model_detect(); cross_compnode_model_detect();
......
...@@ -239,9 +239,6 @@ private: ...@@ -239,9 +239,6 @@ private:
//! optimized output tensor copy //! optimized output tensor copy
void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor); void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor);
//! adapt option valid, it should call after update_io
void adapt_option_valid();
//! configure and optimize network after loaded //! configure and optimize network after loaded
void configure_after_loaded(); void configure_after_loaded();
......
#include "./network.h" #include "./network.h"
#include "megbrain/opr/tensor_manip.h"
using namespace mgb; using namespace mgb;
...@@ -137,6 +138,35 @@ SymbolVar Network::add_concat(SymbolVar f, SymbolVar g, int axis) { ...@@ -137,6 +138,35 @@ SymbolVar Network::add_concat(SymbolVar f, SymbolVar g, int axis) {
return opr::Concat::make({f, g}, axis); return opr::Concat::make({f, g}, axis);
} }
SymbolVar Network::add_dimshuffle(SymbolVar f, std::vector<int> pattern) {
return opr::Dimshuffle::make(f, pattern);
}
SymbolVar Network::add_axisaddremove(SymbolVar f) {
return opr::AxisAddRemove::make(
f, {{opr::AxisAddRemove::AxisDesc::Method::REMOVE, {0}}});
}
SymbolVar Network::add_subtensor(SymbolVar f) {
using AIdx = opr::indexing::AxisIndexer;
return opr::Subtensor::make(
f, {AIdx::make_interval(0, f.make_scalar(0), None, None)});
}
SymbolVar Network::add_reshape(SymbolVar f) {
auto shp = opr::GetVarShape::make(f);
return opr::Reshape::make(f, shp);
}
SymbolVar Network::add_broadcast(SymbolVar f) {
auto shp = opr::GetVarShape::make(f);
return opr::Broadcast::make(f, shp);
}
SymbolVar Network::add_copy(SymbolVar f) {
return opr::Copy::make(f);
}
SymbolVar mgb::create_block( SymbolVar mgb::create_block(
Network& network, SymbolVar f_in, size_t stride, size_t num_outputs1, Network& network, SymbolVar f_in, size_t stride, size_t num_outputs1,
bool has_proj, DType out_dtype) { bool has_proj, DType out_dtype) {
......
...@@ -53,6 +53,12 @@ public: ...@@ -53,6 +53,12 @@ public:
opr::Pooling::Param::Mode mode = opr::Pooling::Param::Mode::MAX); opr::Pooling::Param::Mode mode = opr::Pooling::Param::Mode::MAX);
SymbolVar add_type_cvt(SymbolVar f, DType out_dtype = dtype::Float32()); SymbolVar add_type_cvt(SymbolVar f, DType out_dtype = dtype::Float32());
SymbolVar add_concat(SymbolVar f, SymbolVar g, int axis = 0); SymbolVar add_concat(SymbolVar f, SymbolVar g, int axis = 0);
SymbolVar add_dimshuffle(SymbolVar f, std::vector<int> pattern);
SymbolVar add_axisaddremove(SymbolVar f);
SymbolVar add_subtensor(SymbolVar f);
SymbolVar add_reshape(SymbolVar f);
SymbolVar add_broadcast(SymbolVar f);
SymbolVar add_copy(SymbolVar f);
}; };
SymbolVar create_block( SymbolVar create_block(
......
...@@ -45,6 +45,35 @@ struct TestGraph { ...@@ -45,6 +45,35 @@ struct TestGraph {
m_out_var = m_network->add_concat(f, -f); m_out_var = m_network->add_concat(f, -f);
} }
void create_relayout_out_graph(int mem_forward_opr_type) {
input_tensor = m_gen({1, 3, 32, 32}, m_cn);
auto input = opr::Host2DeviceCopy::make(*m_network->graph, input_tensor, m_cn)
.rename("input");
auto f = m_network->add_conv(
input, 4, {3, 3}, dtype::Float32(), true, {2, 2}, {0, 0});
f = m_network->add_elemwise(
{f}, dtype::Float32(), opr::Elemwise::Param::Mode::EXP);
f = m_network->add_conv(f, 8, {3, 3}, dtype::Float32(), true, {1, 1}, {1, 1});
f = m_network->add_pooling(f, {2, 2}, {2, 2});
//! dimshuffle
if (mem_forward_opr_type == 0) {
f = m_network->add_dimshuffle(f, {0, 2, 3, 1});
//! BroadCast
} else if (mem_forward_opr_type == 1) {
f = m_network->add_broadcast(f);
//! Subtensor
} else if (mem_forward_opr_type == 2) {
f = m_network->add_subtensor(f);
//! AxisAddRemove
} else if (mem_forward_opr_type == 3) {
f = m_network->add_axisaddremove(f);
//! Reshape
} else if (mem_forward_opr_type == 4) {
f = m_network->add_reshape(f);
}
m_out_var = m_network->add_copy(f);
}
void create_graph_with_subtensor_forward() { void create_graph_with_subtensor_forward() {
input_tensor = m_gen({2, 3, 32, 32}, m_cn); input_tensor = m_gen({2, 3, 32, 32}, m_cn);
auto input = opr::Host2DeviceCopy::make(*m_network->graph, input_tensor, m_cn) auto input = opr::Host2DeviceCopy::make(*m_network->graph, input_tensor, m_cn)
...@@ -211,6 +240,67 @@ TEST(TestNoCopy, IONoCopyPtrEQ) { ...@@ -211,6 +240,67 @@ TEST(TestNoCopy, IONoCopyPtrEQ) {
} }
} }
namespace {
auto test_memory_forward_io_no_copy(int opr_type, TensorShape shape) {
auto test_graph = TestGraph();
auto compute_graph = test_graph.m_network->graph;
compute_graph->options().force_output_use_user_specified_memory = true;
test_graph.create_relayout_out_graph(opr_type);
HostTensorND truth;
auto func = test_graph.compile_without_copy();
//! because the output tensor not assign user memory, so it will wrong
ASSERT_THROW(func->execute(), MegBrainError);
auto&& outvar = func->get_output_vars()[0];
ASSERT_EQ(outvar, test_graph.m_out_var.node());
size_t times = 10;
for (size_t i = 0; i < times; i++) {
auto input_tensor = test_graph.input_tensor;
auto layout = input_tensor->layout();
size_t length = layout.total_nr_elems();
auto storage = TensorStorage<HostTensorStorageTrait>(test_graph.m_cn);
storage.ensure_size(length * sizeof(float));
float* ptr = storage.ptr()->as<float>();
for (size_t d = 0; d < length; d++) {
ptr[d] = i / 5 + 3;
}
input_tensor->reset(storage, layout);
DeviceTensorND dv(test_graph.m_cn, shape);
outvar->init_mem_plan(&dv);
outvar->reset_dev_tensor_from_tensor(dv);
func->execute();
func->wait();
if (i % 5 == 0) {
truth.copy_from(func->get_output_vars()[0]->dev_tensor()).sync();
continue;
}
HostTensorND to_check;
to_check.copy_from(func->get_output_vars()[0]->dev_tensor()).sync();
MGB_ASSERT_TENSOR_EQ(to_check, truth);
}
}
} // namespace
TEST(TestNoCopy, IONoCopyEndWithDimshuffle) {
test_memory_forward_io_no_copy(0, {1, 7, 7, 8});
}
TEST(TestNoCopy, IONoCopyEndWithReshape) {
test_memory_forward_io_no_copy(4, {1, 8, 7, 7});
}
TEST(TestNoCopy, IONoCopyEndWithAxisAddRemove) {
test_memory_forward_io_no_copy(3, {8, 7, 7});
}
TEST(TestNoCopy, IONoCopyEndWithBroadCast) {
test_memory_forward_io_no_copy(1, {1, 8, 7, 7});
}
TEST(TestNoCopy, IONoCopyEndWithSubtensor) {
test_memory_forward_io_no_copy(2, {1, 8, 7, 7});
}
TEST(TestNoCopy, IONoCopyCorrect) { TEST(TestNoCopy, IONoCopyCorrect) {
auto test_graph = TestGraph(); auto test_graph = TestGraph();
auto compute_graph = test_graph.m_network->graph; auto compute_graph = test_graph.m_network->graph;
......
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
#include "megbrain/gopt/inference.h" #include "megbrain/gopt/inference.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/utility.h"
namespace {
bool is_opr_memforward_var(mgb::VarNode* var) {
if (var) {
auto opr = var->owner_opr();
if (opr->try_cast_final<mgb::opr::Reshape>() ||
opr->try_cast_final<mgb::opr::Broadcast>() ||
opr->try_cast_final<mgb::opr::Subtensor>() ||
opr->try_cast_final<mgb::opr::AxisAddRemove>() ||
opr->try_cast_final<mgb::opr::Dimshuffle>()) {
return true;
}
};
return false;
}
} // namespace
namespace mgb { namespace mgb {
namespace serialization { namespace serialization {
...@@ -42,6 +60,14 @@ void GraphLoader::LoadResult::graph_compile_ahead() { ...@@ -42,6 +60,14 @@ void GraphLoader::LoadResult::graph_compile_ahead() {
//! just do basic optimize_for_inference ahead, and replace the var in //! just do basic optimize_for_inference ahead, and replace the var in
//! LoadResult //! LoadResult
if (graph->options().force_output_use_user_specified_memory) { if (graph->options().force_output_use_user_specified_memory) {
//! if the output var is like dimshuffle, reshape, it maybe memory forward to
//! the output, so add a Copy operator in the end.
for (auto& var : output_var_list) {
if (is_opr_memforward_var(var.node())) {
std::string name = var.node()->name();
var = opr::Copy::make(var, name);
}
}
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
auto new_vars = gopt::optimize_for_inference(output_var_list, options); auto new_vars = gopt::optimize_for_inference(output_var_list, options);
output_var_list = new_vars; output_var_list = new_vars;
...@@ -62,6 +88,7 @@ void GraphLoader::LoadResult::graph_compile_ahead() { ...@@ -62,6 +88,7 @@ void GraphLoader::LoadResult::graph_compile_ahead() {
found, "can't find var name %s when optimize_for_inference. ", found, "can't find var name %s when optimize_for_inference. ",
var.node()->cname()); var.node()->cname());
} }
output_var_map_id = var_map_id;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册