提交 10929941 编写于 作者: W willzhang4a58

add set_col_id for blob


Former-commit-id: 271de42e
上级 b6ffa37f
......@@ -117,7 +117,7 @@ struct KernelUtil<DeviceType::kGPU, T> final {
BlobDesc blob_desc = BlobDesc(blob->blob_desc());
char* host_raw_dptr = nullptr;
CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize()));
Blob host_blob(&blob_desc, host_raw_dptr);
Blob host_blob(nullptr, &blob_desc, host_raw_dptr);
// synchronous initialize the host blob
KernelUtil<DeviceType::kCPU, T>::Initialize(nullptr, initializer_conf,
random_seed, &host_blob);
......@@ -137,7 +137,7 @@ struct KernelUtil<DeviceType::kGPU, T> final {
BlobDesc blob_desc = BlobDesc(blob->blob_desc());
char* host_raw_dptr = nullptr;
CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize()));
Blob host_blob(&blob_desc, host_raw_dptr);
Blob host_blob(nullptr, &blob_desc, host_raw_dptr);
KernelUtil<DeviceType::kCPU, T>::InitializeWithModelDir(
ctx, part_id, part_num, model_dir, &host_blob, bn_in_op, dim_num,
num_in_each_dim);
......
......@@ -34,7 +34,7 @@ template<>
Blob* CreateBlob<DeviceType::kCPU>(const BlobDesc* blob_desc) {
void* mem_ptr = nullptr;
CudaCheck(cudaMallocHost(&mem_ptr, blob_desc->TotalByteSize()));
return new Blob(blob_desc, static_cast<char*>(mem_ptr));
return new Blob(nullptr, blob_desc, static_cast<char*>(mem_ptr));
}
template<>
......
......@@ -10,7 +10,7 @@ template<>
Blob* CreateBlob<DeviceType::kGPU>(const BlobDesc* blob_desc) {
void* mem_ptr = nullptr;
CudaCheck(cudaMalloc(&mem_ptr, blob_desc->TotalByteSize()));
return new Blob(blob_desc, static_cast<char*>(mem_ptr));
return new Blob(nullptr, blob_desc, static_cast<char*>(mem_ptr));
}
template<>
......
......@@ -5,7 +5,7 @@
namespace oneflow {
Blob::Blob(const BlobDesc* blob_desc, char* mem_ptr,
Blob::Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr,
const void* comm_net_token) {
mem_ptr_ = mem_ptr;
if (blob_desc->has_data_id_field()) {
......@@ -23,6 +23,7 @@ Blob::Blob(const BlobDesc* blob_desc, char* mem_ptr,
+ blob_desc->ByteSizeOfColNumField();
blob_desc_ = blob_desc;
comm_net_token_ = comm_net_token;
regst_ = regst;
}
const char* Blob::data_id(int32_t no) const {
......@@ -80,6 +81,10 @@ void Blob::CopyFrom(DeviceCtx* device_ctx, const Blob* rhs) {
TotalByteSize());
}
void Blob::set_col_id(int32_t val) { regst_->set_col_id(val); }
void Blob::set_max_col_id(int32_t val) { regst_->set_max_col_id(val); }
#define INSTANTIATE_BLOB_FUNC(dev_t) \
template void Blob::CopyDataContentFrom<dev_t>(DeviceCtx*, const Blob*); \
template void Blob::CopyDataIdFrom<dev_t>(DeviceCtx*, const Blob*); \
......
......@@ -7,12 +7,15 @@
namespace oneflow {
class Regst;
class Blob final {
public:
OF_DISALLOW_COPY_AND_MOVE(Blob);
Blob(const BlobDesc* blob_desc, char* mem_ptr)
: Blob(blob_desc, mem_ptr, nullptr) {}
Blob(const BlobDesc* blob_desc, char* mem_ptr, const void* comm_net_token);
Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr)
: Blob(regst, blob_desc, mem_ptr, nullptr) {}
Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr,
const void* comm_net_token);
~Blob() = default;
const char* data_id(int32_t no) const;
......@@ -62,6 +65,9 @@ class Blob final {
template<DeviceType device_type>
void CopyFrom(DeviceCtx* device_ctx, const Blob* rhs);
void set_col_id(int32_t val);
void set_max_col_id(int32_t val);
private:
template<typename T>
void CheckDataType() const {
......@@ -78,6 +84,7 @@ class Blob final {
void* dptr_;
const void* comm_net_token_;
const BlobDesc* blob_desc_;
Regst* regst_;
};
} // namespace oneflow
......
......@@ -24,13 +24,13 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
char* cur_pointer = std::get<0>(allocation_result);
for (const std::string& lbn : lbns) {
const BlobDesc* blob_desc = runtime_regst_desc->GetBlobDescFromLbn(lbn);
auto blob_ptr = of_make_unique<Blob>(blob_desc, cur_pointer);
auto blob_ptr = of_make_unique<Blob>(regst, blob_desc, cur_pointer);
CHECK(regst->lbn2blob_.emplace(lbn, std::move(blob_ptr)).second);
cur_pointer += blob_desc->TotalByteSize();
}
regst->packed_blob_.reset(new Blob(runtime_regst_desc->packed_blob_desc(),
std::get<0>(allocation_result),
std::get<1>(allocation_result)));
regst->packed_blob_.reset(new Blob(
regst, runtime_regst_desc->packed_blob_desc(),
std::get<0>(allocation_result), std::get<1>(allocation_result)));
regst->deleter_ = std::get<2>(allocation_result);
OneRegstDone(regst);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册