提交 0719c5a8 编写于 作者: N Niu Chong 提交者: Will Zhang

feat: add col_num(that is sequence length) in Blob and BlobDesc (#494)

* feat: add seq_len in Blob and BlobDesc

* fix: rename PieceStatus to BlobHeader and insert it to mem_ptr of Blob

* fix: fix the comments

* fix: fix typos and remove  vim .swp file

* fix: fix the comment

* fix: update max_col_num() to col_num() in Register

* fix: fix due to the comment
上级 0c19fdbc
#include "oneflow/core/register/blob.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/job/runtime_context.h"
namespace oneflow {
Blob::Blob(const BlobDesc* blob_desc, char* mem_ptr,
const void* comm_net_token) {
data_id_ptr_ = blob_desc->has_data_id_field() ? mem_ptr : nullptr;
dptr_ = mem_ptr + blob_desc->ByteSizeOfDataIdField();
blob_header_ = reinterpret_cast<BlobHeader*>(mem_ptr);
if (blob_desc->has_data_id_field()) {
data_id_ptr_ = mem_ptr + blob_desc->ByteSizeOfBlobHeaderField();
} else {
data_id_ptr_ = nullptr;
}
if (blob_desc->has_col_num_field()) {
col_num_ptr_ = reinterpret_cast<int32_t*>(
mem_ptr + blob_desc->ByteSizeOfBlobHeaderField()
+ blob_desc->ByteSizeOfDataIdField());
} else {
col_num_ptr_ = nullptr;
}
dptr_ = mem_ptr + blob_desc->ByteSizeOfBlobHeaderField()
+ blob_desc->ByteSizeOfDataIdField()
+ blob_desc->ByteSizeOfColNumField();
blob_desc_ = blob_desc;
comm_net_token_ = comm_net_token;
}
......@@ -17,18 +33,64 @@ const char* Blob::data_id(int32_t no) const {
return data_id_ptr_ + no * JobDesc::Singleton()->SizeOfOneDataId();
}
int32_t Blob::col_num(int32_t no) const {
CHECK_NOTNULL(col_num_ptr_);
return *(col_num_ptr_ + no);
}
void Blob::set_col_num(int32_t no, int32_t val) {
CHECK_NOTNULL(col_num_ptr_);
*(col_num_ptr_ + no) = val;
}
const void* Blob::memory_ptr() const {
return reinterpret_cast<void*>(blob_header_);
}
size_t Blob::ByteSizeOfBlobHeaderField() const {
return blob_desc_->ByteSizeOfDataIdField();
}
size_t Blob::ByteSizeOfDataIdField() const {
return blob_desc_->ByteSizeOfDataIdField();
}
size_t Blob::ByteSizeOfColNumField() const {
return blob_desc_->ByteSizeOfColNumField();
}
size_t Blob::ByteSizeOfDataContentField() const {
return blob_desc_->ByteSizeOfDataContentField();
}
template<DeviceType device_type>
void Blob::CopyBlobHeaderFrom(DeviceCtx* device_ctx, const Blob* rhs) {
if (this == rhs) { return; }
Memcpy<device_type>(device_ctx, blob_header_, rhs->blob_header_,
ByteSizeOfBlobHeaderField());
}
template<DeviceType device_type>
void Blob::CopyDataContentFrom(DeviceCtx* device_ctx, const Blob* rhs) {
if (this == rhs) { return; }
Memcpy<device_type>(device_ctx, dptr_, rhs->dptr_,
ByteSizeOfDataContentField());
}
template<DeviceType device_type>
void Blob::CopyDataIdFrom(DeviceCtx* device_ctx, const Blob* rhs) {
if (this == rhs) { return; }
Memcpy<device_type>(device_ctx, data_id_ptr_, rhs->data_id_ptr_,
ByteSizeOfDataIdField());
}
template<DeviceType device_type>
void Blob::CopyColNumFrom(DeviceCtx* device_ctx, const Blob* rhs) {
if (this == rhs) { return; }
Memcpy<device_type>(device_ctx, col_num_ptr_, rhs->col_num_ptr_,
ByteSizeOfColNumField());
}
template<DeviceType device_type>
void Blob::CopyFrom(DeviceCtx* device_ctx, const Blob* rhs) {
if (this == rhs) { return; }
......@@ -37,8 +99,10 @@ void Blob::CopyFrom(DeviceCtx* device_ctx, const Blob* rhs) {
}
#define INSTANTIATE_BLOB_FUNC(dev_t) \
template void Blob::CopyBlobHeaderFrom<dev_t>(DeviceCtx*, const Blob*); \
template void Blob::CopyDataContentFrom<dev_t>(DeviceCtx*, const Blob*); \
template void Blob::CopyDataIdFrom<dev_t>(DeviceCtx*, const Blob*); \
template void Blob::CopyColNumFrom<dev_t>(DeviceCtx*, const Blob*); \
template void Blob::CopyFrom<dev_t>(DeviceCtx*, const Blob*);
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_BLOB_FUNC, DEVICE_TYPE_SEQ);
......
......@@ -21,9 +21,10 @@ class Blob final {
const char* data_id() const { return data_id(0); }
char* mut_data_id() { return mut_data_id(0); }
const void* memory_ptr() const {
return data_id_ptr_ == nullptr ? dptr_ : static_cast<void*>(data_id_ptr_);
}
int32_t col_num(int32_t no) const;
void set_col_num(int32_t no, int32_t val);
const void* memory_ptr() const;
void* mut_memory_ptr() { return const_cast<void*>(memory_ptr()); }
template<typename T = void>
......@@ -45,21 +46,31 @@ class Blob final {
const Shape& shape() const { return blob_desc_->shape(); }
DataType data_type() const { return blob_desc_->data_type(); }
bool has_data_id_field() const { return blob_desc_->has_data_id_field(); }
size_t ByteSizeOfDataIdField() const {
return blob_desc_->ByteSizeOfDataIdField();
}
size_t ByteSizeOfDataContentField() const {
return blob_desc_->ByteSizeOfDataContentField();
}
bool has_col_num_field() const { return blob_desc_->has_col_num_field(); }
size_t ByteSizeOfBlobHeaderField() const;
size_t ByteSizeOfDataIdField() const;
size_t ByteSizeOfColNumField() const;
size_t ByteSizeOfDataContentField() const;
size_t TotalByteSize() const { return blob_desc_->TotalByteSize(); }
template<DeviceType device_type>
void CopyBlobHeaderFrom(DeviceCtx* device_ctx, const Blob* rhs);
template<DeviceType device_type>
void CopyDataContentFrom(DeviceCtx* device_ctx, const Blob* rhs);
template<DeviceType device_type>
void CopyDataIdFrom(DeviceCtx* device_ctx, const Blob* rhs);
template<DeviceType device_type>
void CopyColNumFrom(DeviceCtx* device_ctx, const Blob* rhs);
template<DeviceType device_type>
void CopyFrom(DeviceCtx* device_ctx, const Blob* rhs);
int64_t col_id() const { return blob_header_->col_id; }
void set_col_id(int64_t val) { blob_header_->col_id = val; }
int64_t max_col_id() const { return blob_header_->max_col_id; }
void set_max_col_id(int64_t val) { blob_header_->max_col_id = val; }
bool IsLastCol() const { return col_id() == max_col_id(); }
private:
template<typename T>
void CheckDataType() const {
......@@ -70,7 +81,9 @@ class Blob final {
<< blob_desc_->data_type() << " " << GetDataType<T>::val;
}
BlobHeader* blob_header_;
char* data_id_ptr_;
int32_t* col_num_ptr_;
void* dptr_;
const void* comm_net_token_;
const BlobDesc* blob_desc_;
......
......@@ -13,6 +13,12 @@ Regst::Regst() {
regst_desc_ = nullptr;
}
bool Regst::HaveNextPieceColStatusOf(const Regst* other) const {
return (piece_id() == other->piece_id())
&& (max_col_id() == other->max_col_id())
&& (col_id() == other->col_id() + 1);
}
Blob* Regst::GetBlobByLbn(const std::string& lbn) {
auto it = lbn2blob_.find(lbn);
if (it != lbn2blob_.end()) {
......
......@@ -21,6 +21,11 @@ class Regst final {
Blob* GetBlobByLbn(const std::string& lbn);
Blob* packed_blob() { return packed_blob_.get(); }
int64_t col_id() const { return FirstBlob()->col_id(); }
int64_t max_col_id() const { return FirstBlob()->max_col_id(); }
bool IsLastCol() const { return FirstBlob()->IsLastCol(); }
bool HaveNextPieceColStatusOf(const Regst* other) const;
// Setters
void set_piece_id(int64_t val) { piece_id_ = val; }
void set_model_version_id(int64_t val) { model_version_id_ = val; }
......@@ -29,6 +34,8 @@ class Regst final {
friend class RegstMgr;
Regst();
const Blob* FirstBlob() const { return lbn2blob_.begin()->second.get(); }
int64_t piece_id_;
int64_t model_version_id_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册