未验证 提交 a4a7e4df 编写于 作者: C cheng cheng 提交者: GitHub

Remove RtBlobDesc (#4644)

* Remove RtBlobDesc

* refine code for RuntimeBlobShapeInferHelper::BlobDesc4BnInOp
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 a0348862
......@@ -28,15 +28,14 @@ Maybe<void> EagerBlobObject::TryInitBlob() {
Maybe<void> EagerBlobObject::InitBlob() {
CHECK_NE_OR_RETURN(blob_desc_.data_type(), DataType::kInvalidDataType);
rt_blob_desc_.reset(new RtBlobDesc(blob_desc_));
{
header_buffer_.reset();
int64_t header_byte_size = rt_blob_desc_->ByteSizeOfBlobHeader();
int64_t header_byte_size = blob_desc_.ByteSizeOfBlobHeader();
const auto& FreeHeader = [header_byte_size](char* dptr) { std::free(dptr); };
char* ptr = reinterpret_cast<char*>(std::malloc(header_byte_size));
header_buffer_ = std::unique_ptr<char, std::function<void(char*)>>(ptr, FreeHeader);
}
blob_.reset(new Blob(*mem_case_, rt_blob_desc_.get(), header_buffer_.get(), nullptr));
blob_.reset(new Blob(*mem_case_, &blob_desc_, header_buffer_.get(), nullptr));
return Maybe<void>::Ok();
}
......
......@@ -68,9 +68,6 @@ class EagerBlobObject final : public BlobObject {
std::shared_ptr<TensorBuffer> tensor_buffer_;
std::size_t blob_body_bytes_;
MemoryAllocator non_pod_initer_;
protected:
std::unique_ptr<RtBlobDesc> rt_blob_desc_;
};
} // namespace eager
......
......@@ -28,9 +28,7 @@ class LazyRefBlobObject final : public BlobObject {
LazyRefBlobObject(Blob* blob)
: BlobObject(std::make_shared<MemoryCase>(blob->mem_case()),
std::make_shared<Shape>(blob->static_shape()), blob->data_type()) {
const auto& rt_blob_desc = blob->blob_desc();
blob_desc_ =
BlobDesc(rt_blob_desc.body_shape(), rt_blob_desc.data_type(), rt_blob_desc.is_dynamic());
blob_desc_ = blob->blob_desc();
ref_blob_ = blob;
}
~LazyRefBlobObject() override = default;
......
......@@ -17,7 +17,7 @@ limitations under the License.
#include "oneflow/core/common/util.h"
#include "oneflow/core/graph/inplace_lbi_graph.h"
#include "oneflow/core/graph/id_serialization.h"
#include "oneflow/core/register/runtime_blob_desc.h"
#include "oneflow/core/register/blob_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/operator/variable_op.h"
#include "oneflow/core/graph/op_graph.h"
......
......@@ -22,7 +22,6 @@ limitations under the License.
#include "oneflow/core/job/placement.pb.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/register/logical_blob_id.pb.h"
#include "oneflow/core/register/runtime_blob_desc.h"
namespace oneflow {
......
......@@ -16,7 +16,6 @@ limitations under the License.
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/job_rewriter/pass_util.h"
#include "oneflow/core/register/runtime_blob_desc.h"
namespace oneflow {
......
......@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/register/runtime_blob_desc.h"
#include "oneflow/core/framework/framework.h"
namespace oneflow {
......
......@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/register/runtime_blob_desc.h"
#include "oneflow/core/framework/framework.h"
namespace oneflow {
......
......@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/register/runtime_blob_desc.h"
#include "oneflow/core/framework/framework.h"
namespace oneflow {
......
......@@ -15,7 +15,6 @@ limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/register/runtime_blob_desc.h"
namespace oneflow {
......
......@@ -15,7 +15,6 @@ limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/register/runtime_blob_desc.h"
namespace oneflow {
......
......@@ -51,17 +51,15 @@ class OnDemandHostBlob final {
explicit OnDemandHostBlob(const Blob* like) {
Shape shape;
like->shape().ToShape(&shape);
blob_desc_.reset(new RtBlobDesc(BlobDesc(shape, like->data_type())));
blob_desc_.reset(new BlobDesc(shape, like->data_type()));
Init();
}
explicit OnDemandHostBlob(const RtBlobDesc& blob_desc) {
blob_desc_.reset(new RtBlobDesc(BlobDesc(blob_desc.body_shape(), blob_desc.data_type())));
explicit OnDemandHostBlob(const BlobDesc& blob_desc) {
blob_desc_.reset(new BlobDesc(blob_desc));
Init();
}
explicit OnDemandHostBlob(const Shape& shape, DataType data_type) {
BlobDesc blob_desc(data_type);
blob_desc.mut_shape() = shape;
blob_desc_.reset(new RtBlobDesc(blob_desc));
blob_desc_.reset(new BlobDesc(shape, data_type));
Init();
}
~OnDemandHostBlob() = default;
......@@ -80,7 +78,7 @@ class OnDemandHostBlob final {
std::vector<char> header;
std::vector<char> data;
std::unique_ptr<Blob> blob_;
std::unique_ptr<RtBlobDesc> blob_desc_;
std::unique_ptr<const BlobDesc> blob_desc_;
};
template<DeviceType device_type>
......
......@@ -77,13 +77,11 @@ void RuntimeBlobShapeInferHelper::UpdateInputBlobDescs7OpInferCacheKey(
}
BlobDesc* RuntimeBlobShapeInferHelper::BlobDesc4BnInOp(const std::string& bn_in_op,
const RtBlobDesc& rt_blob_desc) {
BlobDesc* blob_desc = bn_in_op2blob_desc_.at(bn_in_op).get();
if (blob_desc != nullptr) { return blob_desc; }
blob_desc =
new BlobDesc(rt_blob_desc.body_shape(), rt_blob_desc.data_type(), rt_blob_desc.is_dynamic());
bn_in_op2blob_desc_.at(bn_in_op).reset(blob_desc);
return blob_desc;
const BlobDesc& blob_desc) {
auto it = bn_in_op2blob_desc_.find(bn_in_op);
if (it == bn_in_op2blob_desc_.end()) { return nullptr; }
if (!it->second) { it->second.reset(new BlobDesc(blob_desc)); }
return it->second.get();
}
void RuntimeBlobShapeInferHelper::InferShape(std::function<Blob*(const std::string&)> BnInOp2Blob) {
......
......@@ -22,7 +22,7 @@ limitations under the License.
namespace oneflow {
class Blob;
class RtBlobDesc;
class BlobDesc;
class RuntimeBlobShapeInferHelper final {
public:
......@@ -34,7 +34,7 @@ class RuntimeBlobShapeInferHelper final {
private:
void UpdateInputBlobDescs7OpInferCacheKey(std::function<Blob*(const std::string&)> BnInOp2Blob);
BlobDesc* BlobDesc4BnInOp(const std::string& bn_in_op, const RtBlobDesc& rt_blob_desc);
BlobDesc* BlobDesc4BnInOp(const std::string& bn_in_op, const BlobDesc& rt_blob_desc);
std::shared_ptr<Operator> op_;
HashSet<std::string> ibns_;
......
......@@ -36,7 +36,7 @@ namespace {
void FillTensorDescWithBlob(const Blob* blob, user_op::NaiveTensorDesc* tensor_desc) {
BlobDescProto proto;
blob->blob_desc().body_shape().ToProto(proto.mutable_shape());
blob->blob_desc().shape().ToProto(proto.mutable_shape());
proto.set_data_type(blob->blob_desc().data_type());
proto.set_is_dynamic(blob->blob_desc().is_dynamic());
*tensor_desc = proto;
......
......@@ -113,15 +113,15 @@ void MemoryAllocator::Deallocate(char* dptr, MemoryCase mem_case) {
}
void InitNonPODTypeBlobIfNeed(MemoryAllocator* allocator, Blob* blob_ptr) {
const RtBlobDesc& blob_desc = blob_ptr->blob_desc();
const BlobDesc& blob_desc = blob_ptr->blob_desc();
if (blob_desc.data_type() == kOFRecord) {
int64_t elem_cnt = blob_desc.body_shape().elem_cnt();
int64_t elem_cnt = blob_desc.shape().elem_cnt();
FOR_RANGE(int64_t, idx, 0, elem_cnt) {
allocator->PlacementNew(&blob_ptr->mut_dptr<OFRecord>()[idx]);
}
}
if (blob_desc.data_type() == kTensorBuffer) {
int64_t elem_cnt = blob_desc.body_shape().elem_cnt();
int64_t elem_cnt = blob_desc.shape().elem_cnt();
FOR_RANGE(int64_t, idx, 0, elem_cnt) {
allocator->PlacementNew(&blob_ptr->mut_dptr<TensorBuffer>()[idx]);
}
......
......@@ -18,16 +18,16 @@ limitations under the License.
namespace oneflow {
Blob::Blob(const MemoryCase& mem_case, const RtBlobDesc* blob_desc, char* header_ptr) {
Blob::Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr) {
Init(mem_case, blob_desc, header_ptr, header_ptr + blob_desc->ByteSizeOfBlobHeader());
}
Blob::Blob(const MemoryCase& mem_case, const RtBlobDesc* blob_desc, char* header_ptr,
Blob::Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr,
char* body_ptr) {
Init(mem_case, blob_desc, header_ptr, body_ptr);
}
void Blob::Init(const MemoryCase& mem_case, const RtBlobDesc* blob_desc, char* header_ptr,
void Blob::Init(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr,
char* body_ptr) {
mem_case_ = mem_case;
blob_desc_ = blob_desc;
......
......@@ -19,7 +19,7 @@ limitations under the License.
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/job/resource.pb.h"
#include "oneflow/core/memory/memory_case.pb.h"
#include "oneflow/core/register/runtime_blob_desc.h"
#include "oneflow/core/register/blob_desc.h"
#include "oneflow/core/common/shape_view.h"
#include "oneflow/core/common/symbol.h"
......@@ -48,16 +48,16 @@ class BlobAccessCheckerIf final : public BlobAccessChecker {
class Blob final {
public:
OF_DISALLOW_COPY_AND_MOVE(Blob);
Blob(const MemoryCase& mem_case, const RtBlobDesc* blob_desc, char* header_ptr);
Blob(const MemoryCase& mem_case, const RtBlobDesc* blob_desc, char* header_ptr, char* body_ptr);
Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr);
Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr, char* body_ptr);
virtual ~Blob() = default;
DataType data_type() const { return blob_desc_->data_type(); }
const char* header_ptr() const { return header_ptr_; }
char* mut_header_ptr() { return header_ptr_; }
char* mut_contiguous_header_ptr();
const RtBlobDesc& blob_desc() const { return *blob_desc_; }
const RtBlobDesc* blob_desc_ptr() const { return blob_desc_; }
const BlobDesc& blob_desc() const { return *blob_desc_; }
const BlobDesc* blob_desc_ptr() const { return blob_desc_; }
template<typename T = void>
const T* dptr() const {
......@@ -75,7 +75,7 @@ class Blob final {
CheckDataType<T>(data_type());
return static_cast<T*>(dptr_);
}
const Shape& static_shape() const { return blob_desc_->body_shape(); }
const Shape& static_shape() const { return blob_desc_->shape(); }
const ShapeView& shape_view() const { return *shape_view_; }
const ShapeView& shape() const { return *shape_view_; }
MutShapeView* mut_shape_view() {
......@@ -105,12 +105,12 @@ class Blob final {
const BlobAccessChecker* blob_access_checker() { return this->blob_access_checker_; }
private:
void Init(const MemoryCase& mem_case, const RtBlobDesc* blob_desc, char* header_ptr,
void Init(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr,
char* body_ptr);
const BlobAccessChecker* blob_access_checker_;
MemoryCase mem_case_;
const RtBlobDesc* blob_desc_;
const BlobDesc* blob_desc_;
void* dptr_;
char* header_ptr_;
std::unique_ptr<ShapeView> shape_view_;
......
......@@ -68,4 +68,18 @@ bool BlobDesc::operator==(const BlobDesc& rhs) const {
&& (is_dynamic() == rhs.is_dynamic());
}
size_t BlobDesc::ByteSizeOfBlobHeader() const { return shape().NumAxes() * sizeof(int64_t); }
size_t BlobDesc::ByteSizeOfBlobBody() const {
return shape().elem_cnt() * GetSizeOfDataType(data_type());
}
size_t BlobDesc::AlignedByteSizeOfBlobBody() const {
return RoundUp(ByteSizeOfBlobBody(), BlobDesc::kAlignSize);
}
size_t BlobDesc::AlignedTotalByteSize() const {
return ByteSizeOfBlobHeader() + AlignedByteSizeOfBlobBody();
}
} // namespace oneflow
......@@ -63,6 +63,11 @@ class BlobDesc final {
void CopyFrom(const BlobDesc&);
size_t ByteSizeOfBlobHeader() const;
size_t ByteSizeOfBlobBody() const;
size_t AlignedByteSizeOfBlobBody() const;
size_t AlignedTotalByteSize() const;
private:
std::shared_ptr<Shape> shape_;
DataType data_type_;
......
......@@ -17,7 +17,6 @@ limitations under the License.
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/register/runtime_blob_desc.h"
#include "oneflow/core/register/runtime_register_desc.h"
namespace oneflow {
......@@ -116,7 +115,7 @@ void RegstDesc::ForEachLbi(std::function<void(const LogicalBlobId&)> func) const
void RegstDesc::EraseZeroSizeBlob() {
EraseIf<LogicalBlobId, std::unique_ptr<BlobDesc>>(
&lbi2blob_desc_, [](HashMap<LogicalBlobId, std::unique_ptr<BlobDesc>>::iterator it) {
return RtBlobDesc(*(it->second)).ByteSizeOfBlobBody() == 0;
return it->second->ByteSizeOfBlobBody() == 0;
});
}
......@@ -159,8 +158,7 @@ void RegstDesc::ToProto(RegstDescProto* ret) const {
}
bool RegstDesc::HasSameMemSize(const RegstDesc* rhs) {
return RtBlobDesc(*SoleBlobDesc()).AlignedTotalByteSize()
== RtBlobDesc(*(rhs->SoleBlobDesc())).AlignedTotalByteSize();
return SoleBlobDesc()->AlignedTotalByteSize() == rhs->SoleBlobDesc()->AlignedTotalByteSize();
}
bool RegstDesc::HasSameBlobDescs(const RegstDesc* rhs) {
......
......@@ -187,12 +187,12 @@ void RegstMgr::NewBlobsInOneRegst(const std::vector<LbiBlobDescPair>& lbis, Regs
if (main_mem_ptr == nullptr) {
cur_body_pointer = nullptr;
} else {
cur_body_pointer = main_mem_ptr + rt_regst_desc->GetSoleRtBlobDesc()->ByteSizeOfBlobHeader();
cur_body_pointer = main_mem_ptr + rt_regst_desc->GetSoleBlobDesc()->ByteSizeOfBlobHeader();
}
}
rt_regst_desc->ForEachBlobDescOffsetInOnRegst([&](int64_t ordinal, const LogicalBlobId& lbi,
const RtBlobDesc* blob_desc,
int64_t body_offset, int64_t header_offset) {
const BlobDesc* blob_desc, int64_t body_offset,
int64_t header_offset) {
std::unique_ptr<Blob> blob_ptr;
if (cur_body_pointer == nullptr) {
blob_ptr.reset(new Blob(regst->regst_desc()->mem_case(), blob_desc,
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/register/runtime_blob_desc.h"
namespace oneflow {
RtBlobDesc::RtBlobDesc(const BlobDesc& blob_desc) {
shape_ = blob_desc.shape();
data_type_ = blob_desc.data_type();
is_dynamic_ = blob_desc.is_dynamic();
}
RtBlobDesc::RtBlobDesc(const BlobDescProto& proto) {
shape_ = Shape(proto.shape());
data_type_ = proto.data_type();
is_dynamic_ = proto.is_dynamic();
}
size_t RtBlobDesc::ByteSizeOfBlobHeader() const { return shape_.NumAxes() * sizeof(int64_t); }
size_t RtBlobDesc::ByteSizeOfBlobBody() const { return Capacity(); }
size_t RtBlobDesc::AlignedByteSizeOfBlobBody() const {
return RoundUp(ByteSizeOfBlobBody(), BlobDesc::kAlignSize);
}
size_t RtBlobDesc::AlignedTotalByteSize() const {
return ByteSizeOfBlobHeader() + AlignedByteSizeOfBlobBody();
}
bool RtBlobDesc::operator==(const RtBlobDesc& rhs) const {
return (shape_ == rhs.shape_) && (data_type_ == rhs.data_type_)
&& (is_dynamic_ == rhs.is_dynamic_);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_REGISTER_RUNTIME_BLOB_DESC_H_
#define ONEFLOW_CORE_REGISTER_RUNTIME_BLOB_DESC_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/register/blob_desc.h"
#include "oneflow/core/register/blob_desc.pb.h"
namespace oneflow {
class RtBlobDesc final {
public:
OF_DISALLOW_COPY_AND_MOVE(RtBlobDesc);
RtBlobDesc() = delete;
~RtBlobDesc() = default;
explicit RtBlobDesc(const BlobDesc& blob_desc);
explicit RtBlobDesc(const BlobDescProto& blob_desc_proto);
bool is_dynamic() const { return is_dynamic_; }
DataType data_type() const { return data_type_; }
int64_t NumAxes() const { return shape_.NumAxes(); }
int64_t Capacity() const { return shape_.elem_cnt() * GetSizeOfDataType(data_type()); }
const Shape& body_shape() const { return shape_; }
size_t ByteSizeOfBlobHeader() const;
size_t ByteSizeOfBlobBody() const;
size_t AlignedByteSizeOfBlobBody() const;
size_t AlignedTotalByteSize() const;
bool operator==(const RtBlobDesc& rhs) const;
private:
Shape shape_;
DataType data_type_;
bool is_dynamic_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_REGISTER_RUNTIME_BLOB_DESC_H_
......@@ -35,14 +35,14 @@ RtRegstDesc::RtRegstDesc(const RegstDescProto& proto) {
sorted_lbi_vec_.reserve(lbi_pairs.size());
for (int64_t i = 0; i < lbi_pairs.size(); ++i) {
const LbiBlobDescPair& pair = lbi_pairs.at(i);
sorted_blob_desc_vec_.push_back(std::make_unique<RtBlobDesc>(pair.blob_desc()));
sorted_blob_desc_vec_.push_back(std::make_unique<const BlobDesc>(pair.blob_desc()));
sorted_lbi_vec_.push_back(pair.lbi());
lbi2blob_desc_ordinal_.emplace(pair.lbi(), i);
}
CHECK(data_regst_desc.has_time_shape());
data_regst_time_shape_.reset(new Shape(data_regst_desc.time_shape()));
} else {
sorted_blob_desc_vec_.push_back(std::make_unique<RtBlobDesc>(BlobDesc(DataType::kChar)));
sorted_blob_desc_vec_.push_back(std::make_unique<const BlobDesc>(BlobDesc(DataType::kChar)));
}
}
......@@ -55,16 +55,16 @@ int64_t RtRegstDesc::GetOrdinalForLbi(const LogicalBlobId& lbi) const {
}
}
const RtBlobDesc* RtRegstDesc::GetRtBlobDescFromLbi(const LogicalBlobId& lbi) const {
const BlobDesc* RtRegstDesc::GetBlobDescFromLbi(const LogicalBlobId& lbi) const {
auto it = lbi2blob_desc_ordinal_.find(lbi);
if (it == lbi2blob_desc_ordinal_.end()) {
return nullptr;
} else {
return GetRtBlobDescByOrdinal(it->second);
return GetBlobDescByOrdinal(it->second);
}
}
const RtBlobDesc* RtRegstDesc::GetRtBlobDescByOrdinal(int64_t ordinal) const {
const BlobDesc* RtRegstDesc::GetBlobDescByOrdinal(int64_t ordinal) const {
return sorted_blob_desc_vec_.at(ordinal).get();
}
......@@ -72,13 +72,13 @@ const LogicalBlobId& RtRegstDesc::GetLbiByOrdinal(int64_t ordinal) const {
return sorted_lbi_vec_.at(ordinal);
}
const RtBlobDesc* RtRegstDesc::GetSoleRtBlobDesc() const {
const BlobDesc* RtRegstDesc::GetSoleBlobDesc() const {
CHECK_EQ(sorted_blob_desc_vec_.size(), 1);
return sorted_blob_desc_vec_.at(0).get();
}
size_t RtRegstDesc::TotalByteSize4AllRegst() const {
return GetSoleRtBlobDesc()->AlignedTotalByteSize() * register_num_;
return GetSoleBlobDesc()->AlignedTotalByteSize() * register_num_;
}
size_t RtRegstDesc::TotalMainByteSize4AllRegst() const {
......@@ -87,9 +87,9 @@ size_t RtRegstDesc::TotalMainByteSize4AllRegst() const {
size_t RtRegstDesc::MainByteSize4OneRegst() const {
if (mem_case_.has_device_cuda_mem()) {
return GetSoleRtBlobDesc()->AlignedByteSizeOfBlobBody();
return GetSoleBlobDesc()->AlignedByteSizeOfBlobBody();
} else {
return GetSoleRtBlobDesc()->AlignedTotalByteSize();
return GetSoleBlobDesc()->AlignedTotalByteSize();
}
}
......@@ -99,7 +99,7 @@ size_t RtRegstDesc::TotalSeparatedHeaderByteSize4AllRegst() const {
size_t RtRegstDesc::SeparatedHeaderByteSize4OneRegst() const {
if (mem_case_.has_device_cuda_mem()) {
return GetSoleRtBlobDesc()->ByteSizeOfBlobHeader();
return GetSoleBlobDesc()->ByteSizeOfBlobHeader();
} else {
return 0;
}
......@@ -112,12 +112,12 @@ const Shape& RtRegstDesc::data_regst_time_shape() const {
}
void RtRegstDesc::ForEachBlobDescOffsetInOnRegst(
const std::function<void(int64_t ordinal, const LogicalBlobId& lbi, const RtBlobDesc* desc,
const std::function<void(int64_t ordinal, const LogicalBlobId& lbi, const BlobDesc* desc,
int64_t body_offset, int64_t header_offset)>& Handler) const {
int64_t cur_body_offset = 0;
int64_t cur_header_offset = 0;
for (int64_t i = 0; i < sorted_blob_desc_vec_.size(); ++i) {
const RtBlobDesc* blob_desc = sorted_blob_desc_vec_.at(i).get();
const BlobDesc* blob_desc = sorted_blob_desc_vec_.at(i).get();
const LogicalBlobId& lbi = sorted_lbi_vec_.at(i);
Handler(i, lbi, blob_desc, cur_body_offset, cur_header_offset);
cur_body_offset += blob_desc->AlignedByteSizeOfBlobBody();
......
......@@ -17,7 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_
#include "oneflow/core/memory/memory_case.pb.h"
#include "oneflow/core/register/runtime_blob_desc.h"
#include "oneflow/core/register/blob_desc.h"
#include "oneflow/core/register/register_desc.pb.h"
namespace oneflow {
......@@ -39,9 +39,9 @@ class RtRegstDesc {
int64_t lbi_num() const { return sorted_lbi_vec_.size(); }
int64_t GetOrdinalForLbi(const LogicalBlobId& lbi) const;
const RtBlobDesc* GetRtBlobDescFromLbi(const LogicalBlobId& lbi) const;
const RtBlobDesc* GetRtBlobDescByOrdinal(int64_t ordinal) const;
const RtBlobDesc* GetSoleRtBlobDesc() const;
const BlobDesc* GetBlobDescFromLbi(const LogicalBlobId& lbi) const;
const BlobDesc* GetBlobDescByOrdinal(int64_t ordinal) const;
const BlobDesc* GetSoleBlobDesc() const;
const LogicalBlobId& GetLbiByOrdinal(int64_t ordinal) const;
size_t TotalByteSize4AllRegst() const;
size_t TotalMainByteSize4AllRegst() const;
......@@ -51,7 +51,7 @@ class RtRegstDesc {
const Shape& data_regst_time_shape() const;
void ForEachBlobDescOffsetInOnRegst(
const std::function<void(int64_t ordinal, const LogicalBlobId& lbi, const RtBlobDesc* desc,
const std::function<void(int64_t ordinal, const LogicalBlobId& lbi, const BlobDesc* desc,
int64_t body_offset, int64_t header_offset)>& Handler) const;
private:
......@@ -63,7 +63,7 @@ class RtRegstDesc {
MemoryCase mem_case_;
HashMap<LogicalBlobId, int64_t> lbi2blob_desc_ordinal_;
std::unique_ptr<Shape> data_regst_time_shape_;
std::vector<std::unique_ptr<RtBlobDesc>> sorted_blob_desc_vec_;
std::vector<std::unique_ptr<const BlobDesc>> sorted_blob_desc_vec_;
std::vector<LogicalBlobId> sorted_lbi_vec_;
};
......
......@@ -36,7 +36,7 @@ namespace oneflow {
namespace xrt {
static Parameter BuildParameter(const Blob& blob, const std::string& name) {
const auto& desc = blob.blob_desc();
return Parameter(name, const_cast<void*>(blob.dptr<void>()), desc.body_shape(), desc.data_type());
return Parameter(name, const_cast<void*>(blob.dptr<void>()), desc.shape(), desc.data_type());
}
} // namespace xrt
......@@ -47,16 +47,11 @@ void BlobDescGetter<device_type>::DumpEntryBlobDescTo(
const auto& io_mapping = launch_conf.input_output_mapping();
for (const auto& bn : kernel_->op_attribute().input_bns()) {
const RtBlobDesc& runtime_desc = get_blob_fn_(bn)->blob_desc();
BlobDesc blob_desc(kernel_->job_desc().DefaultDataType());
blob_desc.mut_shape() = runtime_desc.body_shape();
blob_desc.set_data_type(runtime_desc.data_type());
blob_desc.set_is_dynamic(runtime_desc.is_dynamic());
// Map blob_name to function's input name.
std::string blob_name = xrt::BlobIdToName(kernel_->BnInOp2Lbi(bn));
// CHECK_GT(io_mapping.count(blob_name), 0);
const std::string& mapping_name = io_mapping.at(blob_name);
entry_blob_desc->emplace(mapping_name, std::move(blob_desc));
entry_blob_desc->emplace(mapping_name, get_blob_fn_(bn)->blob_desc());
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册