From cbb47089a6b40dc5cc366c7df7d82019a402fb88 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 26 Sep 2021 19:34:40 +0800 Subject: [PATCH] perf(interpreter): add fastpath for GetVarShape GitOrigin-RevId: d1ac4e7fe38c8b8dc4d5141603ebea545a3c396f --- .../src/impl/interpreter/interpreter_impl.cpp | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 685445391..480fc59f5 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -269,7 +269,20 @@ void ChannelImpl::dispatch_default_cpu( uint64_t op_id = Profiler::next_id(); - OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); + if (op->trait()->apply_on_device_tensornd) { + OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); + } else { + // proxy to apply_on_physical_tensor + SmallVector input_tensors; + for (auto&& input_tensornd : input_tensornds) { + input_tensors.push_back(Tensor::make( + input_tensornd, HostTensorND::make_proxy(input_tensornd))); + } + auto output_tensors = OpDef::apply_on_physical_tensor(*op, input_tensors); + for (size_t i = 0; i < output_tensors.size(); ++i) { + output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor()); + } + } SmallVector output_infos; output_infos.reserve(output_descs.size()); @@ -357,6 +370,17 @@ SmallVector ChannelImpl::apply_op( std::shared_ptr op, const SmallVector& inputs) { MGB_LOCK_GUARD(m_spin); mgb_assert(check_available(), "Channel already closed"); + auto* input = reinterpret_cast(inputs[0]); + if (op->same_type() && input->desc.layout.ndim) { + size_t ndim = input->desc.layout.ndim; + auto& gvs = op->cast_final_safe(); + if (gvs.axis == MEGDNN_MAX_NDIM) { + HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()}; + DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu(); + cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout); + return {reinterpret_cast(put_impl(shape_tensor, false))}; + } + } return apply_op_impl(std::move(op), inputs); } @@ -621,6 +645,12 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr()); // update tensor desc for static infer + if (dest->desc.layout.ndim) { + mgb_assert( + dest->desc.layout.eq_shape(ptr->layout()), + "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(), + ptr->layout().to_string().c_str()); + } dest->desc.layout = ptr->layout(); dest->desc.comp_node = ptr->comp_node(); dest->memory = ptr->blob()->size(); -- GitLab