diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 6854453914e2cd8b3af1ca360c14ca6693fca34b..480fc59f579bbf6ded84f8692a680e3944c41d0e 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -269,7 +269,20 @@ void ChannelImpl::dispatch_default_cpu( uint64_t op_id = Profiler::next_id(); - OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); + if (op->trait()->apply_on_device_tensornd) { + OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); + } else { + // proxy to apply_on_physical_tensor + SmallVector input_tensors; + for (auto&& input_tensornd : input_tensornds) { + input_tensors.push_back(Tensor::make( + input_tensornd, HostTensorND::make_proxy(input_tensornd))); + } + auto output_tensors = OpDef::apply_on_physical_tensor(*op, input_tensors); + for (size_t i = 0; i < output_tensors.size(); ++i) { + output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor()); + } + } SmallVector output_infos; output_infos.reserve(output_descs.size()); @@ -357,6 +370,17 @@ SmallVector ChannelImpl::apply_op( std::shared_ptr op, const SmallVector& inputs) { MGB_LOCK_GUARD(m_spin); mgb_assert(check_available(), "Channel already closed"); + auto* input = reinterpret_cast(inputs[0]); + if (op->same_type() && input->desc.layout.ndim) { + size_t ndim = input->desc.layout.ndim; + auto& gvs = op->cast_final_safe(); + if (gvs.axis == MEGDNN_MAX_NDIM) { + HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()}; + DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu(); + cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout); + return {reinterpret_cast(put_impl(shape_tensor, false))}; + } + } return apply_op_impl(std::move(op), inputs); } @@ -621,6 +645,12 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr()); // update tensor desc for static infer + if (dest->desc.layout.ndim) { + mgb_assert( + dest->desc.layout.eq_shape(ptr->layout()), + "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(), + ptr->layout().to_string().c_str()); + } dest->desc.layout = ptr->layout(); dest->desc.comp_node = ptr->comp_node(); dest->memory = ptr->blob()->size();