提交 48db45d1 编写于 作者: M Megvii Engine Team

perf(interpreter): try put device value with host to reduce d2h

GitOrigin-RevId: 63d36e770609f7666e823beca579d53d90b4e6b0
上级 a605f38b
......@@ -44,7 +44,7 @@ PyObject *cpp_apply_backward_varnode;
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) {
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor(), value->get_value()));
}
py::tuple tup(6);
auto data = value->get_value();
......@@ -248,7 +248,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
// for DeviceTensorND
if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
auto dv = py::handle(arg0).cast<DeviceTensorND>();
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv);
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv, {});
m_tensor = std::make_shared<Tensor>(handle);
} else {
throw py::type_error("single argument is not tensor, varnode or devicetensor");
......@@ -347,7 +347,6 @@ SET_GET_NAME(user_custom_name)
SET_GET_NAME(automatic_name)
#undef SET_GET_NAME
PyObject* TensorWrapper::handle() {
return py::cast(m_tensor->m_handle).release().ptr();
}
......@@ -532,7 +531,7 @@ PyObject* TensorWrapper::_dev_tensor(){
// set m_handle to make it a real tensor
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>(), {});
m_tensor->m_handle = std::move(SharedHandle(sh));
// compiled info is useless after m_handle is set
......
......@@ -135,7 +135,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
return info;
}
Handle ChannelImpl::put(const DeviceTensorND& data) {
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
MGB_LOCK_GUARD(m_spin);
auto& state = get_channel_state();
mgb_assert(check_available(), "Channel already closed");
......@@ -144,7 +144,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put);
init(info, {data.layout(), data.comp_node()});
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
info->ptr = Tensor::make(data);
info->ptr = Tensor::make(data, hvalue);
RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr());
info->status = TensorInfo::Produced;
RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put);
......
......@@ -42,7 +42,7 @@ struct ChannelImpl : Interpreter::Channel {
~ChannelImpl() override;
Handle put(const HostTensorND& value, bool no_cache) override;
Handle put(const DeviceTensorND& value) override;
Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override;
void del(Handle) override;
void swap_in(Handle) override;
......
......@@ -23,7 +23,7 @@ struct Interpreter {
virtual ~Channel() = default;
virtual Handle put(const HostTensorND& value, bool no_cache) = 0;
virtual Handle put(const DeviceTensorND& value) = 0;
virtual Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) = 0;
virtual void del(Handle) = 0;
virtual void swap_in(Handle) = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册