diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 1c11b9370b2570b7191247c6b5dd707b70c86c88..dac2726de97a59766b5363e5de569af0f56340bd 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -855,6 +855,9 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { } else { // i may be null validated = false; + for (auto i : cmd.outputs) { + output_descs.push_back({}); + } } // Here std::move is REQUIRED for removing duplicated references. auto outputs = apply_on_physical_tensor( diff --git a/imperative/src/impl/ops/concatenate.cpp b/imperative/src/impl/ops/concatenate.cpp index a6a9c3a7122b6d44832be50cc2772a1e7c816f54..cf0e34d7a1d7f0dcde2c2832d6d21b11f32e8a90 100644 --- a/imperative/src/impl/ops/concatenate.cpp +++ b/imperative/src/impl/ops/concatenate.cpp @@ -111,13 +111,14 @@ SmallVector apply_on_physical_tensor( int axis = op_def.axis >= 0 ? op_def.axis : op_def.axis + inputs[0]->layout().ndim; CompNode& oup_cn = output_descs[0].comp_node; - if (op_def.comp_node.valid()) { - mgb_assert(op_def.comp_node == oup_cn, "Concat compnode infer error"); - } - - // prepare inputs and output layout TensorLayout& oup_layout = output_descs[0].layout; - if (!validated) { + if (validated) { + if (op_def.comp_node.valid()) { + mgb_assert(op_def.comp_node == oup_cn, "Concat compnode infer error"); + } + } else { + // prepare inputs and output layout + oup_cn = inputs[0]->comp_node(); SmallVector inputs_holder(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { inputs_holder[i] = &inputs[i]->layout(); @@ -213,13 +214,14 @@ SmallVector apply_on_physical_tensor( op_def.axis >= 0 ? op_def.axis : op_def.axis + inputs[0]->layout().ndim + 1; CompNode& oup_cn = output_descs[0].comp_node; - if (op_def.comp_node.valid()) { - mgb_assert(op_def.comp_node == oup_cn, "Stack compnode infer error"); - } - - // prepare inputs and output layout TensorLayout& oup_layout = output_descs[0].layout; - if (!validated) { + if (validated) { + if (op_def.comp_node.valid()) { + mgb_assert(op_def.comp_node == oup_cn, "Stack compnode infer error"); + } + } else { + // prepare inputs and output layout + oup_cn = inputs[0]->comp_node(); SmallVector inputs_holder(inputs.size()); for (size_t i = 0; i < nr_inp; ++i) { inputs_holder[i] = &inputs[i]->layout();