From 48db45d123c1c474c2a3de7caef89e7e09f7f5e9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 2 Aug 2021 16:04:36 +0800 Subject: [PATCH] perf(interpreter): try put device value with host to reduce d2h GitOrigin-RevId: 63d36e770609f7666e823beca579d53d90b4e6b0 --- imperative/python/src/tensor.cpp | 7 +++---- imperative/src/impl/interpreter/interpreter_impl.cpp | 4 ++-- imperative/src/impl/interpreter/interpreter_impl.h | 2 +- imperative/src/include/megbrain/imperative/interpreter.h | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 73d3c2697..3fe21c4ed 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -44,7 +44,7 @@ PyObject *cpp_apply_backward_varnode; std::shared_ptr make_const(imperative::TensorPtr value) { if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { - return std::make_shared(interpreter_for_py->put(value->dev_tensor())); + return std::make_shared(interpreter_for_py->put(value->dev_tensor(), value->get_value())); } py::tuple tup(6); auto data = value->get_value(); @@ -248,7 +248,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { // for DeviceTensorND if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { auto dv = py::handle(arg0).cast(); - interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv); + interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv, {}); m_tensor = std::make_shared(handle); } else { throw py::type_error("single argument is not tensor, varnode or devicetensor"); @@ -347,7 +347,6 @@ SET_GET_NAME(user_custom_name) SET_GET_NAME(automatic_name) #undef SET_GET_NAME - PyObject* TensorWrapper::handle() { return py::cast(m_tensor->m_handle).release().ptr(); } @@ -532,7 +531,7 @@ PyObject* TensorWrapper::_dev_tensor(){ // set m_handle to make it a real tensor auto py_dev_tensor = py::reinterpret_borrow(dev_tensor); - auto sh = interpreter_for_py->put(py_dev_tensor.cast()); + auto sh = interpreter_for_py->put(py_dev_tensor.cast(), {}); m_tensor->m_handle = std::move(SharedHandle(sh)); // compiled info is useless after m_handle is set diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 89379c8bc..2de9b78fb 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -135,7 +135,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { return info; } -Handle ChannelImpl::put(const DeviceTensorND& data) { +Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) { MGB_LOCK_GUARD(m_spin); auto& state = get_channel_state(); mgb_assert(check_available(), "Channel already closed"); @@ -144,7 +144,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put); init(info, {data.layout(), data.comp_node()}); info->mem_desc.id = StorageIdentifier::make(++m_storage_id); - info->ptr = Tensor::make(data); + info->ptr = Tensor::make(data, hvalue); RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr()); info->status = TensorInfo::Produced; RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put); diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index a79988aef..4422f55dc 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -42,7 +42,7 @@ struct ChannelImpl : Interpreter::Channel { ~ChannelImpl() override; Handle put(const HostTensorND& value, bool no_cache) override; - Handle put(const DeviceTensorND& value) override; + Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override; void del(Handle) override; void swap_in(Handle) override; diff --git a/imperative/src/include/megbrain/imperative/interpreter.h b/imperative/src/include/megbrain/imperative/interpreter.h index f7f5028ff..92de64ed3 100644 --- a/imperative/src/include/megbrain/imperative/interpreter.h +++ b/imperative/src/include/megbrain/imperative/interpreter.h @@ -23,7 +23,7 @@ struct Interpreter { virtual ~Channel() = default; virtual Handle put(const HostTensorND& value, bool no_cache) = 0; - virtual Handle put(const DeviceTensorND& value) = 0; + virtual Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) = 0; virtual void del(Handle) = 0; virtual void swap_in(Handle) = 0; -- GitLab