提交 2f68aeb9 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(imperative/jit): let trace support empty IO

GitOrigin-RevId: 97a55242bfe4d23e447ac77842e213ecb21995ee
上级 809d5056
......@@ -45,7 +45,10 @@ InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback,
if(m_use_static_shape){
mgb_assert(m_output_shape.ndim);
}
add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC).dtype(dt);
add_output(None)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.dtype(dt);
add_output(None)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
......@@ -109,6 +112,11 @@ void InputCallback::scn_do_execute() {
if (m_use_static_shape) {
mgb_assert(dev_tensor.shape().eq_shape(m_output_shape));
}
if (dev_tensor.empty()) {
auto layout = dev_tensor.layout();
layout.init_contiguous_stride();
dev_tensor.reset(dev_tensor.storage(), layout);
}
output(0)->reset_dev_tensor_from_tensor(dev_tensor);
}
......@@ -172,6 +180,7 @@ cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const {
};
m_use_host_value = m_param.prefer_host_value && host_value_avail();
dep_types[0] = m_use_host_value ? NodeProp::DepType::HOST_VALUE : NodeProp::DepType::DEV_VALUE;
dep_types[0] |= NodeProp::DepType::VALUE_ALLOW_EMPTY;
prop->reset_dep_type(input(), dep_types);
return prop;
}
......
......@@ -564,11 +564,7 @@ void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs) {
mgb_assert(var->comp_node() == tensor->comp_node() &&
var->shape().eq_shape(layout) &&
var->dtype() == layout.dtype);
if (!tensor->layout().is_empty()) {
var->assign_dev_tensor_from_tensor(tensor->dev_tensor());
} else {
var->m_dev_tensor.storage({var->comp_node()});
}
var->assign_dev_tensor_from_tensor(tensor->dev_tensor());
++ j;
}
chk.mem_alloc_status.set_from_owner_var();
......
......@@ -361,9 +361,19 @@ VarNode& VarNode::reset_dev_tensor_from_tensor(const DeviceTensorND& value) {
}
void VarNode::assign_dev_tensor_from_tensor(const DeviceTensorND& value) {
mgb_assert(value.layout().is_contiguous() &&
mgb_assert((value.layout().is_contiguous() || value.empty()) &&
m_dev_tensor.dtype() == value.dtype() &&
m_dev_tensor.format() == value.format());
if (value.empty()) {
mgb_assert(value.shape_valid() && value.layout().is_empty());
bool allow_empty = contain_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
auto &&recv = owner_graph()->var_receiver_in_current_comp_seq(this);
mgb_throw_if(!allow_empty || !recv.is_empty_allowed(),
GraphError,
"assign empty tensor to var %s, but allowed=%d, receiver=%s",
cg::dump_var_info({this}).c_str(),
allow_empty, recv.to_string().c_str());
}
if (cg::is_static_var_shape(this)) {
mgb_assert(shape().eq_shape(value.shape()),
"shape mismatch for static inferrable var when setting dev "
......
......@@ -480,8 +480,8 @@ class VarNode final: public GraphNodeBase {
*
* \param src_var the var node to provide dev tensor, which must have
* been initialized, and does not have to be in the same computing
* graph. Its value must be contiguous. It can also be placed on a
* different comp node.
* graph. Its value must be contiguous or empty. It can also be
* placed on a different comp node.
*
* \return whether memory forwarding succeeds; if false is returned, a
* new tensor would be allocated and its value is copied from src
......@@ -495,8 +495,8 @@ class VarNode final: public GraphNodeBase {
* This function should only be called by this var's owner operator and
* this var must have NO_SYS_MEM_ALLOC flag
*
* \param value the tensor to be used; it must be contiguous and be
* placed on the same comp node of this var.
* \param value the tensor to be used; it must be contiguous or empty
* and be placed on the same comp node of this var.
*/
VarNode& reset_dev_tensor_from_tensor(const DeviceTensorND &value);
......
......@@ -10,6 +10,7 @@
*/
#include "megbrain/opr/io.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/utility.h"
......@@ -2336,4 +2337,95 @@ TEST(TestGraph, DynamicOutput) {
MGB_ASSERT_TENSOR_NEAR(expect_spl_0_0, result_spl_0_0, 1e-4);
}
namespace {
// used for test reset_dev_tensor_from_tensor
MGB_DEFINE_OPR_CLASS(MaybeEmptyTensorOpr, cg::SingleCNOperatorNodeBase)// {
DeviceTensorND m_dv;
void init_output_comp_node() override {
output(0)->comp_node(m_dv.comp_node());
comp_node(m_dv.comp_node());
}
void scn_do_execute() override {
output(0)->reset_dev_tensor_from_tensor(m_dv);
}
void init_output_static_infer_desc() override {
using namespace cg::static_infer;
auto &&mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0),
ShapeInferDesc::make_const(m_dv.shape()));
}
public:
MaybeEmptyTensorOpr(ComputingGraph &graph,
const DeviceTensorND &dv, const OperatorNodeConfig &config):
Super(&graph, config, "", {}), m_dv{dv} {
add_output(None)
->add_flag(cg::VarNode::Flag::NO_SYS_MEM_ALLOC)
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.dtype(dv.dtype());
}
static SymbolVar make(ComputingGraph &graph, const DeviceTensorND &dv,
const OperatorNodeConfig &config = {}) {
return graph.insert_opr(std::make_unique<MaybeEmptyTensorOpr>(
graph, dv, config))->output(0);
}
};
} // anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MaybeEmptyTensorOpr);
TEST(TestMemReuse, ResetEmptyDevTensor) {
// reciver opr allow empty tensor as input
auto allow_empty = [](const TensorShape& inp_shp) {
HostTensorGenerator<> gen;
auto g = ComputingGraph::make();
auto host_x1 = gen(inp_shp),
host_x2 = gen(inp_shp);
DeviceTensorND dev_x1, dev_x2;
dev_x1.copy_from(*host_x1), dev_x2.copy_from(*host_x2);
auto x1 = MaybeEmptyTensorOpr::make(*g, dev_x1, {"x1"}),
x2 = MaybeEmptyTensorOpr::make(*g, dev_x2, {"x2"}),
y = x1 + x2;
HostTensorND host_y;
auto func = g->compile({make_callback_copy(y, host_y)});
auto &&recv = x1.node()->owner_graph()->var_receiver_in_current_comp_seq(x1.node());
ASSERT_TRUE(recv.is_empty_allowed());
ASSERT_NO_THROW(func->execute().wait());
if (inp_shp.is_empty()) {
ASSERT_TRUE(host_y.empty());
ASSERT_TRUE(host_y.shape().is_empty());
}
};
// reciver opr do not allow empty tensor as input
auto forbid_empty = [](const TensorShape& inp_shp) {
HostTensorGenerator<> gen;
auto g = ComputingGraph::make();
auto host_x = gen(inp_shp);
DeviceTensorND dev_x;
dev_x.copy_from(*host_x);
auto x = MaybeEmptyTensorOpr::make(*g, dev_x, {"x"}),
y = opr::Reduce::make(x, {opr::Reduce::Mode::MAX, 0});
HostTensorND host_y;
auto func = g->compile({make_callback_copy(y, host_y)});
auto &&recv = x.node()->owner_graph()->var_receiver_in_current_comp_seq(x.node());
ASSERT_TRUE(!recv.is_empty_allowed());
if (inp_shp.is_empty()) {
ASSERT_ANY_THROW(func->execute().wait());
} else {
ASSERT_NO_THROW(func->execute().wait());
}
};
allow_empty({2, 3, 4, 5});
allow_empty({2, 0, 3, 4});
forbid_empty({4, 5, 6, 7});
forbid_empty({8, 0, 0, 9});
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册