提交 1c29eb42 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

Dev pod desc (#1268)

* available instance num

* import shape.proto

* PodProto

* rename message

* union pod is useless

* PodPtr

* rename: PodPtr::get() => PodPtr::Get()

* BlobDescProto.pod

* mv register_desc.time_shape into another pr

* pod_helper.h

* FieldAlignedByteSize

* pod_desc

* PodDesc copy constructor

* BlobDesc::body_shape_pod_desc_

* add BlobDesc::opaque_header_pod_desc_

* align_shift => alignment

* default alignment

* add field Blob::header_pod_ptr_

* rename AlignedFieldPodProto => FieldPodProto

* bugfix

* check

* FieldId

* simplify RtBlobDesc

* simplify Blob

* ShapedPod => TensorPod

* refine ComputePackedBlobDesc


Former-commit-id: 8800da93
上级 09761973
......@@ -3,6 +3,7 @@ package oneflow;
import "oneflow/core/common/shape.proto";
import "oneflow/core/common/data_type.proto";
import "oneflow/core/register/logical_blob_id.proto";
import "oneflow/core/operator/op_conf.proto";
message ConvKernelConf {
......
......@@ -5,6 +5,7 @@ import "oneflow/core/common/shape.proto";
import "oneflow/core/common/data_type.proto";
import "oneflow/core/record/image.proto";
import "oneflow/core/job/resource.proto";
import "oneflow/core/register/logical_blob_id.proto";
enum ActivationType {
kNone = 0;
......@@ -259,13 +260,6 @@ message BoxSplitConf {
message BoxCloneConf {
}
message LogicalBlobId {
optional string op_name = 1;
optional string blob_name = 2;
optional int32 clone_id = 4 [default = -1];
optional bool is_packed_id = 5 [default = false];
}
message BoxingOpConf {
required LogicalBlobId lbi = 1;
required int32 in_num = 2;
......
......@@ -6,35 +6,23 @@
namespace oneflow {
Blob::Blob(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr) {
Blob::Blob(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr)
: header_pod_ptr_(blob_desc->header_pod_desc(), header_ptr) {
Init(regst, blob_desc, header_ptr, header_ptr + blob_desc->ByteSizeOfBlobHeader());
}
Blob::Blob(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr, char* body_ptr) {
Blob::Blob(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr, char* body_ptr)
: header_pod_ptr_(blob_desc->header_pod_desc(), header_ptr) {
Init(regst, blob_desc, header_ptr, body_ptr);
}
void Blob::Init(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr, char* body_ptr) {
if (body_ptr == header_ptr + blob_desc->ByteSizeOfBlobHeader()) {
is_contiguous_ = true;
} else {
is_contiguous_ = false;
}
is_contiguous_ = (body_ptr == header_ptr + blob_desc->ByteSizeOfBlobHeader());
regst_ = regst;
blob_desc_ = blob_desc;
header_ptr_ = header_ptr;
if (blob_desc->has_data_id_field()) {
data_id_ptr_ = header_ptr;
} else {
data_id_ptr_ = nullptr;
}
char* offset = header_ptr + blob_desc->ByteSizeOfDataIdField();
if (blob_desc->has_col_num_field()) {
col_num_ptr_ = reinterpret_cast<int32_t*>(offset);
} else {
col_num_ptr_ = nullptr;
}
data_id_ptr_ = header_pod_ptr_.MutTensorPtr<char>(FieldKey::kDataId, nullptr);
col_num_ptr_ = header_pod_ptr_.MutTensorPtr<int32_t>(FieldKey::kColNum, nullptr);
dptr_ = body_ptr;
}
......
......@@ -9,6 +9,7 @@
#include "oneflow/core/persistence/persistent_in_stream.h"
#include "oneflow/core/record/record.pb.h"
#include "oneflow/core/record/record_io.h"
#include "oneflow/core/register/pod_ptr.h"
namespace oneflow {
......@@ -87,6 +88,8 @@ class Blob final {
void set_max_col_id(int32_t val);
bool IsColValid() const { return col_id() <= max_col_id(); }
const MemoryCase& mem_case() const;
const PodPtr* header_pod_ptr() const { return &header_pod_ptr_; }
PodPtr* header_pod_ptr() { return &header_pod_ptr_; }
private:
int64_t GetDptrOffset(int32_t index) const { return 0; }
......@@ -114,6 +117,7 @@ class Blob final {
void* dptr_;
const RtBlobDesc* blob_desc_;
Regst* regst_;
PodPtr header_pod_ptr_;
};
template<typename RecordType>
......
......@@ -23,6 +23,7 @@ BlobDesc::BlobDesc(const BlobDescProto& proto) : body_field_(proto.body()) {
has_data_id_ = false;
has_col_num_ = false;
opaque_header_ = FieldDesc(proto.header().opaque_header());
opaque_header_pod_desc_.InitFromProto(proto.header().header_pod_desc());
} else {
CHECK(proto.header().has_field_header());
header_is_opaque_ = false;
......@@ -31,16 +32,18 @@ BlobDesc::BlobDesc(const BlobDescProto& proto) : body_field_(proto.body()) {
}
}
BlobDesc::BlobDesc(int64_t header_byte_size, const Shape& shape, DataType data_type,
int32_t max_col_num)
BlobDesc::BlobDesc(const StructPodDesc& header_pod_desc, int64_t header_byte_size,
const Shape& shape, DataType data_type, int32_t max_col_num)
: has_data_id_(false),
has_col_num_(false),
max_col_num_(max_col_num),
blob_mem_id_(-1),
body_field_(shape, data_type) {
CHECK_EQ(header_pod_desc.ByteSize(), header_byte_size);
if (header_byte_size > 0) {
header_is_opaque_ = true;
opaque_header_ = FieldDesc(Shape({header_byte_size}), DataType::kChar);
opaque_header_pod_desc_ = header_pod_desc;
} else {
header_is_opaque_ = false;
}
......@@ -54,16 +57,19 @@ void BlobDesc::set_has_col_num_field(bool val) {
CHECK(!header_is_opaque_);
has_col_num_ = val;
}
void BlobDesc::DataIdFieldToProto(FieldHeaderDesc* proto) const {
FieldDesc data_id_field(Shape({body_field_.shape().At(0),
static_cast<int64_t>(Global<JobDesc>::Get()->SizeOfOneDataId())}),
DataType::kChar);
void BlobDesc::DataIdFieldToProto(FieldHeaderDesc* proto, StructPodDesc* header_pod_desc) const {
Shape shape(
{body_field_.shape().At(0), static_cast<int64_t>(Global<JobDesc>::Get()->SizeOfOneDataId())});
FieldDesc data_id_field(shape, DataType::kChar);
data_id_field.ToProto(proto->mutable_data_id());
header_pod_desc->AddField(FieldKey::kDataId, TensorPodDesc(shape, DataType::kChar));
}
void BlobDesc::ColNumFieldToProto(FieldHeaderDesc* proto) const {
FieldDesc col_num_field(Shape({body_field_.shape().At(0)}), DataType::kInt32);
void BlobDesc::ColNumFieldToProto(FieldHeaderDesc* proto, StructPodDesc* header_pod_desc) const {
Shape shape({body_field_.shape().At(0)});
FieldDesc col_num_field(shape, DataType::kInt32);
col_num_field.ToProto(proto->mutable_col_num());
header_pod_desc->AddField(FieldKey::kColNum, TensorPodDesc(shape, DataType::kInt32));
}
void BlobDesc::HeaderToProto(BlobDescProto* proto) const {
......@@ -71,10 +77,13 @@ void BlobDesc::HeaderToProto(BlobDescProto* proto) const {
proto->mutable_header()->set_blob_mem_id(blob_mem_id_);
if (!header_is_opaque_) {
FieldHeaderDesc* field_header = proto->mutable_header()->mutable_field_header();
if (has_data_id_field()) { DataIdFieldToProto(field_header); }
if (has_col_num_field()) { ColNumFieldToProto(field_header); }
StructPodDesc header_pod_desc;
if (has_data_id_field()) { DataIdFieldToProto(field_header, &header_pod_desc); }
if (has_col_num_field()) { ColNumFieldToProto(field_header, &header_pod_desc); }
header_pod_desc.ToProto(proto->mutable_header()->mutable_header_pod_desc());
} else {
opaque_header_.ToProto(proto->mutable_header()->mutable_opaque_header());
opaque_header_pod_desc_.ToProto(proto->mutable_header()->mutable_header_pod_desc());
}
}
......@@ -85,6 +94,7 @@ void BlobDesc::ToProto(BlobDescProto* proto) const {
bool BlobDesc::operator==(const BlobDesc& rhs) const {
return header_is_opaque_ == rhs.header_is_opaque_ && opaque_header_ == rhs.opaque_header_
&& opaque_header_pod_desc_ == rhs.opaque_header_pod_desc_
&& has_data_id_ == rhs.has_data_id_ && has_col_num_ == rhs.has_col_num_
&& max_col_num_ == rhs.max_col_num_ && blob_mem_id_ == rhs.blob_mem_id_
&& body_field_ == rhs.body_field_;
......@@ -93,6 +103,7 @@ bool BlobDesc::operator==(const BlobDesc& rhs) const {
BlobDesc& BlobDesc::operator=(const BlobDesc& blob_desc) {
header_is_opaque_ = blob_desc.header_is_opaque_;
opaque_header_ = blob_desc.opaque_header_;
opaque_header_pod_desc_ = blob_desc.opaque_header_pod_desc_;
has_data_id_ = blob_desc.has_data_id_;
has_col_num_ = blob_desc.has_col_num_;
max_col_num_ = blob_desc.max_col_num_;
......@@ -111,11 +122,12 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
std::unique_ptr<BlobDesc> ret(new BlobDesc());
const BlobDesc* last_blob_desc = nullptr;
HashMap<int32_t, size_t> blob_mem_id2size;
StructPodDesc opaque_header_pod_desc;
for (auto& pair : lbi2blob_desc) {
BlobDesc* blob_desc = pair.second.get();
RtBlobDesc rt_blob_desc(*blob_desc);
header_byte_size += rt_blob_desc.ByteSizeOfBlobHeader();
*opaque_header_pod_desc.MutStructField(NewFieldId(pair.first)) = rt_blob_desc.header_pod_desc();
int64_t cur_body_byte_size = rt_blob_desc.ByteSizeOfBlobBody();
int32_t blob_mem_id = blob_desc->blob_mem_id();
if (blob_mem_id == -1) {
......@@ -150,12 +162,12 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
if (header_byte_size == 0) {
ret.reset(new BlobDesc(Shape({total_elem_cnt}), sole_data_type, false, false, max_col_num));
} else {
ret.reset(
new BlobDesc(header_byte_size, Shape({total_elem_cnt}), sole_data_type, max_col_num));
ret.reset(new BlobDesc(opaque_header_pod_desc, header_byte_size, Shape({total_elem_cnt}),
sole_data_type, max_col_num));
}
} else {
ret.reset(
new BlobDesc(header_byte_size, Shape({body_byte_size}), DataType::kChar, max_col_num));
ret.reset(new BlobDesc(opaque_header_pod_desc, header_byte_size, Shape({body_byte_size}),
DataType::kChar, max_col_num));
}
return ret;
}
......
......@@ -5,6 +5,7 @@
#include "oneflow/core/common/shape.h"
#include "oneflow/core/register/field_desc.h"
#include "oneflow/core/register/blob_desc.pb.h"
#include "oneflow/core/register/pod_desc.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
......@@ -18,7 +19,8 @@ class BlobDesc {
BlobDesc(const Shape&, DataType, bool has_data_id, bool has_col_num, int32_t max_col_num);
BlobDesc(const Shape& shape) : body_field_(shape) {}
BlobDesc(const BlobDescProto& proto);
BlobDesc(int64_t header_byte_size, const Shape&, DataType, int32_t max_col_num);
BlobDesc(const StructPodDesc& header_pod_desc, int64_t header_byte_size, const Shape&, DataType,
int32_t max_col_num);
const Shape& shape() const { return body_field_.shape(); }
Shape& mut_shape() { return body_field_.mut_shape(); }
......@@ -46,11 +48,12 @@ class BlobDesc {
private:
void HeaderToProto(BlobDescProto* proto) const;
void DataIdFieldToProto(FieldHeaderDesc* proto) const;
void ColNumFieldToProto(FieldHeaderDesc* proto) const;
void DataIdFieldToProto(FieldHeaderDesc* proto, StructPodDesc* header_pod_desc) const;
void ColNumFieldToProto(FieldHeaderDesc* proto, StructPodDesc* header_pod_desc) const;
bool header_is_opaque_;
FieldDesc opaque_header_;
StructPodDesc opaque_header_pod_desc_;
bool has_data_id_;
bool has_col_num_;
......
......@@ -2,6 +2,7 @@ syntax = "proto2";
package oneflow;
import "oneflow/core/register/field_desc.proto";
import "oneflow/core/register/pod.proto";
message FieldHeaderDesc {
optional FieldDescProto data_id = 1;
......@@ -15,6 +16,7 @@ message BlobHeaderDescProto {
FieldDescProto opaque_header = 3;
FieldHeaderDesc field_header = 4;
}
required StructPodProto header_pod_desc = 5;
}
message BlobDescProto {
......
syntax = "proto2";
package oneflow;
message LogicalBlobId {
optional string op_name = 1;
optional string blob_name = 2;
optional int32 clone_id = 4 [default = -1];
optional bool is_packed_id = 5 [default = false];
}
syntax = "proto2";
package oneflow;
import "oneflow/core/common/shape.proto";
import "oneflow/core/common/data_type.proto";
import "oneflow/core/register/logical_blob_id.proto";
message TensorPodProto {
required ShapeProto shape = 1;
required DataType data_type = 2;
}
message StructPodProto {
repeated FieldPodProto field = 1;
}
enum FieldKey {
kInvalidFieldKey = 0;
kDataId = 1;
kColNum = 2;
}
message FieldId {
oneof field_id_type {
FieldKey key = 1;
LogicalBlobId lbi = 2;
}
}
message FieldPodProto {
required FieldId field_id = 1;
required int32 alignment = 2;
required PodProto pod = 3;
}
message PodProto {
oneof pod_type {
TensorPodProto tensor_pod = 1;
StructPodProto struct_pod = 2;
}
}
#include "oneflow/core/register/pod_desc.h"
namespace oneflow {
namespace {
std::unique_ptr<PodDesc> NewPodDesc(const PodProto& pod) {
if (pod.has_tensor_pod()) { return std::make_unique<TensorPodDesc>(pod.tensor_pod()); }
if (pod.has_struct_pod()) { return std::make_unique<StructPodDesc>(pod.struct_pod()); }
// ignore field pod
UNIMPLEMENTED();
return std::unique_ptr<PodDesc>();
}
} // namespace
FieldId NewFieldId(FieldKey key) {
FieldId ret;
ret.set_key(key);
return ret;
}
FieldId NewFieldId(const LogicalBlobId& lbi) {
FieldId ret;
*ret.mutable_lbi() = lbi;
return ret;
}
TensorPodDesc::TensorPodDesc(const TensorPodProto& tensor_pod) { InitFromProto(tensor_pod); }
TensorPodDesc::TensorPodDesc(const TensorPodDesc& tensor_pod) {
PodProto pod_proto;
tensor_pod.ToProto(&pod_proto);
InitFromProto(pod_proto.tensor_pod());
}
void TensorPodDesc::InitFromProto(const TensorPodProto& tensor_pod) {
shape_ = Shape(tensor_pod.shape());
data_type_ = tensor_pod.data_type();
}
size_t TensorPodDesc::ByteSize() const { return shape_.elem_cnt() * GetSizeOfDataType(data_type_); }
bool TensorPodDesc::operator==(const PodDesc& rhs) const {
const auto* tensor_rhs = dynamic_cast<const TensorPodDesc*>(&rhs);
if (tensor_rhs == nullptr) { return false; }
return shape() == tensor_rhs->shape() && data_type() == tensor_rhs->data_type();
}
void TensorPodDesc::ToProto(PodProto* pod_proto) const {
shape_.ToProto(pod_proto->mutable_tensor_pod()->mutable_shape());
pod_proto->mutable_tensor_pod()->set_data_type(data_type_);
}
FieldPodDesc::FieldPodDesc(const FieldPodProto& field_pod) {
field_id_ = field_pod.field_id();
pod_ = std::move(NewPodDesc(field_pod.pod()));
alignment_ = field_pod.alignment();
}
size_t FieldPodDesc::ByteSize() const { return RoundUp(pod_->ByteSize(), alignment_); }
bool FieldPodDesc::operator==(const PodDesc& rhs) const {
const auto* field_rhs = dynamic_cast<const FieldPodDesc*>(&rhs);
if (field_rhs == nullptr) { return false; }
return field_id() == field_rhs->field_id() && pod() == field_rhs->pod()
&& alignment_ == field_rhs->alignment_;
}
void FieldPodDesc::ToProto(FieldPodProto* field_pod_proto) const {
*field_pod_proto->mutable_field_id() = field_id_;
field_pod_proto->set_alignment(alignment_);
pod_->ToProto(field_pod_proto->mutable_pod());
}
StructPodDesc::StructPodDesc(const StructPodProto& struct_pod_proto) {
InitFromProto(struct_pod_proto);
}
StructPodDesc::StructPodDesc(const StructPodDesc& struct_pod_desc) { *this = struct_pod_desc; }
void StructPodDesc::InitFromProto(const StructPodProto& struct_pod) {
CHECK(field_id2field_idx_.empty());
CHECK(fields_.empty());
for (const auto& field : struct_pod.field()) {
std::unique_ptr<FieldPodDesc> pod(new FieldPodDesc(field));
AddField(std::move(pod));
}
}
size_t StructPodDesc::ByteSize() const {
size_t size = 0;
for (const auto& field : fields_) { size += field->ByteSize(); }
return size;
}
bool StructPodDesc::operator==(const PodDesc& rhs) const {
const auto* struct_rhs = dynamic_cast<const StructPodDesc*>(&rhs);
if (struct_rhs == nullptr) { return false; }
if (field_id2field_idx_ != struct_rhs->field_id2field_idx_) { return false; }
for (int i = 0; i < field_id2field_idx_.size(); ++i) {
if (*fields_.at(i) != *struct_rhs->fields_.at(i)) { return false; }
}
return true;
}
void StructPodDesc::ToProto(StructPodProto* struct_pod_proto) const {
for (const auto& field : fields_) { field->ToProto(struct_pod_proto->add_field()); }
}
bool StructPodDesc::HasField(const FieldId& field_id) const {
return field_id2field_idx_.find(field_id) != field_id2field_idx_.end();
}
StructPodDesc* StructPodDesc::MutStructField(const FieldId& field_id) {
return MutStructField(field_id, 1);
}
StructPodDesc* StructPodDesc::MutStructField(const FieldId& field_id, int32_t alignment) {
if (!HasField(field_id)) { AddField(field_id, std::make_unique<StructPodDesc>(), alignment); }
return MutExistedField(field_id)->MutCast<StructPodDesc>();
}
PodDesc* StructPodDesc::MutExistedField(const FieldId& field_id) {
return fields_.at(field_id2field_idx_.at(field_id))->mut_pod();
}
const PodDesc& StructPodDesc::Field(const FieldId& field_id) const {
return fields_.at(field_id2field_idx_.at(field_id))->pod();
}
void StructPodDesc::AddField(FieldKey field_key, const PodDesc& pod_desc) {
return AddField(NewFieldId(field_key), pod_desc);
}
void StructPodDesc::AddField(const FieldId& field_id, const PodDesc& pod_desc) {
return AddField(field_id, pod_desc, 1);
}
void StructPodDesc::AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment) {
AddField(field_id, pod_desc.Clone(), alignment);
}
void StructPodDesc::AddField(const FieldId& field_id, std::unique_ptr<PodDesc>&& field,
size_t alignment) {
auto* pod = new FieldPodDesc(field_id, std::move(field), alignment);
AddField(std::unique_ptr<FieldPodDesc>(pod));
}
void StructPodDesc::AddField(std::unique_ptr<FieldPodDesc>&& field) {
CHECK(field_id2field_idx_.emplace(field->field_id(), fields_.size()).second);
fields_.emplace_back(std::move(field));
}
size_t StructPodDesc::ByteOffset4Field(const FieldId& field_id) const {
CHECK(HasField(field_id));
size_t offset = 0;
for (int32_t i = 0; i < field_id2field_idx_.at(field_id); ++i) {
offset += fields_.at(i)->ByteSize();
}
return offset;
}
StructPodDesc& StructPodDesc::operator=(const StructPodDesc& struct_pod_desc) {
Clear();
StructPodProto struct_pod_proto;
struct_pod_desc.ToProto(&struct_pod_proto);
InitFromProto(struct_pod_proto);
return *this;
}
void StructPodDesc::Clear() {
CHECK_EQ(fields_.size(), field_id2field_idx_.size());
fields_.clear();
field_id2field_idx_.clear();
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_REGISTER_POD_DESC_H_
#define ONEFLOW_CORE_REGISTER_POD_DESC_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/register/pod.pb.h"
namespace std {
template<>
struct hash<oneflow::FieldId> {
size_t operator()(const oneflow::FieldId& field_id) const {
if (field_id.has_key()) { return std::hash<int>()(field_id.key()); }
if (field_id.has_lbi()) { return std::hash<oneflow::LogicalBlobId>()(field_id.lbi()); }
UNIMPLEMENTED();
}
};
} // namespace std
namespace oneflow {
FieldId NewFieldId(FieldKey key);
FieldId NewFieldId(const LogicalBlobId& lbi);
inline bool operator==(const FieldId& lhs, const FieldId& rhs) {
PbMd message_diff;
return message_diff.Equivalent(lhs, rhs);
}
class PodDesc {
public:
OF_DISALLOW_COPY_AND_MOVE(PodDesc);
PodDesc() = default;
virtual ~PodDesc() = default;
template<typename T>
const T& Cast() const;
template<typename T>
T* MutCast();
virtual size_t ByteSize() const = 0;
virtual void ToProto(PodProto* pod_proto) const = 0;
virtual std::unique_ptr<PodDesc> Clone() const = 0;
virtual bool operator==(const PodDesc& rhs) const = 0;
bool operator!=(const PodDesc& rhs) const { return !(*this == rhs); }
};
class TensorPodDesc final : public PodDesc {
public:
TensorPodDesc() = default;
TensorPodDesc(const Shape& shape, DataType data_type) : shape_(shape), data_type_(data_type) {}
explicit TensorPodDesc(const TensorPodProto& shape_pod_proto);
explicit TensorPodDesc(const TensorPodDesc& shape_pod);
~TensorPodDesc() = default;
const Shape& shape() const { return shape_; }
DataType data_type() const { return data_type_; }
Shape* mut_shape() { return &shape_; }
void set_data_type(DataType data_type) { data_type_ = data_type; }
void InitFromProto(const TensorPodProto& shape_pod);
size_t ByteSize() const override;
void ToProto(PodProto* pod_proto) const override;
std::unique_ptr<PodDesc> Clone() const override { return std::make_unique<TensorPodDesc>(*this); }
bool operator==(const PodDesc& rhs) const override;
private:
Shape shape_;
DataType data_type_;
};
class FieldPodDesc;
class StructPodDesc final : public PodDesc {
public:
StructPodDesc() = default;
explicit StructPodDesc(const StructPodProto&);
explicit StructPodDesc(const StructPodDesc&);
~StructPodDesc() = default;
StructPodDesc* MutStructField(const FieldId& field_id);
const PodDesc& Field(FieldKey field_key) const { return Field(NewFieldId(field_key)); }
const PodDesc& Field(const FieldId& field_id) const;
void AddField(FieldKey field_key, const PodDesc& pod_desc);
void AddField(const FieldId& field_id, const PodDesc& pod_desc);
size_t ByteSize() const override;
void InitFromProto(const StructPodProto& struct_pod);
bool HasField(FieldKey field_key) const { return HasField(NewFieldId(field_key)); }
bool HasField(const FieldId& field_id) const;
StructPodDesc& operator=(const StructPodDesc&);
std::unique_ptr<PodDesc> Clone() const override { return std::make_unique<StructPodDesc>(*this); }
void ToProto(PodProto* pod_proto) const override { ToProto(pod_proto->mutable_struct_pod()); }
void ToProto(StructPodProto* pod_proto) const;
StructPodDesc* MutStructField(const FieldId& field_id, int32_t default_alignment);
void AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment);
bool operator==(const PodDesc& rhs) const override;
size_t ByteOffset4Field(const FieldId& field_name) const;
private:
void Clear();
PodDesc* MutExistedField(const FieldId& field_id);
void AddField(std::unique_ptr<FieldPodDesc>&& field);
void AddField(const FieldId& field_id, std::unique_ptr<PodDesc>&& field);
void AddField(const FieldId& field_id, std::unique_ptr<PodDesc>&& field, size_t alignment);
std::vector<std::unique_ptr<FieldPodDesc>> fields_;
HashMap<FieldId, int32_t> field_id2field_idx_;
};
class FieldPodDesc final : public PodDesc {
public:
OF_DISALLOW_COPY_AND_MOVE(FieldPodDesc);
~FieldPodDesc() = default;
private:
friend class StructPodDesc;
FieldPodDesc(const FieldId& field_id, std::unique_ptr<PodDesc>&& pod, size_t alignment)
: PodDesc(), field_id_(field_id), pod_(std::move(pod)), alignment_(alignment) {}
explicit FieldPodDesc(const FieldPodProto& field_pod_proto);
size_t ByteSize() const override;
void ToProto(PodProto* pod_proto) const override { UNIMPLEMENTED(); }
std::unique_ptr<PodDesc> Clone() const override { UNIMPLEMENTED(); }
void ToProto(FieldPodProto* field_proto) const;
bool operator==(const PodDesc& rhs) const override;
const PodDesc& pod() const { return *pod_; }
const FieldId& field_id() const { return field_id_; }
PodDesc* mut_pod() { return pod_.get(); }
FieldId field_id_;
std::unique_ptr<PodDesc> pod_;
size_t alignment_;
};
template<typename T>
const T& PodDesc::Cast() const {
static_assert(std::is_same<T, TensorPodDesc>::value || std::is_same<T, StructPodDesc>::value,
"only TensorPodDesc and StructPodDesc supported");
return *dynamic_cast<T*>(this);
}
template<typename T>
T* PodDesc::MutCast() {
static_assert(std::is_same<T, TensorPodDesc>::value || std::is_same<T, StructPodDesc>::value,
"only TensorPodDesc and StructPodDesc supported");
return dynamic_cast<T*>(this);
}
} // namespace oneflow
#endif // ONEFLOW_CORE_REGISTER_POD_DESC_H_
#include "oneflow/core/register/pod_ptr.h"
namespace oneflow {
PodPtr PodPtrField(const PodDesc* pod_desc, const FieldId& field_id, char* ptr) {
const auto* struct_pod = dynamic_cast<const StructPodDesc*>(pod_desc);
CHECK_NOTNULL(struct_pod);
return PodPtr(struct_pod->Field(field_id), ptr + struct_pod->ByteOffset4Field(field_id));
}
bool PodPtr::HasField(const FieldId& field_id) const {
const auto* struct_pod = dynamic_cast<const StructPodDesc*>(pod_desc_);
return struct_pod && struct_pod->HasField(field_id);
}
const PodPtr PodPtr::Field(const FieldId& field_id) const {
return PodPtrField(pod_desc_, field_id, ptr_);
}
PodPtr PodPtr::MutField(const FieldId& field_id) { return PodPtrField(pod_desc_, field_id, ptr_); }
} // namespace oneflow
#ifndef ONEFLOW_CORE_REGISTER_POD_PTR_H_
#define ONEFLOW_CORE_REGISTER_POD_PTR_H_
#include "oneflow/core/register/pod_desc.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/data_type.h"
namespace oneflow {
class PodPtr final {
public:
PodPtr(const PodDesc& pod_desc, char* ptr) : pod_desc_(&pod_desc), ptr_(ptr) {}
~PodPtr() = default;
template<typename T>
const T* TensorPtr() const;
template<typename T>
const T* TensorPtr(FieldKey field_key, const T* default_ptr) const;
template<typename T>
T* MutTensorPtr();
template<typename T>
T* MutTensorPtr(FieldKey field_key, T* default_ptr);
const PodDesc& pod_desc() const { return *pod_desc_; }
char* ptr() const { return ptr_; }
bool HasField(FieldKey field_key) const { return HasField(NewFieldId(field_key)); }
const PodPtr Field(FieldKey field_key) const { return Field(NewFieldId(field_key)); }
PodPtr MutField(FieldKey field_key) { return MutField(NewFieldId(field_key)); }
bool HasField(const FieldId& field_id) const;
const PodPtr Field(const FieldId& field_id) const;
PodPtr MutField(const FieldId& field_id);
private:
template<typename T>
void CheckDataType() {
const auto* tensor_pod = dynamic_cast<const TensorPodDesc*>(pod_desc_);
CHECK_NOTNULL(tensor_pod);
CHECK_EQ(tensor_pod->data_type(), GetDataType<T>::value);
}
const PodDesc* const pod_desc_;
char* const ptr_;
};
template<typename T>
const T* PodPtr::TensorPtr(FieldKey field_key, const T* default_ptr) const {
if (!HasField(field_key)) { return default_ptr; }
return Field(field_key).TensorPtr<T>();
}
template<typename T>
T* PodPtr::MutTensorPtr(FieldKey field_key, T* default_ptr) {
if (!HasField(field_key)) { return default_ptr; }
return MutField(field_key).MutTensorPtr<T>();
}
template<typename T>
const T* PodPtr::TensorPtr() const {
CheckDataType<T>();
return reinterpret_cast<const T*>(ptr_);
}
template<typename T>
T* PodPtr::MutTensorPtr() {
CheckDataType<T>();
return reinterpret_cast<T*>(ptr_);
}
} // namespace oneflow
#endif // ONEFLOW_CORE_REGISTER_POD_PTR_H_
......@@ -2,8 +2,8 @@ syntax = "proto2";
package oneflow;
import "oneflow/core/register/blob_desc.proto";
import "oneflow/core/register/logical_blob_id.proto";
import "oneflow/core/memory/memory_case.proto";
import "oneflow/core/operator/op_conf.proto";
message LbiBlobDescPair {
required LogicalBlobId lbi = 1;
......
......@@ -13,56 +13,29 @@ RtBlobDesc::RtBlobDesc(const BlobDescProto& blob_desc_proto) { InitFromProto(blo
void RtBlobDesc::InitFromProto(const BlobDescProto& blob_desc_proto) {
blob_desc_proto_ = blob_desc_proto;
body_desc_ = FieldDesc(blob_desc_proto.body());
if (blob_desc_proto.header().has_opaque_header()) {
CHECK(header_desc_.emplace("opaque_header", FieldDesc(blob_desc_proto.header().opaque_header()))
.second);
} else {
CHECK(blob_desc_proto.header().has_field_header());
if (blob_desc_proto.header().field_header().has_data_id()) {
CHECK(header_desc_
.emplace("data_id", FieldDesc(blob_desc_proto.header().field_header().data_id()))
.second);
}
if (blob_desc_proto.header().field_header().has_col_num()) {
CHECK(header_desc_
.emplace("col_num", FieldDesc(blob_desc_proto.header().field_header().col_num()))
.second);
}
}
header_pod_desc_.InitFromProto(blob_desc_proto.header().header_pod_desc());
}
const Shape& RtBlobDesc::shape() const { return body_desc_.shape(); }
DataType RtBlobDesc::data_type() const { return body_desc_.data_type(); }
const Shape& RtBlobDesc::shape(const std::string& field_name) const {
auto field_it = GetFieldIteratorOrFail(field_name);
return field_it->second.shape();
}
DataType RtBlobDesc::data_type(const std::string& field_name) const {
auto field_it = GetFieldIteratorOrFail(field_name);
return field_it->second.data_type();
}
bool RtBlobDesc::has_data_id_field() const { return header_pod_desc_.HasField(FieldKey::kDataId); }
bool RtBlobDesc::has_data_id_field() const { return HasField("data_id"); }
bool RtBlobDesc::has_col_num_field() const { return header_pod_desc_.HasField(FieldKey::kColNum); }
bool RtBlobDesc::has_col_num_field() const { return HasField("col_num"); }
size_t RtBlobDesc::ByteSizeOfBlobHeader() const {
size_t header_size = 0;
for (auto& pair : header_desc_) { header_size += ByteSizeOfField(pair.first); }
return header_size;
}
size_t RtBlobDesc::ByteSizeOfBlobHeader() const { return header_pod_desc_.ByteSize(); }
size_t RtBlobDesc::ByteSizeOfBlobBody() const { return body_desc_.AlignedByteSize(); }
size_t RtBlobDesc::ByteSizeOfDataIdField() const {
return HasField("data_id") ? ByteSizeOfField("data_id") : 0;
if (!has_data_id_field()) { return 0; }
return header_pod_desc_.Field(FieldKey::kDataId).ByteSize();
}
size_t RtBlobDesc::ByteSizeOfColNumField() const {
return HasField("col_num") ? ByteSizeOfField("col_num") : 0;
if (!has_col_num_field()) { return 0; }
return header_pod_desc_.Field(FieldKey::kColNum).ByteSize();
}
size_t RtBlobDesc::ByteSizeOfDataContentField() const { return body_desc_.ByteSize(); }
......@@ -74,25 +47,4 @@ bool RtBlobDesc::operator==(const RtBlobDesc& rhs) const {
return message_diff.Equals(blob_desc_proto_, rhs.blob_desc_proto());
}
HashMap<std::string, FieldDesc>::const_iterator RtBlobDesc::GetFieldIteratorOrFail(
const std::string& field_name) const {
auto field_it = header_desc_.find(field_name);
CHECK(field_it != header_desc_.end());
return field_it;
}
bool RtBlobDesc::HasField(const std::string& field_name) const {
return header_desc_.find(field_name) != header_desc_.end();
}
size_t RtBlobDesc::ByteSizeOfField(const std::string& field_name) const {
auto field_it = GetFieldIteratorOrFail(field_name);
return field_it->second.ByteSize();
}
size_t RtBlobDesc::AlignedByteSizeOfField(const std::string& field_name) const {
auto field_it = GetFieldIteratorOrFail(field_name);
return field_it->second.AlignedByteSize();
}
} // namespace oneflow
......@@ -20,11 +20,10 @@ class RtBlobDesc {
const BlobDescProto& blob_desc_proto() const { return blob_desc_proto_; }
const Shape& shape() const; // body shape
DataType data_type() const; // body data type
const Shape& shape(const std::string& field_name) const;
DataType data_type(const std::string& field_name) const;
bool has_data_id_field() const;
bool has_col_num_field() const;
const StructPodDesc& header_pod_desc() const { return header_pod_desc_; }
int32_t max_col_num() const { return blob_desc_proto_.header().max_col_num(); }
......@@ -40,15 +39,10 @@ class RtBlobDesc {
private:
void InitFromProto(const BlobDescProto& proto);
HashMap<std::string, FieldDesc>::const_iterator GetFieldIteratorOrFail(
const std::string& field_name) const;
bool HasField(const std::string& field_name) const;
size_t ByteSizeOfField(const std::string& field_name) const;
size_t AlignedByteSizeOfField(const std::string& field_name) const;
BlobDescProto blob_desc_proto_;
HashMap<std::string, FieldDesc> header_desc_;
FieldDesc body_desc_;
StructPodDesc header_pod_desc_;
};
} // namespace oneflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册