提交 125129cf 编写于 作者: W willzhang4a58

fix: has_header_field

Former-commit-id: cf20a9f8
上级 088423ba
...@@ -30,6 +30,7 @@ void BasicDataLoaderOp::InferBlobDescs( ...@@ -30,6 +30,7 @@ void BasicDataLoaderOp::InferBlobDescs(
} }
out->mut_shape() = Shape(dim_vec); out->mut_shape() = Shape(dim_vec);
out->set_data_type(conf.data_type()); out->set_data_type(conf.data_type());
out->set_has_header_field(true);
out->set_has_data_id_field(JobDesc::Singleton()->SizeOfOneDataId() > 0); out->set_has_data_id_field(JobDesc::Singleton()->SizeOfOneDataId() > 0);
out->set_has_col_num_field(conf.max_sequence_size() > 1); out->set_has_col_num_field(conf.max_sequence_size() > 1);
out->set_max_col_num(conf.max_sequence_size()); out->set_max_col_num(conf.max_sequence_size());
......
...@@ -13,7 +13,7 @@ void InferBasicRnnCellBlobDesc( ...@@ -13,7 +13,7 @@ void InferBasicRnnCellBlobDesc(
int64_t piece_size = in_blob_desc->shape().At(0); int64_t piece_size = in_blob_desc->shape().At(0);
BlobDesc data_tmp_blob_desc = BlobDesc data_tmp_blob_desc =
BlobDesc(Shape({embedding_size, hidden_size}), BlobDesc(Shape({embedding_size, hidden_size}),
JobDesc::Singleton()->DefaultDataType(), false, false, JobDesc::Singleton()->DefaultDataType(), false, false, false,
in_blob_desc->max_col_num()); in_blob_desc->max_col_num());
*GetBlobDesc4BnInOp("in_ip_op_out") = data_tmp_blob_desc; *GetBlobDesc4BnInOp("in_ip_op_out") = data_tmp_blob_desc;
*GetBlobDesc4BnInOp("hidden_ip_op_out") = data_tmp_blob_desc; *GetBlobDesc4BnInOp("hidden_ip_op_out") = data_tmp_blob_desc;
......
...@@ -7,8 +7,12 @@ namespace oneflow { ...@@ -7,8 +7,12 @@ namespace oneflow {
Blob::Blob(const BlobDesc* blob_desc, char* mem_ptr, Blob::Blob(const BlobDesc* blob_desc, char* mem_ptr,
const void* comm_net_token) { const void* comm_net_token) {
mem_ptr_ = mem_ptr;
if (blob_desc->has_header_field()) {
blob_header_ = reinterpret_cast<BlobHeader*>(mem_ptr); blob_header_ = reinterpret_cast<BlobHeader*>(mem_ptr);
} else {
blob_header_ = nullptr;
}
if (blob_desc->has_data_id_field()) { if (blob_desc->has_data_id_field()) {
data_id_ptr_ = mem_ptr + blob_desc->ByteSizeOfBlobHeaderField(); data_id_ptr_ = mem_ptr + blob_desc->ByteSizeOfBlobHeaderField();
} else { } else {
...@@ -43,10 +47,6 @@ void Blob::set_col_num(int32_t no, int32_t val) { ...@@ -43,10 +47,6 @@ void Blob::set_col_num(int32_t no, int32_t val) {
*(col_num_ptr_ + no) = val; *(col_num_ptr_ + no) = val;
} }
const void* Blob::memory_ptr() const {
return reinterpret_cast<void*>(blob_header_);
}
size_t Blob::ByteSizeOfBlobHeaderField() const { size_t Blob::ByteSizeOfBlobHeaderField() const {
return blob_desc_->ByteSizeOfBlobHeaderField(); return blob_desc_->ByteSizeOfBlobHeaderField();
} }
......
...@@ -24,8 +24,8 @@ class Blob final { ...@@ -24,8 +24,8 @@ class Blob final {
int32_t col_num(int32_t no) const; int32_t col_num(int32_t no) const;
void set_col_num(int32_t no, int32_t val); void set_col_num(int32_t no, int32_t val);
const void* memory_ptr() const; const void* memory_ptr() const { return mem_ptr_; }
void* mut_memory_ptr() { return const_cast<void*>(memory_ptr()); } void* mut_memory_ptr() { return mem_ptr_; }
template<typename T = void> template<typename T = void>
const T* dptr() const { const T* dptr() const {
...@@ -82,6 +82,7 @@ class Blob final { ...@@ -82,6 +82,7 @@ class Blob final {
<< blob_desc_->data_type() << " " << GetDataType<T>::val; << blob_desc_->data_type() << " " << GetDataType<T>::val;
} }
void* mem_ptr_;
BlobHeader* blob_header_; BlobHeader* blob_header_;
char* data_id_ptr_; char* data_id_ptr_;
int32_t* col_num_ptr_; int32_t* col_num_ptr_;
......
...@@ -5,12 +5,14 @@ namespace oneflow { ...@@ -5,12 +5,14 @@ namespace oneflow {
BlobDesc::BlobDesc() BlobDesc::BlobDesc()
: BlobDesc(Shape(), JobDesc::Singleton()->DefaultDataType(), false, false, : BlobDesc(Shape(), JobDesc::Singleton()->DefaultDataType(), false, false,
1) {} false, 1) {}
BlobDesc::BlobDesc(Shape shape, DataType data_type, bool has_data_id_field, BlobDesc::BlobDesc(Shape shape, DataType data_type, bool has_header_field,
bool has_col_num_field, int32_t max_col_num) bool has_data_id_field, bool has_col_num_field,
int32_t max_col_num)
: shape_(shape), : shape_(shape),
data_type_(data_type), data_type_(data_type),
has_header_field_(has_header_field),
has_data_id_field_(has_data_id_field), has_data_id_field_(has_data_id_field),
has_col_num_field_(has_col_num_field), has_col_num_field_(has_col_num_field),
max_col_num_(max_col_num) {} max_col_num_(max_col_num) {}
...@@ -18,6 +20,7 @@ BlobDesc::BlobDesc(Shape shape, DataType data_type, bool has_data_id_field, ...@@ -18,6 +20,7 @@ BlobDesc::BlobDesc(Shape shape, DataType data_type, bool has_data_id_field,
BlobDesc::BlobDesc(const BlobDescProto& proto) { BlobDesc::BlobDesc(const BlobDescProto& proto) {
shape_ = Shape(proto.shape()); shape_ = Shape(proto.shape());
data_type_ = proto.data_type(); data_type_ = proto.data_type();
has_header_field_ = proto.has_header_field();
has_data_id_field_ = proto.has_data_id_field(); has_data_id_field_ = proto.has_data_id_field();
has_col_num_field_ = proto.has_col_num_field(); has_col_num_field_ = proto.has_col_num_field();
max_col_num_ = proto.max_col_num(); max_col_num_ = proto.max_col_num();
...@@ -26,11 +29,20 @@ BlobDesc::BlobDesc(const BlobDescProto& proto) { ...@@ -26,11 +29,20 @@ BlobDesc::BlobDesc(const BlobDescProto& proto) {
void BlobDesc::ToProto(BlobDescProto* proto) const { void BlobDesc::ToProto(BlobDescProto* proto) const {
shape_.ToProto(proto->mutable_shape()); shape_.ToProto(proto->mutable_shape());
proto->set_data_type(data_type_); proto->set_data_type(data_type_);
proto->set_has_header_field(has_header_field_);
proto->set_has_data_id_field(has_data_id_field_); proto->set_has_data_id_field(has_data_id_field_);
proto->set_has_col_num_field(has_col_num_field_); proto->set_has_col_num_field(has_col_num_field_);
proto->set_max_col_num(max_col_num_); proto->set_max_col_num(max_col_num_);
} }
size_t BlobDesc::ByteSizeOfBlobHeaderField() const {
if (has_header_field_) {
return sizeof(BlobHeader);
} else {
return 0;
}
}
size_t BlobDesc::ByteSizeOfDataIdField() const { size_t BlobDesc::ByteSizeOfDataIdField() const {
if (has_data_id_field_) { if (has_data_id_field_) {
return shape_.At(0) * JobDesc::Singleton()->SizeOfOneDataId(); return shape_.At(0) * JobDesc::Singleton()->SizeOfOneDataId();
...@@ -58,6 +70,7 @@ size_t BlobDesc::TotalByteSize() const { ...@@ -58,6 +70,7 @@ size_t BlobDesc::TotalByteSize() const {
bool BlobDesc::operator==(const BlobDesc& rhs) const { bool BlobDesc::operator==(const BlobDesc& rhs) const {
return shape_ == rhs.shape_ && data_type_ == rhs.data_type_ return shape_ == rhs.shape_ && data_type_ == rhs.data_type_
&& has_header_field_ == rhs.has_header_field_
&& has_data_id_field_ == rhs.has_data_id_field_ && has_data_id_field_ == rhs.has_data_id_field_
&& has_col_num_field_ == rhs.has_col_num_field_ && has_col_num_field_ == rhs.has_col_num_field_
&& max_col_num_ == rhs.max_col_num_; && max_col_num_ == rhs.max_col_num_;
...@@ -67,6 +80,7 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) { ...@@ -67,6 +80,7 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
int64_t total_byte_size = 0; int64_t total_byte_size = 0;
int64_t total_data_content_byte_size = 0; int64_t total_data_content_byte_size = 0;
HashSet<int> data_type_set; HashSet<int> data_type_set;
bool has_header_field = false;
bool has_data_id_field = false; bool has_data_id_field = false;
bool has_col_num_field = false; bool has_col_num_field = false;
int32_t max_col_num = -1; int32_t max_col_num = -1;
...@@ -78,6 +92,7 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) { ...@@ -78,6 +92,7 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
data_type_set.insert(static_cast<int>(blob_desc->data_type())); data_type_set.insert(static_cast<int>(blob_desc->data_type()));
has_data_id_field = has_data_id_field || blob_desc->has_data_id_field(); has_data_id_field = has_data_id_field || blob_desc->has_data_id_field();
has_col_num_field = has_col_num_field || blob_desc->has_col_num_field(); has_col_num_field = has_col_num_field || blob_desc->has_col_num_field();
has_header_field = has_header_field || blob_desc->has_header_field();
if (max_col_num == -1) { if (max_col_num == -1) {
max_col_num = blob_desc->max_col_num(); max_col_num = blob_desc->max_col_num();
} else { } else {
...@@ -90,7 +105,8 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) { ...@@ -90,7 +105,8 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
if (blob_desc_cnt <= 1) { return ret; } if (blob_desc_cnt <= 1) { return ret; }
CHECK_EQ(has_col_num_field, false); CHECK_EQ(has_col_num_field, false);
CHECK_EQ(max_col_num, 1); CHECK_EQ(max_col_num, 1);
if (has_data_id_field == false && data_type_set.size() == 1) { if (has_header_field == false && has_data_id_field == false
&& data_type_set.size() == 1) {
DataType sole_data_type = static_cast<DataType>(*(data_type_set.begin())); DataType sole_data_type = static_cast<DataType>(*(data_type_set.begin()));
int64_t size_of_one_elem = GetSizeOfDataType(sole_data_type); int64_t size_of_one_elem = GetSizeOfDataType(sole_data_type);
CHECK_EQ(total_data_content_byte_size % size_of_one_elem, 0); CHECK_EQ(total_data_content_byte_size % size_of_one_elem, 0);
......
...@@ -19,7 +19,7 @@ class BlobDesc final { ...@@ -19,7 +19,7 @@ class BlobDesc final {
~BlobDesc() = default; ~BlobDesc() = default;
BlobDesc(); BlobDesc();
BlobDesc(Shape shape, DataType data_type, bool has_data_id_field, BlobDesc(Shape, DataType, bool has_header_field, bool has_data_id_field,
bool has_col_num_field, int32_t max_col_num); bool has_col_num_field, int32_t max_col_num);
BlobDesc(Shape shape) : BlobDesc() { shape_ = shape; } BlobDesc(Shape shape) : BlobDesc() { shape_ = shape; }
BlobDesc(const BlobDescProto& proto); BlobDesc(const BlobDescProto& proto);
...@@ -30,6 +30,9 @@ class BlobDesc final { ...@@ -30,6 +30,9 @@ class BlobDesc final {
DataType data_type() const { return data_type_; } DataType data_type() const { return data_type_; }
void set_data_type(DataType val) { data_type_ = val; } void set_data_type(DataType val) { data_type_ = val; }
bool has_header_field() const { return has_header_field_; }
void set_has_header_field(bool val) { has_header_field_ = val; }
bool has_data_id_field() const { return has_data_id_field_; } bool has_data_id_field() const { return has_data_id_field_; }
void set_has_data_id_field(bool val) { has_data_id_field_ = val; } void set_has_data_id_field(bool val) { has_data_id_field_ = val; }
...@@ -40,7 +43,7 @@ class BlobDesc final { ...@@ -40,7 +43,7 @@ class BlobDesc final {
void set_max_col_num(int32_t val) { max_col_num_ = val; } void set_max_col_num(int32_t val) { max_col_num_ = val; }
void ToProto(BlobDescProto* proto) const; void ToProto(BlobDescProto* proto) const;
size_t ByteSizeOfBlobHeaderField() const { return sizeof(BlobHeader); } size_t ByteSizeOfBlobHeaderField() const;
size_t ByteSizeOfDataIdField() const; size_t ByteSizeOfDataIdField() const;
size_t ByteSizeOfColNumField() const; size_t ByteSizeOfColNumField() const;
size_t ByteSizeOfDataContentField() const; size_t ByteSizeOfDataContentField() const;
...@@ -50,6 +53,7 @@ class BlobDesc final { ...@@ -50,6 +53,7 @@ class BlobDesc final {
private: private:
Shape shape_; Shape shape_;
DataType data_type_; DataType data_type_;
bool has_header_field_;
bool has_data_id_field_; bool has_data_id_field_;
bool has_col_num_field_; bool has_col_num_field_;
int64_t max_col_num_; int64_t max_col_num_;
......
...@@ -7,7 +7,8 @@ import "oneflow/core/common/data_type.proto"; ...@@ -7,7 +7,8 @@ import "oneflow/core/common/data_type.proto";
message BlobDescProto { message BlobDescProto {
required ShapeProto shape = 1; required ShapeProto shape = 1;
required DataType data_type = 2; required DataType data_type = 2;
required bool has_data_id_field = 3; required bool has_header_field = 3;
required bool has_col_num_field = 4; required bool has_data_id_field = 4;
required int32 max_col_num = 5; required bool has_col_num_field = 5;
required int32 max_col_num = 6;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册