提交 6c413ba9 编写于 作者: M Megvii Engine Team

refactor(mge): refactor physical tensor

GitOrigin-RevId: 93ba67ca5aa3ab06d47c985ca98d6df4927ae393
上级 d56570d9
...@@ -949,6 +949,7 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) { ...@@ -949,6 +949,7 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) {
} }
bool enable_fastpath(py::handle inp) { bool enable_fastpath(py::handle inp) {
// FIXME: the way to judge whether it is in traced module is inaccurate
if (!TensorWrapper::try_cast(inp.ptr()) || if (!TensorWrapper::try_cast(inp.ptr()) ||
TransformationManager::get_instance() TransformationManager::get_instance()
.segments[TransformationManager::Segment::Trace] .segments[TransformationManager::Segment::Trace]
......
...@@ -113,12 +113,7 @@ Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) { ...@@ -113,12 +113,7 @@ Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) {
MGB_RECORD_EVENT( MGB_RECORD_EVENT(
profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(),
dev_tensor().raw_ptr()); dev_tensor().raw_ptr());
DeviceTensorStorage storage; dev_tensor(false).copy_from_fixlayout(hv);
storage.reset(m_cn, m_blob->size(), m_blob->storage());
storage = storage.sub(m_offset);
DeviceTensorND dv;
dv.reset(storage, m_layout);
dv.copy_from_fixlayout(hv);
// even though hv is saved in m_value, Tensor itself could be // even though hv is saved in m_value, Tensor itself could be
// released before copy completes // released before copy completes
MGB_RECORD_EVENT( MGB_RECORD_EVENT(
...@@ -218,15 +213,9 @@ megdnn::TensorND Tensor::dnn_tensor() { ...@@ -218,15 +213,9 @@ megdnn::TensorND Tensor::dnn_tensor() {
} }
void Tensor::fetch_value() { void Tensor::fetch_value() {
MGB_LOCK_GUARD(m_blob_mtx);
MGB_LOCK_GUARD(m_value_mtx); MGB_LOCK_GUARD(m_value_mtx);
if (m_value.empty()) { if (m_value.empty()) {
DeviceTensorStorage storage; m_value.copy_from(dev_tensor(false));
storage.reset(m_cn, m_blob->size(), m_blob->storage());
storage = storage.sub(m_offset);
DeviceTensorND dv;
dv.reset(storage, m_layout);
m_value.copy_from(dv);
m_value_ready.reset(EventPool::without_timer().alloc(comp_node())); m_value_ready.reset(EventPool::without_timer().alloc(comp_node()));
m_value_ready->record(); m_value_ready->record();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册