提交 48db45d1 编写于 作者: M Megvii Engine Team

perf(interpreter): try put device value with host to reduce d2h

GitOrigin-RevId: 63d36e770609f7666e823beca579d53d90b4e6b0
上级 a605f38b
...@@ -44,7 +44,7 @@ PyObject *cpp_apply_backward_varnode; ...@@ -44,7 +44,7 @@ PyObject *cpp_apply_backward_varnode;
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) { std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) {
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor())); return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor(), value->get_value()));
} }
py::tuple tup(6); py::tuple tup(6);
auto data = value->get_value(); auto data = value->get_value();
...@@ -248,7 +248,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -248,7 +248,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
// for DeviceTensorND // for DeviceTensorND
if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
auto dv = py::handle(arg0).cast<DeviceTensorND>(); auto dv = py::handle(arg0).cast<DeviceTensorND>();
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv); interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv, {});
m_tensor = std::make_shared<Tensor>(handle); m_tensor = std::make_shared<Tensor>(handle);
} else { } else {
throw py::type_error("single argument is not tensor, varnode or devicetensor"); throw py::type_error("single argument is not tensor, varnode or devicetensor");
...@@ -347,7 +347,6 @@ SET_GET_NAME(user_custom_name) ...@@ -347,7 +347,6 @@ SET_GET_NAME(user_custom_name)
SET_GET_NAME(automatic_name) SET_GET_NAME(automatic_name)
#undef SET_GET_NAME #undef SET_GET_NAME
PyObject* TensorWrapper::handle() { PyObject* TensorWrapper::handle() {
return py::cast(m_tensor->m_handle).release().ptr(); return py::cast(m_tensor->m_handle).release().ptr();
} }
...@@ -532,7 +531,7 @@ PyObject* TensorWrapper::_dev_tensor(){ ...@@ -532,7 +531,7 @@ PyObject* TensorWrapper::_dev_tensor(){
// set m_handle to make it a real tensor // set m_handle to make it a real tensor
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>(), {});
m_tensor->m_handle = std::move(SharedHandle(sh)); m_tensor->m_handle = std::move(SharedHandle(sh));
// compiled info is useless after m_handle is set // compiled info is useless after m_handle is set
......
...@@ -135,7 +135,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { ...@@ -135,7 +135,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
return info; return info;
} }
Handle ChannelImpl::put(const DeviceTensorND& data) { Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
auto& state = get_channel_state(); auto& state = get_channel_state();
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
...@@ -144,7 +144,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { ...@@ -144,7 +144,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put); RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put);
init(info, {data.layout(), data.comp_node()}); init(info, {data.layout(), data.comp_node()});
info->mem_desc.id = StorageIdentifier::make(++m_storage_id); info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
info->ptr = Tensor::make(data); info->ptr = Tensor::make(data, hvalue);
RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr()); RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr());
info->status = TensorInfo::Produced; info->status = TensorInfo::Produced;
RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put); RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put);
......
...@@ -42,7 +42,7 @@ struct ChannelImpl : Interpreter::Channel { ...@@ -42,7 +42,7 @@ struct ChannelImpl : Interpreter::Channel {
~ChannelImpl() override; ~ChannelImpl() override;
Handle put(const HostTensorND& value, bool no_cache) override; Handle put(const HostTensorND& value, bool no_cache) override;
Handle put(const DeviceTensorND& value) override; Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override;
void del(Handle) override; void del(Handle) override;
void swap_in(Handle) override; void swap_in(Handle) override;
......
...@@ -23,7 +23,7 @@ struct Interpreter { ...@@ -23,7 +23,7 @@ struct Interpreter {
virtual ~Channel() = default; virtual ~Channel() = default;
virtual Handle put(const HostTensorND& value, bool no_cache) = 0; virtual Handle put(const HostTensorND& value, bool no_cache) = 0;
virtual Handle put(const DeviceTensorND& value) = 0; virtual Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) = 0;
virtual void del(Handle) = 0; virtual void del(Handle) = 0;
virtual void swap_in(Handle) = 0; virtual void swap_in(Handle) = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册