From 1c29eb42febbe68b5681ae8ef0213e7e8dc3a56c Mon Sep 17 00:00:00 2001 From: Li Xinqi Date: Mon, 1 Oct 2018 19:09:46 +0800 Subject: [PATCH] Dev pod desc (#1268) * available instance num * import shape.proto * PodProto * rename message * union pod is useless * PodPtr * rename: PodPtr::get() => PodPtr::Get() * BlobDescProto.pod * mv register_desc.time_shape into another pr * pod_helper.h * FieldAlignedByteSize * pod_desc * PodDesc copy constructor * BlobDesc::body_shape_pod_desc_ * add BlobDesc::opaque_header_pod_desc_ * align_shift => alignment * default alignment * add field Blob::header_pod_ptr_ * rename AlignedFieldPodProto => FieldPodProto * bugfix * check * FieldId * simplify RtBlobDesc * simplify Blob * ShapedPod => TensorPod * refine ComputePackedBlobDesc Former-commit-id: 8800da932878813705e4dc8d8dd2b216039c5d5c --- oneflow/core/kernel/kernel.proto | 1 + oneflow/core/operator/op_conf.proto | 8 +- oneflow/core/register/blob.cpp | 26 +-- oneflow/core/register/blob.h | 4 + oneflow/core/register/blob_desc.cpp | 42 +++-- oneflow/core/register/blob_desc.h | 9 +- oneflow/core/register/blob_desc.proto | 2 + oneflow/core/register/logical_blob_id.proto | 9 + oneflow/core/register/pod.proto | 41 +++++ oneflow/core/register/pod_desc.cpp | 178 ++++++++++++++++++++ oneflow/core/register/pod_desc.h | 155 +++++++++++++++++ oneflow/core/register/pod_ptr.cpp | 22 +++ oneflow/core/register/pod_ptr.h | 72 ++++++++ oneflow/core/register/register_desc.proto | 2 +- oneflow/core/register/runtime_blob_desc.cpp | 64 +------ oneflow/core/register/runtime_blob_desc.h | 10 +- 16 files changed, 536 insertions(+), 109 deletions(-) create mode 100644 oneflow/core/register/logical_blob_id.proto create mode 100644 oneflow/core/register/pod.proto create mode 100644 oneflow/core/register/pod_desc.cpp create mode 100644 oneflow/core/register/pod_desc.h create mode 100644 oneflow/core/register/pod_ptr.cpp create mode 100644 oneflow/core/register/pod_ptr.h diff --git a/oneflow/core/kernel/kernel.proto b/oneflow/core/kernel/kernel.proto index cdeeab3ddc..f4a4b0fa9b 100644 --- a/oneflow/core/kernel/kernel.proto +++ b/oneflow/core/kernel/kernel.proto @@ -3,6 +3,7 @@ package oneflow; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/data_type.proto"; +import "oneflow/core/register/logical_blob_id.proto"; import "oneflow/core/operator/op_conf.proto"; message ConvKernelConf { diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index ad57cbfd0b..bbdb2833db 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -5,6 +5,7 @@ import "oneflow/core/common/shape.proto"; import "oneflow/core/common/data_type.proto"; import "oneflow/core/record/image.proto"; import "oneflow/core/job/resource.proto"; +import "oneflow/core/register/logical_blob_id.proto"; enum ActivationType { kNone = 0; @@ -259,13 +260,6 @@ message BoxSplitConf { message BoxCloneConf { } -message LogicalBlobId { - optional string op_name = 1; - optional string blob_name = 2; - optional int32 clone_id = 4 [default = -1]; - optional bool is_packed_id = 5 [default = false]; -} - message BoxingOpConf { required LogicalBlobId lbi = 1; required int32 in_num = 2; diff --git a/oneflow/core/register/blob.cpp b/oneflow/core/register/blob.cpp index bce1c7c47a..be6e123ede 100644 --- a/oneflow/core/register/blob.cpp +++ b/oneflow/core/register/blob.cpp @@ -6,35 +6,23 @@ namespace oneflow { -Blob::Blob(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr) { +Blob::Blob(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr) + : header_pod_ptr_(blob_desc->header_pod_desc(), header_ptr) { Init(regst, blob_desc, header_ptr, header_ptr + blob_desc->ByteSizeOfBlobHeader()); } -Blob::Blob(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr, char* body_ptr) { +Blob::Blob(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr, char* body_ptr) + : header_pod_ptr_(blob_desc->header_pod_desc(), header_ptr) { Init(regst, blob_desc, header_ptr, body_ptr); } void Blob::Init(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr, char* body_ptr) { - if (body_ptr == header_ptr + blob_desc->ByteSizeOfBlobHeader()) { - is_contiguous_ = true; - } else { - is_contiguous_ = false; - } - + is_contiguous_ = (body_ptr == header_ptr + blob_desc->ByteSizeOfBlobHeader()); regst_ = regst; blob_desc_ = blob_desc; header_ptr_ = header_ptr; - if (blob_desc->has_data_id_field()) { - data_id_ptr_ = header_ptr; - } else { - data_id_ptr_ = nullptr; - } - char* offset = header_ptr + blob_desc->ByteSizeOfDataIdField(); - if (blob_desc->has_col_num_field()) { - col_num_ptr_ = reinterpret_cast(offset); - } else { - col_num_ptr_ = nullptr; - } + data_id_ptr_ = header_pod_ptr_.MutTensorPtr(FieldKey::kDataId, nullptr); + col_num_ptr_ = header_pod_ptr_.MutTensorPtr(FieldKey::kColNum, nullptr); dptr_ = body_ptr; } diff --git a/oneflow/core/register/blob.h b/oneflow/core/register/blob.h index 60097f1569..a8c5527c96 100644 --- a/oneflow/core/register/blob.h +++ b/oneflow/core/register/blob.h @@ -9,6 +9,7 @@ #include "oneflow/core/persistence/persistent_in_stream.h" #include "oneflow/core/record/record.pb.h" #include "oneflow/core/record/record_io.h" +#include "oneflow/core/register/pod_ptr.h" namespace oneflow { @@ -87,6 +88,8 @@ class Blob final { void set_max_col_id(int32_t val); bool IsColValid() const { return col_id() <= max_col_id(); } const MemoryCase& mem_case() const; + const PodPtr* header_pod_ptr() const { return &header_pod_ptr_; } + PodPtr* header_pod_ptr() { return &header_pod_ptr_; } private: int64_t GetDptrOffset(int32_t index) const { return 0; } @@ -114,6 +117,7 @@ class Blob final { void* dptr_; const RtBlobDesc* blob_desc_; Regst* regst_; + PodPtr header_pod_ptr_; }; template diff --git a/oneflow/core/register/blob_desc.cpp b/oneflow/core/register/blob_desc.cpp index 6c165ba27f..e043605bce 100644 --- a/oneflow/core/register/blob_desc.cpp +++ b/oneflow/core/register/blob_desc.cpp @@ -23,6 +23,7 @@ BlobDesc::BlobDesc(const BlobDescProto& proto) : body_field_(proto.body()) { has_data_id_ = false; has_col_num_ = false; opaque_header_ = FieldDesc(proto.header().opaque_header()); + opaque_header_pod_desc_.InitFromProto(proto.header().header_pod_desc()); } else { CHECK(proto.header().has_field_header()); header_is_opaque_ = false; @@ -31,16 +32,18 @@ BlobDesc::BlobDesc(const BlobDescProto& proto) : body_field_(proto.body()) { } } -BlobDesc::BlobDesc(int64_t header_byte_size, const Shape& shape, DataType data_type, - int32_t max_col_num) +BlobDesc::BlobDesc(const StructPodDesc& header_pod_desc, int64_t header_byte_size, + const Shape& shape, DataType data_type, int32_t max_col_num) : has_data_id_(false), has_col_num_(false), max_col_num_(max_col_num), blob_mem_id_(-1), body_field_(shape, data_type) { + CHECK_EQ(header_pod_desc.ByteSize(), header_byte_size); if (header_byte_size > 0) { header_is_opaque_ = true; opaque_header_ = FieldDesc(Shape({header_byte_size}), DataType::kChar); + opaque_header_pod_desc_ = header_pod_desc; } else { header_is_opaque_ = false; } @@ -54,16 +57,19 @@ void BlobDesc::set_has_col_num_field(bool val) { CHECK(!header_is_opaque_); has_col_num_ = val; } -void BlobDesc::DataIdFieldToProto(FieldHeaderDesc* proto) const { - FieldDesc data_id_field(Shape({body_field_.shape().At(0), - static_cast(Global::Get()->SizeOfOneDataId())}), - DataType::kChar); +void BlobDesc::DataIdFieldToProto(FieldHeaderDesc* proto, StructPodDesc* header_pod_desc) const { + Shape shape( + {body_field_.shape().At(0), static_cast(Global::Get()->SizeOfOneDataId())}); + FieldDesc data_id_field(shape, DataType::kChar); data_id_field.ToProto(proto->mutable_data_id()); + header_pod_desc->AddField(FieldKey::kDataId, TensorPodDesc(shape, DataType::kChar)); } -void BlobDesc::ColNumFieldToProto(FieldHeaderDesc* proto) const { - FieldDesc col_num_field(Shape({body_field_.shape().At(0)}), DataType::kInt32); +void BlobDesc::ColNumFieldToProto(FieldHeaderDesc* proto, StructPodDesc* header_pod_desc) const { + Shape shape({body_field_.shape().At(0)}); + FieldDesc col_num_field(shape, DataType::kInt32); col_num_field.ToProto(proto->mutable_col_num()); + header_pod_desc->AddField(FieldKey::kColNum, TensorPodDesc(shape, DataType::kInt32)); } void BlobDesc::HeaderToProto(BlobDescProto* proto) const { @@ -71,10 +77,13 @@ void BlobDesc::HeaderToProto(BlobDescProto* proto) const { proto->mutable_header()->set_blob_mem_id(blob_mem_id_); if (!header_is_opaque_) { FieldHeaderDesc* field_header = proto->mutable_header()->mutable_field_header(); - if (has_data_id_field()) { DataIdFieldToProto(field_header); } - if (has_col_num_field()) { ColNumFieldToProto(field_header); } + StructPodDesc header_pod_desc; + if (has_data_id_field()) { DataIdFieldToProto(field_header, &header_pod_desc); } + if (has_col_num_field()) { ColNumFieldToProto(field_header, &header_pod_desc); } + header_pod_desc.ToProto(proto->mutable_header()->mutable_header_pod_desc()); } else { opaque_header_.ToProto(proto->mutable_header()->mutable_opaque_header()); + opaque_header_pod_desc_.ToProto(proto->mutable_header()->mutable_header_pod_desc()); } } @@ -85,6 +94,7 @@ void BlobDesc::ToProto(BlobDescProto* proto) const { bool BlobDesc::operator==(const BlobDesc& rhs) const { return header_is_opaque_ == rhs.header_is_opaque_ && opaque_header_ == rhs.opaque_header_ + && opaque_header_pod_desc_ == rhs.opaque_header_pod_desc_ && has_data_id_ == rhs.has_data_id_ && has_col_num_ == rhs.has_col_num_ && max_col_num_ == rhs.max_col_num_ && blob_mem_id_ == rhs.blob_mem_id_ && body_field_ == rhs.body_field_; @@ -93,6 +103,7 @@ bool BlobDesc::operator==(const BlobDesc& rhs) const { BlobDesc& BlobDesc::operator=(const BlobDesc& blob_desc) { header_is_opaque_ = blob_desc.header_is_opaque_; opaque_header_ = blob_desc.opaque_header_; + opaque_header_pod_desc_ = blob_desc.opaque_header_pod_desc_; has_data_id_ = blob_desc.has_data_id_; has_col_num_ = blob_desc.has_col_num_; max_col_num_ = blob_desc.max_col_num_; @@ -111,11 +122,12 @@ std::unique_ptr ComputePackedBlobDesc( std::unique_ptr ret(new BlobDesc()); const BlobDesc* last_blob_desc = nullptr; HashMap blob_mem_id2size; - + StructPodDesc opaque_header_pod_desc; for (auto& pair : lbi2blob_desc) { BlobDesc* blob_desc = pair.second.get(); RtBlobDesc rt_blob_desc(*blob_desc); header_byte_size += rt_blob_desc.ByteSizeOfBlobHeader(); + *opaque_header_pod_desc.MutStructField(NewFieldId(pair.first)) = rt_blob_desc.header_pod_desc(); int64_t cur_body_byte_size = rt_blob_desc.ByteSizeOfBlobBody(); int32_t blob_mem_id = blob_desc->blob_mem_id(); if (blob_mem_id == -1) { @@ -150,12 +162,12 @@ std::unique_ptr ComputePackedBlobDesc( if (header_byte_size == 0) { ret.reset(new BlobDesc(Shape({total_elem_cnt}), sole_data_type, false, false, max_col_num)); } else { - ret.reset( - new BlobDesc(header_byte_size, Shape({total_elem_cnt}), sole_data_type, max_col_num)); + ret.reset(new BlobDesc(opaque_header_pod_desc, header_byte_size, Shape({total_elem_cnt}), + sole_data_type, max_col_num)); } } else { - ret.reset( - new BlobDesc(header_byte_size, Shape({body_byte_size}), DataType::kChar, max_col_num)); + ret.reset(new BlobDesc(opaque_header_pod_desc, header_byte_size, Shape({body_byte_size}), + DataType::kChar, max_col_num)); } return ret; } diff --git a/oneflow/core/register/blob_desc.h b/oneflow/core/register/blob_desc.h index bedc0cce0d..5bc19a847e 100644 --- a/oneflow/core/register/blob_desc.h +++ b/oneflow/core/register/blob_desc.h @@ -5,6 +5,7 @@ #include "oneflow/core/common/shape.h" #include "oneflow/core/register/field_desc.h" #include "oneflow/core/register/blob_desc.pb.h" +#include "oneflow/core/register/pod_desc.h" #include "oneflow/core/job/job_desc.h" namespace oneflow { @@ -18,7 +19,8 @@ class BlobDesc { BlobDesc(const Shape&, DataType, bool has_data_id, bool has_col_num, int32_t max_col_num); BlobDesc(const Shape& shape) : body_field_(shape) {} BlobDesc(const BlobDescProto& proto); - BlobDesc(int64_t header_byte_size, const Shape&, DataType, int32_t max_col_num); + BlobDesc(const StructPodDesc& header_pod_desc, int64_t header_byte_size, const Shape&, DataType, + int32_t max_col_num); const Shape& shape() const { return body_field_.shape(); } Shape& mut_shape() { return body_field_.mut_shape(); } @@ -46,11 +48,12 @@ class BlobDesc { private: void HeaderToProto(BlobDescProto* proto) const; - void DataIdFieldToProto(FieldHeaderDesc* proto) const; - void ColNumFieldToProto(FieldHeaderDesc* proto) const; + void DataIdFieldToProto(FieldHeaderDesc* proto, StructPodDesc* header_pod_desc) const; + void ColNumFieldToProto(FieldHeaderDesc* proto, StructPodDesc* header_pod_desc) const; bool header_is_opaque_; FieldDesc opaque_header_; + StructPodDesc opaque_header_pod_desc_; bool has_data_id_; bool has_col_num_; diff --git a/oneflow/core/register/blob_desc.proto b/oneflow/core/register/blob_desc.proto index b681edf7aa..c32ce40386 100644 --- a/oneflow/core/register/blob_desc.proto +++ b/oneflow/core/register/blob_desc.proto @@ -2,6 +2,7 @@ syntax = "proto2"; package oneflow; import "oneflow/core/register/field_desc.proto"; +import "oneflow/core/register/pod.proto"; message FieldHeaderDesc { optional FieldDescProto data_id = 1; @@ -15,6 +16,7 @@ message BlobHeaderDescProto { FieldDescProto opaque_header = 3; FieldHeaderDesc field_header = 4; } + required StructPodProto header_pod_desc = 5; } message BlobDescProto { diff --git a/oneflow/core/register/logical_blob_id.proto b/oneflow/core/register/logical_blob_id.proto new file mode 100644 index 0000000000..64a5f579ce --- /dev/null +++ b/oneflow/core/register/logical_blob_id.proto @@ -0,0 +1,9 @@ +syntax = "proto2"; +package oneflow; + +message LogicalBlobId { + optional string op_name = 1; + optional string blob_name = 2; + optional int32 clone_id = 4 [default = -1]; + optional bool is_packed_id = 5 [default = false]; +} diff --git a/oneflow/core/register/pod.proto b/oneflow/core/register/pod.proto new file mode 100644 index 0000000000..556682baa5 --- /dev/null +++ b/oneflow/core/register/pod.proto @@ -0,0 +1,41 @@ +syntax = "proto2"; +package oneflow; + +import "oneflow/core/common/shape.proto"; +import "oneflow/core/common/data_type.proto"; +import "oneflow/core/register/logical_blob_id.proto"; + +message TensorPodProto { + required ShapeProto shape = 1; + required DataType data_type = 2; +} + +message StructPodProto { + repeated FieldPodProto field = 1; +} + +enum FieldKey { + kInvalidFieldKey = 0; + kDataId = 1; + kColNum = 2; +} + +message FieldId { + oneof field_id_type { + FieldKey key = 1; + LogicalBlobId lbi = 2; + } +} + +message FieldPodProto { + required FieldId field_id = 1; + required int32 alignment = 2; + required PodProto pod = 3; +} + +message PodProto { + oneof pod_type { + TensorPodProto tensor_pod = 1; + StructPodProto struct_pod = 2; + } +} diff --git a/oneflow/core/register/pod_desc.cpp b/oneflow/core/register/pod_desc.cpp new file mode 100644 index 0000000000..404f7ee733 --- /dev/null +++ b/oneflow/core/register/pod_desc.cpp @@ -0,0 +1,178 @@ +#include "oneflow/core/register/pod_desc.h" + +namespace oneflow { + +namespace { + +std::unique_ptr NewPodDesc(const PodProto& pod) { + if (pod.has_tensor_pod()) { return std::make_unique(pod.tensor_pod()); } + if (pod.has_struct_pod()) { return std::make_unique(pod.struct_pod()); } + // ignore field pod + UNIMPLEMENTED(); + return std::unique_ptr(); +} + +} // namespace + +FieldId NewFieldId(FieldKey key) { + FieldId ret; + ret.set_key(key); + return ret; +} + +FieldId NewFieldId(const LogicalBlobId& lbi) { + FieldId ret; + *ret.mutable_lbi() = lbi; + return ret; +} + +TensorPodDesc::TensorPodDesc(const TensorPodProto& tensor_pod) { InitFromProto(tensor_pod); } + +TensorPodDesc::TensorPodDesc(const TensorPodDesc& tensor_pod) { + PodProto pod_proto; + tensor_pod.ToProto(&pod_proto); + InitFromProto(pod_proto.tensor_pod()); +} + +void TensorPodDesc::InitFromProto(const TensorPodProto& tensor_pod) { + shape_ = Shape(tensor_pod.shape()); + data_type_ = tensor_pod.data_type(); +} + +size_t TensorPodDesc::ByteSize() const { return shape_.elem_cnt() * GetSizeOfDataType(data_type_); } + +bool TensorPodDesc::operator==(const PodDesc& rhs) const { + const auto* tensor_rhs = dynamic_cast(&rhs); + if (tensor_rhs == nullptr) { return false; } + return shape() == tensor_rhs->shape() && data_type() == tensor_rhs->data_type(); +} + +void TensorPodDesc::ToProto(PodProto* pod_proto) const { + shape_.ToProto(pod_proto->mutable_tensor_pod()->mutable_shape()); + pod_proto->mutable_tensor_pod()->set_data_type(data_type_); +} + +FieldPodDesc::FieldPodDesc(const FieldPodProto& field_pod) { + field_id_ = field_pod.field_id(); + pod_ = std::move(NewPodDesc(field_pod.pod())); + alignment_ = field_pod.alignment(); +} + +size_t FieldPodDesc::ByteSize() const { return RoundUp(pod_->ByteSize(), alignment_); } + +bool FieldPodDesc::operator==(const PodDesc& rhs) const { + const auto* field_rhs = dynamic_cast(&rhs); + if (field_rhs == nullptr) { return false; } + return field_id() == field_rhs->field_id() && pod() == field_rhs->pod() + && alignment_ == field_rhs->alignment_; +} + +void FieldPodDesc::ToProto(FieldPodProto* field_pod_proto) const { + *field_pod_proto->mutable_field_id() = field_id_; + field_pod_proto->set_alignment(alignment_); + pod_->ToProto(field_pod_proto->mutable_pod()); +} + +StructPodDesc::StructPodDesc(const StructPodProto& struct_pod_proto) { + InitFromProto(struct_pod_proto); +} + +StructPodDesc::StructPodDesc(const StructPodDesc& struct_pod_desc) { *this = struct_pod_desc; } + +void StructPodDesc::InitFromProto(const StructPodProto& struct_pod) { + CHECK(field_id2field_idx_.empty()); + CHECK(fields_.empty()); + for (const auto& field : struct_pod.field()) { + std::unique_ptr pod(new FieldPodDesc(field)); + AddField(std::move(pod)); + } +} + +size_t StructPodDesc::ByteSize() const { + size_t size = 0; + for (const auto& field : fields_) { size += field->ByteSize(); } + return size; +} + +bool StructPodDesc::operator==(const PodDesc& rhs) const { + const auto* struct_rhs = dynamic_cast(&rhs); + if (struct_rhs == nullptr) { return false; } + if (field_id2field_idx_ != struct_rhs->field_id2field_idx_) { return false; } + for (int i = 0; i < field_id2field_idx_.size(); ++i) { + if (*fields_.at(i) != *struct_rhs->fields_.at(i)) { return false; } + } + return true; +} + +void StructPodDesc::ToProto(StructPodProto* struct_pod_proto) const { + for (const auto& field : fields_) { field->ToProto(struct_pod_proto->add_field()); } +} + +bool StructPodDesc::HasField(const FieldId& field_id) const { + return field_id2field_idx_.find(field_id) != field_id2field_idx_.end(); +} + +StructPodDesc* StructPodDesc::MutStructField(const FieldId& field_id) { + return MutStructField(field_id, 1); +} + +StructPodDesc* StructPodDesc::MutStructField(const FieldId& field_id, int32_t alignment) { + if (!HasField(field_id)) { AddField(field_id, std::make_unique(), alignment); } + return MutExistedField(field_id)->MutCast(); +} + +PodDesc* StructPodDesc::MutExistedField(const FieldId& field_id) { + return fields_.at(field_id2field_idx_.at(field_id))->mut_pod(); +} + +const PodDesc& StructPodDesc::Field(const FieldId& field_id) const { + return fields_.at(field_id2field_idx_.at(field_id))->pod(); +} + +void StructPodDesc::AddField(FieldKey field_key, const PodDesc& pod_desc) { + return AddField(NewFieldId(field_key), pod_desc); +} + +void StructPodDesc::AddField(const FieldId& field_id, const PodDesc& pod_desc) { + return AddField(field_id, pod_desc, 1); +} + +void StructPodDesc::AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment) { + AddField(field_id, pod_desc.Clone(), alignment); +} + +void StructPodDesc::AddField(const FieldId& field_id, std::unique_ptr&& field, + size_t alignment) { + auto* pod = new FieldPodDesc(field_id, std::move(field), alignment); + AddField(std::unique_ptr(pod)); +} + +void StructPodDesc::AddField(std::unique_ptr&& field) { + CHECK(field_id2field_idx_.emplace(field->field_id(), fields_.size()).second); + fields_.emplace_back(std::move(field)); +} + +size_t StructPodDesc::ByteOffset4Field(const FieldId& field_id) const { + CHECK(HasField(field_id)); + size_t offset = 0; + for (int32_t i = 0; i < field_id2field_idx_.at(field_id); ++i) { + offset += fields_.at(i)->ByteSize(); + } + return offset; +} + +StructPodDesc& StructPodDesc::operator=(const StructPodDesc& struct_pod_desc) { + Clear(); + StructPodProto struct_pod_proto; + struct_pod_desc.ToProto(&struct_pod_proto); + InitFromProto(struct_pod_proto); + return *this; +} + +void StructPodDesc::Clear() { + CHECK_EQ(fields_.size(), field_id2field_idx_.size()); + fields_.clear(); + field_id2field_idx_.clear(); +} + +} // namespace oneflow diff --git a/oneflow/core/register/pod_desc.h b/oneflow/core/register/pod_desc.h new file mode 100644 index 0000000000..19444c2077 --- /dev/null +++ b/oneflow/core/register/pod_desc.h @@ -0,0 +1,155 @@ +#ifndef ONEFLOW_CORE_REGISTER_POD_DESC_H_ +#define ONEFLOW_CORE_REGISTER_POD_DESC_H_ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/protobuf.h" +#include "oneflow/core/common/shape.h" +#include "oneflow/core/register/pod.pb.h" + +namespace std { + +template<> +struct hash { + size_t operator()(const oneflow::FieldId& field_id) const { + if (field_id.has_key()) { return std::hash()(field_id.key()); } + if (field_id.has_lbi()) { return std::hash()(field_id.lbi()); } + UNIMPLEMENTED(); + } +}; + +} // namespace std + +namespace oneflow { + +FieldId NewFieldId(FieldKey key); +FieldId NewFieldId(const LogicalBlobId& lbi); +inline bool operator==(const FieldId& lhs, const FieldId& rhs) { + PbMd message_diff; + return message_diff.Equivalent(lhs, rhs); +} + +class PodDesc { + public: + OF_DISALLOW_COPY_AND_MOVE(PodDesc); + PodDesc() = default; + virtual ~PodDesc() = default; + + template + const T& Cast() const; + template + T* MutCast(); + + virtual size_t ByteSize() const = 0; + virtual void ToProto(PodProto* pod_proto) const = 0; + virtual std::unique_ptr Clone() const = 0; + virtual bool operator==(const PodDesc& rhs) const = 0; + bool operator!=(const PodDesc& rhs) const { return !(*this == rhs); } +}; + +class TensorPodDesc final : public PodDesc { + public: + TensorPodDesc() = default; + TensorPodDesc(const Shape& shape, DataType data_type) : shape_(shape), data_type_(data_type) {} + explicit TensorPodDesc(const TensorPodProto& shape_pod_proto); + explicit TensorPodDesc(const TensorPodDesc& shape_pod); + ~TensorPodDesc() = default; + const Shape& shape() const { return shape_; } + DataType data_type() const { return data_type_; } + Shape* mut_shape() { return &shape_; } + void set_data_type(DataType data_type) { data_type_ = data_type; } + + void InitFromProto(const TensorPodProto& shape_pod); + + size_t ByteSize() const override; + void ToProto(PodProto* pod_proto) const override; + std::unique_ptr Clone() const override { return std::make_unique(*this); } + bool operator==(const PodDesc& rhs) const override; + + private: + Shape shape_; + DataType data_type_; +}; + +class FieldPodDesc; + +class StructPodDesc final : public PodDesc { + public: + StructPodDesc() = default; + explicit StructPodDesc(const StructPodProto&); + explicit StructPodDesc(const StructPodDesc&); + ~StructPodDesc() = default; + + StructPodDesc* MutStructField(const FieldId& field_id); + const PodDesc& Field(FieldKey field_key) const { return Field(NewFieldId(field_key)); } + const PodDesc& Field(const FieldId& field_id) const; + void AddField(FieldKey field_key, const PodDesc& pod_desc); + void AddField(const FieldId& field_id, const PodDesc& pod_desc); + size_t ByteSize() const override; + void InitFromProto(const StructPodProto& struct_pod); + + bool HasField(FieldKey field_key) const { return HasField(NewFieldId(field_key)); } + bool HasField(const FieldId& field_id) const; + StructPodDesc& operator=(const StructPodDesc&); + std::unique_ptr Clone() const override { return std::make_unique(*this); } + void ToProto(PodProto* pod_proto) const override { ToProto(pod_proto->mutable_struct_pod()); } + void ToProto(StructPodProto* pod_proto) const; + StructPodDesc* MutStructField(const FieldId& field_id, int32_t default_alignment); + void AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment); + bool operator==(const PodDesc& rhs) const override; + size_t ByteOffset4Field(const FieldId& field_name) const; + + private: + void Clear(); + PodDesc* MutExistedField(const FieldId& field_id); + void AddField(std::unique_ptr&& field); + void AddField(const FieldId& field_id, std::unique_ptr&& field); + void AddField(const FieldId& field_id, std::unique_ptr&& field, size_t alignment); + + std::vector> fields_; + HashMap field_id2field_idx_; +}; + +class FieldPodDesc final : public PodDesc { + public: + OF_DISALLOW_COPY_AND_MOVE(FieldPodDesc); + ~FieldPodDesc() = default; + + private: + friend class StructPodDesc; + FieldPodDesc(const FieldId& field_id, std::unique_ptr&& pod, size_t alignment) + : PodDesc(), field_id_(field_id), pod_(std::move(pod)), alignment_(alignment) {} + explicit FieldPodDesc(const FieldPodProto& field_pod_proto); + + size_t ByteSize() const override; + void ToProto(PodProto* pod_proto) const override { UNIMPLEMENTED(); } + std::unique_ptr Clone() const override { UNIMPLEMENTED(); } + void ToProto(FieldPodProto* field_proto) const; + bool operator==(const PodDesc& rhs) const override; + + const PodDesc& pod() const { return *pod_; } + const FieldId& field_id() const { return field_id_; } + PodDesc* mut_pod() { return pod_.get(); } + + FieldId field_id_; + std::unique_ptr pod_; + size_t alignment_; +}; + +template +const T& PodDesc::Cast() const { + static_assert(std::is_same::value || std::is_same::value, + "only TensorPodDesc and StructPodDesc supported"); + return *dynamic_cast(this); +} + +template +T* PodDesc::MutCast() { + static_assert(std::is_same::value || std::is_same::value, + "only TensorPodDesc and StructPodDesc supported"); + return dynamic_cast(this); +} + +} // namespace oneflow + +#endif // ONEFLOW_CORE_REGISTER_POD_DESC_H_ diff --git a/oneflow/core/register/pod_ptr.cpp b/oneflow/core/register/pod_ptr.cpp new file mode 100644 index 0000000000..ef64ecca51 --- /dev/null +++ b/oneflow/core/register/pod_ptr.cpp @@ -0,0 +1,22 @@ +#include "oneflow/core/register/pod_ptr.h" + +namespace oneflow { + +PodPtr PodPtrField(const PodDesc* pod_desc, const FieldId& field_id, char* ptr) { + const auto* struct_pod = dynamic_cast(pod_desc); + CHECK_NOTNULL(struct_pod); + return PodPtr(struct_pod->Field(field_id), ptr + struct_pod->ByteOffset4Field(field_id)); +} + +bool PodPtr::HasField(const FieldId& field_id) const { + const auto* struct_pod = dynamic_cast(pod_desc_); + return struct_pod && struct_pod->HasField(field_id); +} + +const PodPtr PodPtr::Field(const FieldId& field_id) const { + return PodPtrField(pod_desc_, field_id, ptr_); +} + +PodPtr PodPtr::MutField(const FieldId& field_id) { return PodPtrField(pod_desc_, field_id, ptr_); } + +} // namespace oneflow diff --git a/oneflow/core/register/pod_ptr.h b/oneflow/core/register/pod_ptr.h new file mode 100644 index 0000000000..3728c36a7f --- /dev/null +++ b/oneflow/core/register/pod_ptr.h @@ -0,0 +1,72 @@ +#ifndef ONEFLOW_CORE_REGISTER_POD_PTR_H_ +#define ONEFLOW_CORE_REGISTER_POD_PTR_H_ +#include "oneflow/core/register/pod_desc.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/data_type.h" + +namespace oneflow { + +class PodPtr final { + public: + PodPtr(const PodDesc& pod_desc, char* ptr) : pod_desc_(&pod_desc), ptr_(ptr) {} + ~PodPtr() = default; + + template + const T* TensorPtr() const; + template + const T* TensorPtr(FieldKey field_key, const T* default_ptr) const; + + template + T* MutTensorPtr(); + template + T* MutTensorPtr(FieldKey field_key, T* default_ptr); + + const PodDesc& pod_desc() const { return *pod_desc_; } + char* ptr() const { return ptr_; } + bool HasField(FieldKey field_key) const { return HasField(NewFieldId(field_key)); } + const PodPtr Field(FieldKey field_key) const { return Field(NewFieldId(field_key)); } + PodPtr MutField(FieldKey field_key) { return MutField(NewFieldId(field_key)); } + + bool HasField(const FieldId& field_id) const; + const PodPtr Field(const FieldId& field_id) const; + PodPtr MutField(const FieldId& field_id); + + private: + template + void CheckDataType() { + const auto* tensor_pod = dynamic_cast(pod_desc_); + CHECK_NOTNULL(tensor_pod); + CHECK_EQ(tensor_pod->data_type(), GetDataType::value); + } + + const PodDesc* const pod_desc_; + char* const ptr_; +}; + +template +const T* PodPtr::TensorPtr(FieldKey field_key, const T* default_ptr) const { + if (!HasField(field_key)) { return default_ptr; } + return Field(field_key).TensorPtr(); +} + +template +T* PodPtr::MutTensorPtr(FieldKey field_key, T* default_ptr) { + if (!HasField(field_key)) { return default_ptr; } + return MutField(field_key).MutTensorPtr(); +} + +template +const T* PodPtr::TensorPtr() const { + CheckDataType(); + return reinterpret_cast(ptr_); +} + +template +T* PodPtr::MutTensorPtr() { + CheckDataType(); + return reinterpret_cast(ptr_); +} + +} // namespace oneflow + +#endif // ONEFLOW_CORE_REGISTER_POD_PTR_H_ diff --git a/oneflow/core/register/register_desc.proto b/oneflow/core/register/register_desc.proto index 6d8d504e46..f8a4c6fff1 100644 --- a/oneflow/core/register/register_desc.proto +++ b/oneflow/core/register/register_desc.proto @@ -2,8 +2,8 @@ syntax = "proto2"; package oneflow; import "oneflow/core/register/blob_desc.proto"; +import "oneflow/core/register/logical_blob_id.proto"; import "oneflow/core/memory/memory_case.proto"; -import "oneflow/core/operator/op_conf.proto"; message LbiBlobDescPair { required LogicalBlobId lbi = 1; diff --git a/oneflow/core/register/runtime_blob_desc.cpp b/oneflow/core/register/runtime_blob_desc.cpp index ebe720653c..c495db8808 100644 --- a/oneflow/core/register/runtime_blob_desc.cpp +++ b/oneflow/core/register/runtime_blob_desc.cpp @@ -13,56 +13,29 @@ RtBlobDesc::RtBlobDesc(const BlobDescProto& blob_desc_proto) { InitFromProto(blo void RtBlobDesc::InitFromProto(const BlobDescProto& blob_desc_proto) { blob_desc_proto_ = blob_desc_proto; body_desc_ = FieldDesc(blob_desc_proto.body()); - if (blob_desc_proto.header().has_opaque_header()) { - CHECK(header_desc_.emplace("opaque_header", FieldDesc(blob_desc_proto.header().opaque_header())) - .second); - } else { - CHECK(blob_desc_proto.header().has_field_header()); - if (blob_desc_proto.header().field_header().has_data_id()) { - CHECK(header_desc_ - .emplace("data_id", FieldDesc(blob_desc_proto.header().field_header().data_id())) - .second); - } - if (blob_desc_proto.header().field_header().has_col_num()) { - CHECK(header_desc_ - .emplace("col_num", FieldDesc(blob_desc_proto.header().field_header().col_num())) - .second); - } - } + header_pod_desc_.InitFromProto(blob_desc_proto.header().header_pod_desc()); } const Shape& RtBlobDesc::shape() const { return body_desc_.shape(); } DataType RtBlobDesc::data_type() const { return body_desc_.data_type(); } -const Shape& RtBlobDesc::shape(const std::string& field_name) const { - auto field_it = GetFieldIteratorOrFail(field_name); - return field_it->second.shape(); -} - -DataType RtBlobDesc::data_type(const std::string& field_name) const { - auto field_it = GetFieldIteratorOrFail(field_name); - return field_it->second.data_type(); -} +bool RtBlobDesc::has_data_id_field() const { return header_pod_desc_.HasField(FieldKey::kDataId); } -bool RtBlobDesc::has_data_id_field() const { return HasField("data_id"); } +bool RtBlobDesc::has_col_num_field() const { return header_pod_desc_.HasField(FieldKey::kColNum); } -bool RtBlobDesc::has_col_num_field() const { return HasField("col_num"); } - -size_t RtBlobDesc::ByteSizeOfBlobHeader() const { - size_t header_size = 0; - for (auto& pair : header_desc_) { header_size += ByteSizeOfField(pair.first); } - return header_size; -} +size_t RtBlobDesc::ByteSizeOfBlobHeader() const { return header_pod_desc_.ByteSize(); } size_t RtBlobDesc::ByteSizeOfBlobBody() const { return body_desc_.AlignedByteSize(); } size_t RtBlobDesc::ByteSizeOfDataIdField() const { - return HasField("data_id") ? ByteSizeOfField("data_id") : 0; + if (!has_data_id_field()) { return 0; } + return header_pod_desc_.Field(FieldKey::kDataId).ByteSize(); } size_t RtBlobDesc::ByteSizeOfColNumField() const { - return HasField("col_num") ? ByteSizeOfField("col_num") : 0; + if (!has_col_num_field()) { return 0; } + return header_pod_desc_.Field(FieldKey::kColNum).ByteSize(); } size_t RtBlobDesc::ByteSizeOfDataContentField() const { return body_desc_.ByteSize(); } @@ -74,25 +47,4 @@ bool RtBlobDesc::operator==(const RtBlobDesc& rhs) const { return message_diff.Equals(blob_desc_proto_, rhs.blob_desc_proto()); } -HashMap::const_iterator RtBlobDesc::GetFieldIteratorOrFail( - const std::string& field_name) const { - auto field_it = header_desc_.find(field_name); - CHECK(field_it != header_desc_.end()); - return field_it; -} - -bool RtBlobDesc::HasField(const std::string& field_name) const { - return header_desc_.find(field_name) != header_desc_.end(); -} - -size_t RtBlobDesc::ByteSizeOfField(const std::string& field_name) const { - auto field_it = GetFieldIteratorOrFail(field_name); - return field_it->second.ByteSize(); -} - -size_t RtBlobDesc::AlignedByteSizeOfField(const std::string& field_name) const { - auto field_it = GetFieldIteratorOrFail(field_name); - return field_it->second.AlignedByteSize(); -} - } // namespace oneflow diff --git a/oneflow/core/register/runtime_blob_desc.h b/oneflow/core/register/runtime_blob_desc.h index 06c086e607..202055fc81 100644 --- a/oneflow/core/register/runtime_blob_desc.h +++ b/oneflow/core/register/runtime_blob_desc.h @@ -20,11 +20,10 @@ class RtBlobDesc { const BlobDescProto& blob_desc_proto() const { return blob_desc_proto_; } const Shape& shape() const; // body shape DataType data_type() const; // body data type - const Shape& shape(const std::string& field_name) const; - DataType data_type(const std::string& field_name) const; bool has_data_id_field() const; bool has_col_num_field() const; + const StructPodDesc& header_pod_desc() const { return header_pod_desc_; } int32_t max_col_num() const { return blob_desc_proto_.header().max_col_num(); } @@ -40,15 +39,10 @@ class RtBlobDesc { private: void InitFromProto(const BlobDescProto& proto); - HashMap::const_iterator GetFieldIteratorOrFail( - const std::string& field_name) const; - bool HasField(const std::string& field_name) const; - size_t ByteSizeOfField(const std::string& field_name) const; - size_t AlignedByteSizeOfField(const std::string& field_name) const; BlobDescProto blob_desc_proto_; - HashMap header_desc_; FieldDesc body_desc_; + StructPodDesc header_pod_desc_; }; } // namespace oneflow -- GitLab