提交 21849d79 编写于 作者: M Megvii Engine Team

fix(imperative): fix the problem that stack and concat will crash under dtr

GitOrigin-RevId: a08da0ff12cbfab5f70bda528687a6a1bee9b5c9
上级 e7a862aa
...@@ -855,6 +855,9 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { ...@@ -855,6 +855,9 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
} else { } else {
// i may be null // i may be null
validated = false; validated = false;
for (auto i : cmd.outputs) {
output_descs.push_back({});
}
} }
// Here std::move is REQUIRED for removing duplicated references. // Here std::move is REQUIRED for removing duplicated references.
auto outputs = apply_on_physical_tensor( auto outputs = apply_on_physical_tensor(
......
...@@ -111,13 +111,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -111,13 +111,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
int axis = op_def.axis >= 0 ? op_def.axis : op_def.axis + inputs[0]->layout().ndim; int axis = op_def.axis >= 0 ? op_def.axis : op_def.axis + inputs[0]->layout().ndim;
CompNode& oup_cn = output_descs[0].comp_node; CompNode& oup_cn = output_descs[0].comp_node;
TensorLayout& oup_layout = output_descs[0].layout;
if (validated) {
if (op_def.comp_node.valid()) { if (op_def.comp_node.valid()) {
mgb_assert(op_def.comp_node == oup_cn, "Concat compnode infer error"); mgb_assert(op_def.comp_node == oup_cn, "Concat compnode infer error");
} }
} else {
// prepare inputs and output layout // prepare inputs and output layout
TensorLayout& oup_layout = output_descs[0].layout; oup_cn = inputs[0]->comp_node();
if (!validated) {
SmallVector<const TensorLayout*> inputs_holder(inputs.size()); SmallVector<const TensorLayout*> inputs_holder(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
inputs_holder[i] = &inputs[i]->layout(); inputs_holder[i] = &inputs[i]->layout();
...@@ -213,13 +214,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -213,13 +214,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
op_def.axis >= 0 ? op_def.axis : op_def.axis + inputs[0]->layout().ndim + 1; op_def.axis >= 0 ? op_def.axis : op_def.axis + inputs[0]->layout().ndim + 1;
CompNode& oup_cn = output_descs[0].comp_node; CompNode& oup_cn = output_descs[0].comp_node;
TensorLayout& oup_layout = output_descs[0].layout;
if (validated) {
if (op_def.comp_node.valid()) { if (op_def.comp_node.valid()) {
mgb_assert(op_def.comp_node == oup_cn, "Stack compnode infer error"); mgb_assert(op_def.comp_node == oup_cn, "Stack compnode infer error");
} }
} else {
// prepare inputs and output layout // prepare inputs and output layout
TensorLayout& oup_layout = output_descs[0].layout; oup_cn = inputs[0]->comp_node();
if (!validated) {
SmallVector<const TensorLayout*> inputs_holder(inputs.size()); SmallVector<const TensorLayout*> inputs_holder(inputs.size());
for (size_t i = 0; i < nr_inp; ++i) { for (size_t i = 0; i < nr_inp; ++i) {
inputs_holder[i] = &inputs[i]->layout(); inputs_holder[i] = &inputs[i]->layout();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册