From f9ed8d7172cdc72c1f8c64ac4b2eef5e2972364f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Mar 2021 11:59:37 +0800 Subject: [PATCH] fix(imperative/tensor): init m_offset when constructing a Tensor with DeviceTensorND GitOrigin-RevId: b340e27c47e38ae92a57ded312d4e22eb259f872 --- imperative/python/test/unit/jit/test_tracing.py | 4 ++-- imperative/src/impl/physical_tensor.cpp | 2 +- src/core/include/megbrain/tensor.h | 7 +++++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 480587a92..6bc9e339a 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -570,9 +570,9 @@ def test_random(shape_mode): def test_trace_advance_indexing(shape_mode): funcs = [ lambda x, i: x[i], - # lambda x, i, j: x[i, j], # FIXME + lambda x, i, j: x[i, j], lambda x, i, j: x[i, :, j, ...], - # lambda x, start, end: x[start:end], # FIXME + lambda x, start, end: x[start:end], lambda x, start, end: x[:, 0, start:end, ..., 1], lambda x, vec: x[vec], lambda x, vec: x[vec, ..., 0, 1:3], diff --git a/imperative/src/impl/physical_tensor.cpp b/imperative/src/impl/physical_tensor.cpp index 778a0fb11..5903b3dcb 100644 --- a/imperative/src/impl/physical_tensor.cpp +++ b/imperative/src/impl/physical_tensor.cpp @@ -253,7 +253,7 @@ Tensor::Tensor(const DeviceTensorND &dv, const HostTensorND& hv) { } m_layout = dv.layout(); m_blob = Blob::make(dv.storage()); - m_offset = 0; + m_offset = dv.storage().offset(); } Tensor::Tensor(const TensorLayout& layout, const CompNode& cn) diff --git a/src/core/include/megbrain/tensor.h b/src/core/include/megbrain/tensor.h index fe73be174..e22133d11 100644 --- a/src/core/include/megbrain/tensor.h +++ b/src/core/include/megbrain/tensor.h @@ -176,6 +176,13 @@ class TensorStorage { return m_size; } + /*! + * \brief offset on allocated block in bytes + */ + size_t offset() const { + return m_offset; + } + //! get underlying comp node; error would be raised if it is invalid CompNode comp_node() const { check_comp_node_valid(); -- GitLab