提交 10929941 编写于 作者: W willzhang4a58

add set_col_id for blob


Former-commit-id: 271de42e
上级 b6ffa37f
...@@ -117,7 +117,7 @@ struct KernelUtil<DeviceType::kGPU, T> final { ...@@ -117,7 +117,7 @@ struct KernelUtil<DeviceType::kGPU, T> final {
BlobDesc blob_desc = BlobDesc(blob->blob_desc()); BlobDesc blob_desc = BlobDesc(blob->blob_desc());
char* host_raw_dptr = nullptr; char* host_raw_dptr = nullptr;
CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize())); CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize()));
Blob host_blob(&blob_desc, host_raw_dptr); Blob host_blob(nullptr, &blob_desc, host_raw_dptr);
// synchronous initialize the host blob // synchronous initialize the host blob
KernelUtil<DeviceType::kCPU, T>::Initialize(nullptr, initializer_conf, KernelUtil<DeviceType::kCPU, T>::Initialize(nullptr, initializer_conf,
random_seed, &host_blob); random_seed, &host_blob);
...@@ -137,7 +137,7 @@ struct KernelUtil<DeviceType::kGPU, T> final { ...@@ -137,7 +137,7 @@ struct KernelUtil<DeviceType::kGPU, T> final {
BlobDesc blob_desc = BlobDesc(blob->blob_desc()); BlobDesc blob_desc = BlobDesc(blob->blob_desc());
char* host_raw_dptr = nullptr; char* host_raw_dptr = nullptr;
CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize())); CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize()));
Blob host_blob(&blob_desc, host_raw_dptr); Blob host_blob(nullptr, &blob_desc, host_raw_dptr);
KernelUtil<DeviceType::kCPU, T>::InitializeWithModelDir( KernelUtil<DeviceType::kCPU, T>::InitializeWithModelDir(
ctx, part_id, part_num, model_dir, &host_blob, bn_in_op, dim_num, ctx, part_id, part_num, model_dir, &host_blob, bn_in_op, dim_num,
num_in_each_dim); num_in_each_dim);
......
...@@ -34,7 +34,7 @@ template<> ...@@ -34,7 +34,7 @@ template<>
Blob* CreateBlob<DeviceType::kCPU>(const BlobDesc* blob_desc) { Blob* CreateBlob<DeviceType::kCPU>(const BlobDesc* blob_desc) {
void* mem_ptr = nullptr; void* mem_ptr = nullptr;
CudaCheck(cudaMallocHost(&mem_ptr, blob_desc->TotalByteSize())); CudaCheck(cudaMallocHost(&mem_ptr, blob_desc->TotalByteSize()));
return new Blob(blob_desc, static_cast<char*>(mem_ptr)); return new Blob(nullptr, blob_desc, static_cast<char*>(mem_ptr));
} }
template<> template<>
......
...@@ -10,7 +10,7 @@ template<> ...@@ -10,7 +10,7 @@ template<>
Blob* CreateBlob<DeviceType::kGPU>(const BlobDesc* blob_desc) { Blob* CreateBlob<DeviceType::kGPU>(const BlobDesc* blob_desc) {
void* mem_ptr = nullptr; void* mem_ptr = nullptr;
CudaCheck(cudaMalloc(&mem_ptr, blob_desc->TotalByteSize())); CudaCheck(cudaMalloc(&mem_ptr, blob_desc->TotalByteSize()));
return new Blob(blob_desc, static_cast<char*>(mem_ptr)); return new Blob(nullptr, blob_desc, static_cast<char*>(mem_ptr));
} }
template<> template<>
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace oneflow { namespace oneflow {
Blob::Blob(const BlobDesc* blob_desc, char* mem_ptr, Blob::Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr,
const void* comm_net_token) { const void* comm_net_token) {
mem_ptr_ = mem_ptr; mem_ptr_ = mem_ptr;
if (blob_desc->has_data_id_field()) { if (blob_desc->has_data_id_field()) {
...@@ -23,6 +23,7 @@ Blob::Blob(const BlobDesc* blob_desc, char* mem_ptr, ...@@ -23,6 +23,7 @@ Blob::Blob(const BlobDesc* blob_desc, char* mem_ptr,
+ blob_desc->ByteSizeOfColNumField(); + blob_desc->ByteSizeOfColNumField();
blob_desc_ = blob_desc; blob_desc_ = blob_desc;
comm_net_token_ = comm_net_token; comm_net_token_ = comm_net_token;
regst_ = regst;
} }
const char* Blob::data_id(int32_t no) const { const char* Blob::data_id(int32_t no) const {
...@@ -80,6 +81,10 @@ void Blob::CopyFrom(DeviceCtx* device_ctx, const Blob* rhs) { ...@@ -80,6 +81,10 @@ void Blob::CopyFrom(DeviceCtx* device_ctx, const Blob* rhs) {
TotalByteSize()); TotalByteSize());
} }
void Blob::set_col_id(int32_t val) { regst_->set_col_id(val); }
void Blob::set_max_col_id(int32_t val) { regst_->set_max_col_id(val); }
#define INSTANTIATE_BLOB_FUNC(dev_t) \ #define INSTANTIATE_BLOB_FUNC(dev_t) \
template void Blob::CopyDataContentFrom<dev_t>(DeviceCtx*, const Blob*); \ template void Blob::CopyDataContentFrom<dev_t>(DeviceCtx*, const Blob*); \
template void Blob::CopyDataIdFrom<dev_t>(DeviceCtx*, const Blob*); \ template void Blob::CopyDataIdFrom<dev_t>(DeviceCtx*, const Blob*); \
......
...@@ -7,12 +7,15 @@ ...@@ -7,12 +7,15 @@
namespace oneflow { namespace oneflow {
class Regst;
class Blob final { class Blob final {
public: public:
OF_DISALLOW_COPY_AND_MOVE(Blob); OF_DISALLOW_COPY_AND_MOVE(Blob);
Blob(const BlobDesc* blob_desc, char* mem_ptr) Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr)
: Blob(blob_desc, mem_ptr, nullptr) {} : Blob(regst, blob_desc, mem_ptr, nullptr) {}
Blob(const BlobDesc* blob_desc, char* mem_ptr, const void* comm_net_token); Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr,
const void* comm_net_token);
~Blob() = default; ~Blob() = default;
const char* data_id(int32_t no) const; const char* data_id(int32_t no) const;
...@@ -62,6 +65,9 @@ class Blob final { ...@@ -62,6 +65,9 @@ class Blob final {
template<DeviceType device_type> template<DeviceType device_type>
void CopyFrom(DeviceCtx* device_ctx, const Blob* rhs); void CopyFrom(DeviceCtx* device_ctx, const Blob* rhs);
void set_col_id(int32_t val);
void set_max_col_id(int32_t val);
private: private:
template<typename T> template<typename T>
void CheckDataType() const { void CheckDataType() const {
...@@ -78,6 +84,7 @@ class Blob final { ...@@ -78,6 +84,7 @@ class Blob final {
void* dptr_; void* dptr_;
const void* comm_net_token_; const void* comm_net_token_;
const BlobDesc* blob_desc_; const BlobDesc* blob_desc_;
Regst* regst_;
}; };
} // namespace oneflow } // namespace oneflow
......
...@@ -24,13 +24,13 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto, ...@@ -24,13 +24,13 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
char* cur_pointer = std::get<0>(allocation_result); char* cur_pointer = std::get<0>(allocation_result);
for (const std::string& lbn : lbns) { for (const std::string& lbn : lbns) {
const BlobDesc* blob_desc = runtime_regst_desc->GetBlobDescFromLbn(lbn); const BlobDesc* blob_desc = runtime_regst_desc->GetBlobDescFromLbn(lbn);
auto blob_ptr = of_make_unique<Blob>(blob_desc, cur_pointer); auto blob_ptr = of_make_unique<Blob>(regst, blob_desc, cur_pointer);
CHECK(regst->lbn2blob_.emplace(lbn, std::move(blob_ptr)).second); CHECK(regst->lbn2blob_.emplace(lbn, std::move(blob_ptr)).second);
cur_pointer += blob_desc->TotalByteSize(); cur_pointer += blob_desc->TotalByteSize();
} }
regst->packed_blob_.reset(new Blob(runtime_regst_desc->packed_blob_desc(), regst->packed_blob_.reset(new Blob(
std::get<0>(allocation_result), regst, runtime_regst_desc->packed_blob_desc(),
std::get<1>(allocation_result))); std::get<0>(allocation_result), std::get<1>(allocation_result)));
regst->deleter_ = std::get<2>(allocation_result); regst->deleter_ = std::get<2>(allocation_result);
OneRegstDone(regst); OneRegstDone(regst);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册