diff --git a/oneflow/core/common/shape_view.h b/oneflow/core/common/shape_view.h index 128be1532860745f5e893ea7d9055ed4a1636f71..d2184655aeb022adef090fb941b95ada64c42af1 100644 --- a/oneflow/core/common/shape_view.h +++ b/oneflow/core/common/shape_view.h @@ -82,13 +82,36 @@ class ShapeViewBase { */ int64_t elem_cnt() const; + /** + * @brief The shape of tensor is stored as array in the memory and the method ptr + * will get the pointer to the begin of the array buffer. + * + * @return const DimType* (a.k.a const DimT*) + */ const DimType* ptr() const { return ptr_; } bool operator==(const ShapeViewBase& rhs) const; std::string ToString() const; + + /** + * @brief Convert ShapeViewBase to DimVector + * + * @param dim_vec + */ void ToDimVector(DimVector* dim_vec) const; + + /** + * @brief Convert ShapeViewBase to Shape + * + * @param shape + */ void ToShape(Shape* shape) const; + /** + * @brief Set the pointer to the array buffer + * + * @param ptr + */ void set_ptr(DimType* ptr) { ptr_ = ptr; } protected: