提交 4e4abed6 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

blob slice dptr (#1225)

* enable dptr<T>(...) if T is not void

* simplify dptr(...) by parameter packing


Former-commit-id: 642f1ba8
上级 cbf36fb9
......@@ -49,6 +49,18 @@ class Blob final {
return static_cast<T*>(dptr_);
}
template<typename T, typename... Int64s>
typename std::enable_if<!std::is_same<T, void>::value, const T*>::type dptr(
int64_t dim0, Int64s... remainder_dims) const {
return dptr<T>() + GetDptrOffset(0, dim0, remainder_dims...);
}
template<typename T, typename... Int64s>
typename std::enable_if<!std::is_same<T, void>::value, T*>::type mut_dptr(
int64_t dim0, Int64s... remainder_dims) {
return mut_dptr<T>() + GetDptrOffset(0, dim0, remainder_dims...);
}
const RtBlobDesc& blob_desc() const { return *blob_desc_; }
const RtBlobDesc* blob_desc_ptr() const { return blob_desc_; }
const Shape& shape() const { return blob_desc_->shape(); }
......@@ -77,6 +89,15 @@ class Blob final {
const MemoryCase& mem_case() const;
private:
int64_t GetDptrOffset(int32_t index) const { return 0; }
template<typename... Int64s>
int64_t GetDptrOffset(int32_t index, int64_t cur_dim, Int64s... remainder) const {
CHECK_GE(shape().NumAxes(), index + 1);
CHECK_GE(cur_dim, 0);
CHECK_LT(cur_dim, shape().At(index));
return cur_dim * shape().Count(index + 1) + GetDptrOffset(index + 1, remainder...);
}
template<typename T>
void CheckDataType() const {
LOG_IF(FATAL, (std::is_same<T, void>::value == false && std::is_same<T, char>::value == false
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册