提交 9ec6871d 编写于 作者: L leaves-zwx

modify by review

上级 d338b271
......@@ -29,27 +29,32 @@ namespace oneflow {
class DeviceId {
public:
using index_t = uint32_t;
using node_index_t = uint32_t;
using device_type_t = uint32_t;
using device_index_t = uint32_t;
constexpr static size_t kNodeIndexBits = 19;
constexpr static size_t kDeviceTypeBits = 5;
constexpr static size_t kDeviceIndexBits = 7;
constexpr static index_t kMaxNodeIndex = (index_t{1} << kNodeIndexBits) - index_t{1};
constexpr static index_t kMaxDeviceTypeVal = (index_t{1} << kDeviceTypeBits) - index_t{1};
constexpr static index_t kMaxDeviceIndex = (index_t{1} << kDeviceIndexBits) - index_t{1};
DeviceId(index_t node_index, DeviceType device_type, index_t device_index)
constexpr static node_index_t kMaxNodeIndex =
(node_index_t{1} << kNodeIndexBits) - node_index_t{1};
constexpr static device_type_t kMaxDeviceTypeVal =
(device_type_t{1} << kDeviceTypeBits) - device_type_t{1};
constexpr static device_index_t kMaxDeviceIndex =
(device_index_t{1} << kDeviceIndexBits) - device_index_t{1};
DeviceId(node_index_t node_index, DeviceType device_type, device_index_t device_index)
: node_index_(node_index),
device_type_(static_cast<index_t>(device_type)),
device_type_(static_cast<device_type_t>(device_type)),
device_index_(device_index) {
CHECK_LE(node_index_, kMaxNodeIndex);
CHECK_LE(device_type_, kMaxDeviceTypeVal);
CHECK_LE(device_index, kMaxDeviceIndex);
CHECK_LE(device_index_, kMaxDeviceIndex);
}
index_t node_index() const { return node_index_; }
node_index_t node_index() const { return node_index_; }
DeviceType device_type() const { return static_cast<DeviceType>(device_type_); }
index_t device_index() const { return device_index_; }
device_index_t device_index() const { return device_index_; }
bool operator==(const DeviceId& rhs) const {
return node_index_ == rhs.node_index_ && device_type_ == rhs.device_type_
......@@ -59,16 +64,16 @@ class DeviceId {
bool operator!=(const DeviceId& rhs) const { return !(*this == rhs); }
size_t hash() const {
size_t hash = std::hash<index_t>{}(node_index_);
HashCombine(&hash, std::hash<index_t>{}(device_type_));
HashCombine(&hash, std::hash<index_t>{}(device_index_));
size_t hash = std::hash<node_index_t>{}(node_index_);
HashCombine(&hash, std::hash<device_type_t>{}(device_type_));
HashCombine(&hash, std::hash<device_index_t>{}(device_index_));
return hash;
}
private:
index_t node_index_;
index_t device_type_;
index_t device_index_;
node_index_t node_index_;
device_type_t device_type_;
device_index_t device_index_;
};
} // namespace oneflow
......
......@@ -25,7 +25,7 @@ namespace oneflow {
class StreamIndexGenerator {
public:
virtual ~StreamIndexGenerator() {}
using index_t = StreamId::index_t;
using index_t = StreamId::stream_index_t;
virtual index_t GenerateComputeStreamIndex() = 0;
virtual index_t GenerateH2DStreamIndex() = 0;
......
......@@ -65,8 +65,8 @@ void NcclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node,
const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
const int64_t device_index = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::index_t>(device_index)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::device_index_t>(device_index)};
auto* stream_index_generator = dynamic_cast<CudaStreamIndexGenerator*>(
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id));
CHECK_NOTNULL(stream_index_generator);
......@@ -191,8 +191,8 @@ class NcclCollectiveBoxingP2SNoncontinuousSubTskGphBuilder final : public SubTsk
FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {
const int64_t machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(i));
const int64_t device_index = CHECK_JUST(in_parallel_desc.DeviceId4ParallelId(i));
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::index_t>(device_index)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::device_index_t>(device_index)};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......@@ -293,8 +293,8 @@ class NcclCollectiveBoxingS2BNoncontinuousSubTskGphBuilder final : public SubTsk
FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {
const int64_t machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(i));
const int64_t device_index = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(i));
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::index_t>(device_index)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::device_index_t>(device_index)};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......@@ -406,7 +406,7 @@ class CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder final : public Su
SliceBoxingTaskNode* slice_node = ctx->task_graph()->NewNode<SliceBoxingTaskNode>();
// slice on cpu
const auto in_machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(0));
DeviceId device_id{static_cast<DeviceId::index_t>(in_machine_id), DeviceType::kCPU, 0};
DeviceId device_id{static_cast<DeviceId::node_index_t>(in_machine_id), DeviceType::kCPU, 0};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......@@ -522,8 +522,8 @@ class NcclCollectiveBoxingAll2AllSubTskGphBuilder final : public SubTskGphBuilde
FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {
const int64_t machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(i));
const int64_t device_index = CHECK_JUST(in_parallel_desc.DeviceId4ParallelId(i));
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::index_t>(device_index)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::device_index_t>(device_index)};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......
......@@ -58,8 +58,8 @@ Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
int64_t thrd_id = -1;
if (out_parallel_desc.device_type() == DeviceType::kGPU) {
#ifdef WITH_CUDA
DeviceId device_id{static_cast<DeviceId::index_t>(out_machine_id), DeviceType::kGPU,
static_cast<DeviceId::index_t>(out_dev_phy_id)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(out_machine_id), DeviceType::kGPU,
static_cast<DeviceId::device_index_t>(out_dev_phy_id)};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......@@ -68,7 +68,8 @@ Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
UNIMPLEMENTED();
#endif
} else if (out_parallel_desc.device_type() == DeviceType::kCPU) {
DeviceId device_id{static_cast<DeviceId::index_t>(out_machine_id), DeviceType::kCPU, 0};
DeviceId device_id{static_cast<DeviceId::node_index_t>(out_machine_id), DeviceType::kCPU,
0};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......
......@@ -61,8 +61,8 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
} else {
dev_id = CHECK_JUST(pd.DeviceId4ParallelId(parallel_id));
}
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), pd.device_type(),
static_cast<DeviceId::index_t>(dev_id)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), pd.device_type(),
static_cast<DeviceId::device_index_t>(dev_id)};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......
......@@ -45,7 +45,7 @@ void CopyHdTaskNode::Init(CopyHdOpConf::Type copy_type, const DeviceId& device_i
set_machine_id(device_id.node_index());
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
StreamId::index_t stream_index = 0;
StreamId::stream_index_t stream_index = 0;
if (copy_type == CopyHdOpConf::H2D) {
stream_index = stream_index_generator->GenerateH2DStreamIndex();
} else if (copy_type == CopyHdOpConf::D2H) {
......@@ -84,7 +84,7 @@ OperatorConf CopyHdTaskNode::NewCopyOpConf() {
void CopyCommNetTaskNode::Init(int64_t machine_id, const LogicalBlobId& lbi) {
set_machine_id(machine_id);
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), DeviceType::kCPU, 0};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), DeviceType::kCPU, 0};
auto* generator = dynamic_cast<CPUStreamIndexGenerator*>(
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id));
CHECK_NOTNULL(generator);
......
......@@ -22,7 +22,7 @@ StreamIndexGetterRegistryManager& StreamIndexGetterRegistryManager::Get() {
return mgr;
}
StreamId::index_t StreamIndexGetterRegistryManager::StreamIndex4DeviceIdAndTaskType(
StreamId::stream_index_t StreamIndexGetterRegistryManager::StreamIndex4DeviceIdAndTaskType(
DeviceId device_id, TaskType task_type) {
auto index_getter_fn = StreamIndexGetterRegistryManager::GetStreamIndexGetterFunc(
device_id.device_type(), task_type);
......
......@@ -47,7 +47,7 @@ class StreamIndexGetterRegistryManager final {
StreamIndexKeyMap<StreamIndexGetterFn>& StreamIndexGetterFuncs();
StreamId::index_t StreamIndex4DeviceIdAndTaskType(DeviceId device_id, TaskType task_type);
StreamId::stream_index_t StreamIndex4DeviceIdAndTaskType(DeviceId device_id, TaskType task_type);
private:
StreamIndexGetterFn GetStreamIndexGetterFunc(DeviceType dev_type, TaskType task_type);
......
......@@ -284,16 +284,17 @@ void GenSortedCompTaskNodes(const OpNode* op_node, std::vector<CompTaskNode*>* s
comp_task_node->mut_parallel_ctx()->set_parallel_id(parallel_idx++);
comp_task_node->mut_parallel_ctx()->set_parallel_num(parallel_num);
DeviceId::index_t device_index = parallel_desc.device_type() == DeviceType::kCPU
? 0
: static_cast<DeviceId::index_t>(dev_phy_id);
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), parallel_desc.device_type(),
device_index};
StreamId::index_t stream_index{};
DeviceId::device_index_t device_index =
parallel_desc.device_type() == DeviceType::kCPU
? 0
: static_cast<DeviceId::device_index_t>(dev_phy_id);
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id),
parallel_desc.device_type(), device_index};
StreamId::stream_index_t stream_index{};
if (op_node->op().op_conf().has_stream_index_hint()) {
int32_t stream_index_hint = op_node->op().op_conf().stream_index_hint();
LOG(INFO) << "set op: " << op_node->op().op_name() << " to stream: " << stream_index_hint;
stream_index = static_cast<StreamId::index_t>(stream_index_hint);
stream_index = static_cast<StreamId::stream_index_t>(stream_index_hint);
} else {
stream_index = StreamIndexGetterRegistryManager::Get().StreamIndex4DeviceIdAndTaskType(
device_id, comp_task_node->GetTaskType());
......@@ -522,8 +523,9 @@ TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const int64_t dev_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));
DeviceType device_type = dst_parallel_desc.device_type();
auto device_index =
(device_type == DeviceType::kCPU ? 0 : static_cast<DeviceId::index_t>(dev_id));
MemZoneId mem_zone_id{static_cast<MemZoneId::index_t>(dst_machine_id), device_type, device_index};
(device_type == DeviceType::kCPU ? 0 : static_cast<DeviceId::node_index_t>(dev_id));
MemZoneId mem_zone_id{static_cast<MemZoneId::node_index_t>(dst_machine_id), device_type,
device_index};
return GetProxyNode(src_node, lbi, mem_zone_id);
}
......
......@@ -65,10 +65,11 @@ TaskId DecodeTaskIdFromInt64(int64_t task_id_val) {
int64_t device_index = (task_id_val & kDeviceIndexInt64Mask) >> kDeviceIndexShift;
int64_t stream_index = (task_id_val & kStreamIndexInt64Mask) >> kStreamIndexShift;
int64_t task_index = task_id_val & kTaskIndexInt64Mask;
StreamId stream_id{
static_cast<DeviceId::index_t>(node_index), static_cast<DeviceType>(device_type),
static_cast<DeviceId::index_t>(device_index), static_cast<StreamId::index_t>(stream_index)};
return TaskId{stream_id, static_cast<TaskId::index_t>(task_index)};
StreamId stream_id{static_cast<DeviceId::node_index_t>(node_index),
static_cast<DeviceType>(device_type),
static_cast<DeviceId::device_index_t>(device_index),
static_cast<StreamId::stream_index_t>(stream_index)};
return TaskId{stream_id, static_cast<TaskId::task_index_t>(task_index)};
}
int64_t MachineId4ActorId(int64_t actor_id) {
......
......@@ -22,18 +22,19 @@ namespace oneflow {
class TaskId {
public:
using index_t = uint32_t;
using task_index_t = uint32_t;
const static size_t kTaskIndexBits = 21;
constexpr static index_t kMaxTaskIndex = (index_t{1} << kTaskIndexBits) - index_t{1};
constexpr static task_index_t kMaxTaskIndex =
(task_index_t{1} << kTaskIndexBits) - task_index_t{1};
TaskId(const StreamId& stream_id, index_t task_index)
TaskId(const StreamId& stream_id, task_index_t task_index)
: stream_id_(stream_id), task_index_(task_index) {
CHECK_LE(task_index_, kMaxTaskIndex);
}
const StreamId& stream_id() const { return stream_id_; }
index_t task_index() const { return task_index_; }
task_index_t task_index() const { return task_index_; }
bool operator==(const TaskId& rhs) const {
return stream_id_ == rhs.stream_id_ && task_index_ == rhs.task_index_;
......@@ -42,13 +43,13 @@ class TaskId {
size_t hash() const {
size_t hash = stream_id_.hash();
HashCombine(&hash, std::hash<index_t>{}(task_index_));
HashCombine(&hash, std::hash<task_index_t>{}(task_index_));
return hash;
}
private:
StreamId stream_id_;
index_t task_index_;
task_index_t task_index_;
};
int64_t EncodeTaskIdToInt64(const TaskId&);
......
......@@ -22,7 +22,7 @@ namespace oneflow {
class TaskIdGenerator final {
public:
using task_index_t = TaskId::index_t;
using task_index_t = TaskId::task_index_t;
TaskIdGenerator() = default;
OF_DISALLOW_COPY_AND_MOVE(TaskIdGenerator);
......
......@@ -32,7 +32,7 @@ constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDe
const MemZoneId kInvalidMemZoneId = MemZoneId{0, DeviceType::kInvalidDevice, 0};
MemZoneId GetNodeCPUMemZoneId(MemZoneId::index_t node_index) {
MemZoneId GetNodeCPUMemZoneId(MemZoneId::node_index_t node_index) {
return MemZoneId{node_index, DeviceType::kCPU, 0};
}
......@@ -47,9 +47,9 @@ MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) {
int64_t node_index = (mem_zone_id & kMemZoneIdNodeIndexInt64Mask) >> kMemZoneIdNodeIndexShift;
int64_t device_type = (mem_zone_id & kMemZoneIdDeviceTypeInt64Mask) >> kMemZoneIdDeviceTypeShift;
int64_t device_index = mem_zone_id & kMemZoneIdDeviceIndexInt64Mask;
return MemZoneId(static_cast<MemZoneId::index_t>(node_index),
return MemZoneId(static_cast<MemZoneId::node_index_t>(node_index),
static_cast<DeviceType>(device_type),
static_cast<MemZoneId::index_t>(device_index));
static_cast<MemZoneId::device_index_t>(device_index));
}
} // namespace oneflow
......@@ -25,7 +25,7 @@ using MemZoneId = DeviceId;
int64_t EncodeMemZoneIdToInt64(const MemZoneId&);
MemZoneId DecodeMemZoneIdFromInt64(int64_t);
MemZoneId GetNodeCPUMemZoneId(MemZoneId::index_t node_index);
MemZoneId GetNodeCPUMemZoneId(MemZoneId::node_index_t node_index);
extern const MemZoneId kInvalidMemZoneId;
......
......@@ -59,9 +59,10 @@ StreamId DecodeStreamIdFromInt64(int64_t stream_id_val) {
int64_t device_type = (stream_id_val & kDeviceTypeInt64Mask) >> kDeviceTypeShift;
int64_t device_index = (stream_id_val & kDeviceIndexInt64Mask) >> kDeviceIndexShift;
int64_t stream_index = (stream_id_val & kStreamIndexInt64Mask);
return StreamId{static_cast<DeviceId::index_t>(node_index), static_cast<DeviceType>(device_type),
static_cast<DeviceId::index_t>(device_index),
static_cast<StreamId::index_t>(stream_index)};
return StreamId{static_cast<DeviceId::node_index_t>(node_index),
static_cast<DeviceType>(device_type),
static_cast<DeviceId::device_index_t>(device_index),
static_cast<StreamId::stream_index_t>(stream_index)};
}
} // namespace oneflow
......@@ -22,26 +22,27 @@ namespace oneflow {
class StreamId {
public:
using index_t = uint32_t;
using stream_index_t = uint32_t;
constexpr static size_t kStreamIndexBits = 12;
constexpr static index_t kMaxStreamIndex = (index_t{1} << kStreamIndexBits) - index_t{1};
constexpr static stream_index_t kMaxStreamIndex =
(stream_index_t{1} << kStreamIndexBits) - stream_index_t{1};
StreamId(const DeviceId& device_id, index_t stream_index)
StreamId(const DeviceId& device_id, stream_index_t stream_index)
: device_id_(device_id), stream_index_(stream_index) {
CHECK_LE(stream_index, kMaxStreamIndex);
}
StreamId(DeviceId::index_t node_index, DeviceType device_type, DeviceId::index_t device_index,
index_t stream_index)
StreamId(DeviceId::node_index_t node_index, DeviceType device_type,
DeviceId::node_index_t device_index, stream_index_t stream_index)
: device_id_(node_index, device_type, device_index), stream_index_(stream_index) {
CHECK_LE(stream_index, kMaxStreamIndex);
}
const DeviceId& device_id() const { return device_id_; }
DeviceId::index_t node_index() const { return device_id_.node_index(); }
DeviceId::node_index_t node_index() const { return device_id_.node_index(); }
DeviceType device_type() const { return device_id_.device_type(); }
DeviceId::index_t device_index() const { return device_id_.device_index(); }
index_t stream_index() const { return stream_index_; }
DeviceId::node_index_t device_index() const { return device_id_.device_index(); }
stream_index_t stream_index() const { return stream_index_; }
bool operator==(const StreamId& rhs) const {
return device_id_ == rhs.device_id_ && stream_index_ == rhs.stream_index_;
......@@ -51,13 +52,13 @@ class StreamId {
size_t hash() const {
size_t hash = device_id_.hash();
HashCombine(&hash, std::hash<index_t>{}(stream_index_));
HashCombine(&hash, std::hash<stream_index_t>{}(stream_index_));
return hash;
}
private:
DeviceId device_id_;
index_t stream_index_;
stream_index_t stream_index_;
};
int64_t EncodeStreamIdToInt64(const StreamId&);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册