diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 73d3c2697342c646356d1655f08fd852f312bd37..3fe21c4edad89b04932c629c6f93237db260c50e 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -44,7 +44,7 @@ PyObject *cpp_apply_backward_varnode; std::shared_ptr make_const(imperative::TensorPtr value) { if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { - return std::make_shared(interpreter_for_py->put(value->dev_tensor())); + return std::make_shared(interpreter_for_py->put(value->dev_tensor(), value->get_value())); } py::tuple tup(6); auto data = value->get_value(); @@ -248,7 +248,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { // for DeviceTensorND if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { auto dv = py::handle(arg0).cast(); - interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv); + interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv, {}); m_tensor = std::make_shared(handle); } else { throw py::type_error("single argument is not tensor, varnode or devicetensor"); @@ -347,7 +347,6 @@ SET_GET_NAME(user_custom_name) SET_GET_NAME(automatic_name) #undef SET_GET_NAME - PyObject* TensorWrapper::handle() { return py::cast(m_tensor->m_handle).release().ptr(); } @@ -532,7 +531,7 @@ PyObject* TensorWrapper::_dev_tensor(){ // set m_handle to make it a real tensor auto py_dev_tensor = py::reinterpret_borrow(dev_tensor); - auto sh = interpreter_for_py->put(py_dev_tensor.cast()); + auto sh = interpreter_for_py->put(py_dev_tensor.cast(), {}); m_tensor->m_handle = std::move(SharedHandle(sh)); // compiled info is useless after m_handle is set diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 89379c8bc0aab2181952c229c041e5861edf6f0e..2de9b78fbc682c4e6fa2acca81b31666aa69d14e 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -135,7 +135,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { return info; } -Handle ChannelImpl::put(const DeviceTensorND& data) { +Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) { MGB_LOCK_GUARD(m_spin); auto& state = get_channel_state(); mgb_assert(check_available(), "Channel already closed"); @@ -144,7 +144,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put); init(info, {data.layout(), data.comp_node()}); info->mem_desc.id = StorageIdentifier::make(++m_storage_id); - info->ptr = Tensor::make(data); + info->ptr = Tensor::make(data, hvalue); RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr()); info->status = TensorInfo::Produced; RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put); diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index a79988aef574c6c4b96aa4e3da732394e25c0072..4422f55dcba865edabc2dfa44e47e2246306d332 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -42,7 +42,7 @@ struct ChannelImpl : Interpreter::Channel { ~ChannelImpl() override; Handle put(const HostTensorND& value, bool no_cache) override; - Handle put(const DeviceTensorND& value) override; + Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override; void del(Handle) override; void swap_in(Handle) override; diff --git a/imperative/src/include/megbrain/imperative/interpreter.h b/imperative/src/include/megbrain/imperative/interpreter.h index f7f5028ff4d5c76707287a02ec6ac81b2ce95b07..92de64ed378040b382e89746d23aa0d90627cf90 100644 --- a/imperative/src/include/megbrain/imperative/interpreter.h +++ b/imperative/src/include/megbrain/imperative/interpreter.h @@ -23,7 +23,7 @@ struct Interpreter { virtual ~Channel() = default; virtual Handle put(const HostTensorND& value, bool no_cache) = 0; - virtual Handle put(const DeviceTensorND& value) = 0; + virtual Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) = 0; virtual void del(Handle) = 0; virtual void swap_in(Handle) = 0;