diff --git a/oneflow/core/register/blob.h b/oneflow/core/register/blob.h index d6cf1e8e6b6b1f6cc494863233764165f9545cbe..60097f1569bc09dbf0ef383a07fb8ac9aa68797b 100644 --- a/oneflow/core/register/blob.h +++ b/oneflow/core/register/blob.h @@ -49,6 +49,18 @@ class Blob final { return static_cast(dptr_); } + template + typename std::enable_if::value, const T*>::type dptr( + int64_t dim0, Int64s... remainder_dims) const { + return dptr() + GetDptrOffset(0, dim0, remainder_dims...); + } + + template + typename std::enable_if::value, T*>::type mut_dptr( + int64_t dim0, Int64s... remainder_dims) { + return mut_dptr() + GetDptrOffset(0, dim0, remainder_dims...); + } + const RtBlobDesc& blob_desc() const { return *blob_desc_; } const RtBlobDesc* blob_desc_ptr() const { return blob_desc_; } const Shape& shape() const { return blob_desc_->shape(); } @@ -77,6 +89,15 @@ class Blob final { const MemoryCase& mem_case() const; private: + int64_t GetDptrOffset(int32_t index) const { return 0; } + template + int64_t GetDptrOffset(int32_t index, int64_t cur_dim, Int64s... remainder) const { + CHECK_GE(shape().NumAxes(), index + 1); + CHECK_GE(cur_dim, 0); + CHECK_LT(cur_dim, shape().At(index)); + return cur_dim * shape().Count(index + 1) + GetDptrOffset(index + 1, remainder...); + } + template void CheckDataType() const { LOG_IF(FATAL, (std::is_same::value == false && std::is_same::value == false