提交 125129cf 编写于 作者: W willzhang4a58

fix: has_header_field

Former-commit-id: cf20a9f8
上级 088423ba
......@@ -30,6 +30,7 @@ void BasicDataLoaderOp::InferBlobDescs(
}
out->mut_shape() = Shape(dim_vec);
out->set_data_type(conf.data_type());
out->set_has_header_field(true);
out->set_has_data_id_field(JobDesc::Singleton()->SizeOfOneDataId() > 0);
out->set_has_col_num_field(conf.max_sequence_size() > 1);
out->set_max_col_num(conf.max_sequence_size());
......
......@@ -13,7 +13,7 @@ void InferBasicRnnCellBlobDesc(
int64_t piece_size = in_blob_desc->shape().At(0);
BlobDesc data_tmp_blob_desc =
BlobDesc(Shape({embedding_size, hidden_size}),
JobDesc::Singleton()->DefaultDataType(), false, false,
JobDesc::Singleton()->DefaultDataType(), false, false, false,
in_blob_desc->max_col_num());
*GetBlobDesc4BnInOp("in_ip_op_out") = data_tmp_blob_desc;
*GetBlobDesc4BnInOp("hidden_ip_op_out") = data_tmp_blob_desc;
......
......@@ -7,8 +7,12 @@ namespace oneflow {
Blob::Blob(const BlobDesc* blob_desc, char* mem_ptr,
const void* comm_net_token) {
blob_header_ = reinterpret_cast<BlobHeader*>(mem_ptr);
mem_ptr_ = mem_ptr;
if (blob_desc->has_header_field()) {
blob_header_ = reinterpret_cast<BlobHeader*>(mem_ptr);
} else {
blob_header_ = nullptr;
}
if (blob_desc->has_data_id_field()) {
data_id_ptr_ = mem_ptr + blob_desc->ByteSizeOfBlobHeaderField();
} else {
......@@ -43,10 +47,6 @@ void Blob::set_col_num(int32_t no, int32_t val) {
*(col_num_ptr_ + no) = val;
}
const void* Blob::memory_ptr() const {
return reinterpret_cast<void*>(blob_header_);
}
size_t Blob::ByteSizeOfBlobHeaderField() const {
return blob_desc_->ByteSizeOfBlobHeaderField();
}
......
......@@ -24,8 +24,8 @@ class Blob final {
int32_t col_num(int32_t no) const;
void set_col_num(int32_t no, int32_t val);
const void* memory_ptr() const;
void* mut_memory_ptr() { return const_cast<void*>(memory_ptr()); }
const void* memory_ptr() const { return mem_ptr_; }
void* mut_memory_ptr() { return mem_ptr_; }
template<typename T = void>
const T* dptr() const {
......@@ -82,6 +82,7 @@ class Blob final {
<< blob_desc_->data_type() << " " << GetDataType<T>::val;
}
void* mem_ptr_;
BlobHeader* blob_header_;
char* data_id_ptr_;
int32_t* col_num_ptr_;
......
......@@ -5,12 +5,14 @@ namespace oneflow {
BlobDesc::BlobDesc()
: BlobDesc(Shape(), JobDesc::Singleton()->DefaultDataType(), false, false,
1) {}
false, 1) {}
BlobDesc::BlobDesc(Shape shape, DataType data_type, bool has_data_id_field,
bool has_col_num_field, int32_t max_col_num)
BlobDesc::BlobDesc(Shape shape, DataType data_type, bool has_header_field,
bool has_data_id_field, bool has_col_num_field,
int32_t max_col_num)
: shape_(shape),
data_type_(data_type),
has_header_field_(has_header_field),
has_data_id_field_(has_data_id_field),
has_col_num_field_(has_col_num_field),
max_col_num_(max_col_num) {}
......@@ -18,6 +20,7 @@ BlobDesc::BlobDesc(Shape shape, DataType data_type, bool has_data_id_field,
BlobDesc::BlobDesc(const BlobDescProto& proto) {
shape_ = Shape(proto.shape());
data_type_ = proto.data_type();
has_header_field_ = proto.has_header_field();
has_data_id_field_ = proto.has_data_id_field();
has_col_num_field_ = proto.has_col_num_field();
max_col_num_ = proto.max_col_num();
......@@ -26,11 +29,20 @@ BlobDesc::BlobDesc(const BlobDescProto& proto) {
void BlobDesc::ToProto(BlobDescProto* proto) const {
shape_.ToProto(proto->mutable_shape());
proto->set_data_type(data_type_);
proto->set_has_header_field(has_header_field_);
proto->set_has_data_id_field(has_data_id_field_);
proto->set_has_col_num_field(has_col_num_field_);
proto->set_max_col_num(max_col_num_);
}
size_t BlobDesc::ByteSizeOfBlobHeaderField() const {
if (has_header_field_) {
return sizeof(BlobHeader);
} else {
return 0;
}
}
size_t BlobDesc::ByteSizeOfDataIdField() const {
if (has_data_id_field_) {
return shape_.At(0) * JobDesc::Singleton()->SizeOfOneDataId();
......@@ -58,6 +70,7 @@ size_t BlobDesc::TotalByteSize() const {
bool BlobDesc::operator==(const BlobDesc& rhs) const {
return shape_ == rhs.shape_ && data_type_ == rhs.data_type_
&& has_header_field_ == rhs.has_header_field_
&& has_data_id_field_ == rhs.has_data_id_field_
&& has_col_num_field_ == rhs.has_col_num_field_
&& max_col_num_ == rhs.max_col_num_;
......@@ -67,6 +80,7 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
int64_t total_byte_size = 0;
int64_t total_data_content_byte_size = 0;
HashSet<int> data_type_set;
bool has_header_field = false;
bool has_data_id_field = false;
bool has_col_num_field = false;
int32_t max_col_num = -1;
......@@ -78,6 +92,7 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
data_type_set.insert(static_cast<int>(blob_desc->data_type()));
has_data_id_field = has_data_id_field || blob_desc->has_data_id_field();
has_col_num_field = has_col_num_field || blob_desc->has_col_num_field();
has_header_field = has_header_field || blob_desc->has_header_field();
if (max_col_num == -1) {
max_col_num = blob_desc->max_col_num();
} else {
......@@ -90,7 +105,8 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
if (blob_desc_cnt <= 1) { return ret; }
CHECK_EQ(has_col_num_field, false);
CHECK_EQ(max_col_num, 1);
if (has_data_id_field == false && data_type_set.size() == 1) {
if (has_header_field == false && has_data_id_field == false
&& data_type_set.size() == 1) {
DataType sole_data_type = static_cast<DataType>(*(data_type_set.begin()));
int64_t size_of_one_elem = GetSizeOfDataType(sole_data_type);
CHECK_EQ(total_data_content_byte_size % size_of_one_elem, 0);
......
......@@ -19,7 +19,7 @@ class BlobDesc final {
~BlobDesc() = default;
BlobDesc();
BlobDesc(Shape shape, DataType data_type, bool has_data_id_field,
BlobDesc(Shape, DataType, bool has_header_field, bool has_data_id_field,
bool has_col_num_field, int32_t max_col_num);
BlobDesc(Shape shape) : BlobDesc() { shape_ = shape; }
BlobDesc(const BlobDescProto& proto);
......@@ -30,6 +30,9 @@ class BlobDesc final {
DataType data_type() const { return data_type_; }
void set_data_type(DataType val) { data_type_ = val; }
bool has_header_field() const { return has_header_field_; }
void set_has_header_field(bool val) { has_header_field_ = val; }
bool has_data_id_field() const { return has_data_id_field_; }
void set_has_data_id_field(bool val) { has_data_id_field_ = val; }
......@@ -40,7 +43,7 @@ class BlobDesc final {
void set_max_col_num(int32_t val) { max_col_num_ = val; }
void ToProto(BlobDescProto* proto) const;
size_t ByteSizeOfBlobHeaderField() const { return sizeof(BlobHeader); }
size_t ByteSizeOfBlobHeaderField() const;
size_t ByteSizeOfDataIdField() const;
size_t ByteSizeOfColNumField() const;
size_t ByteSizeOfDataContentField() const;
......@@ -50,6 +53,7 @@ class BlobDesc final {
private:
Shape shape_;
DataType data_type_;
bool has_header_field_;
bool has_data_id_field_;
bool has_col_num_field_;
int64_t max_col_num_;
......
......@@ -7,7 +7,8 @@ import "oneflow/core/common/data_type.proto";
message BlobDescProto {
required ShapeProto shape = 1;
required DataType data_type = 2;
required bool has_data_id_field = 3;
required bool has_col_num_field = 4;
required int32 max_col_num = 5;
required bool has_header_field = 3;
required bool has_data_id_field = 4;
required bool has_col_num_field = 5;
required int32 max_col_num = 6;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册