提交 4547bf9c 编写于 作者: W willzhang4a58

refine int64

上级 9343dc71
......@@ -2,5 +2,5 @@ syntax = "proto3";
package oneflow;
message ShapeProto {
repeated uint64 shapes = 1;
repeated int64 shapes = 1;
}
......@@ -12,8 +12,8 @@ class CompTaskNode : public TaskNode {
CompTaskNode() = default;
virtual ~CompTaskNode() = default;
int64_t parallel_id() const { return parallel_id_; }
void set_parallel_id(int64_t parallel_id) { parallel_id_ = parallel_id; }
uint64_t parallel_id() const { return parallel_id_; }
void set_parallel_id(uint64_t parallel_id) { parallel_id_ = parallel_id; }
bool IsLossNode() const { TODO(); }
......@@ -52,7 +52,7 @@ class CompTaskNode : public TaskNode {
const HashMap<ExecEdge*, const ExecEdge*>& bp_edge2fw_edge);
void BpSetModelDiffRegst();
int64_t parallel_id_;
uint64_t parallel_id_;
};
......
......@@ -7,7 +7,7 @@ MdUpdtTaskGraph::MdUpdtTaskGraph(
const ChainNode* data_chain,
const std::vector<CompTaskNode*>& sorted_bp_comptasks4data_chain) {
BuildTaskGraph(data_chain);
HashMap<int64_t, CompTaskNode*> parallel_id2updt;
HashMap<uint64_t, CompTaskNode*> parallel_id2updt;
InitFaker2MccoyAndParallelId2UpdtMap(sorted_bp_comptasks4data_chain,
&parallel_id2updt);
BuildExecAndProducedRegsts();
......@@ -42,7 +42,7 @@ void MdUpdtTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
void MdUpdtTaskGraph::InitFaker2MccoyAndParallelId2UpdtMap(
const std::vector<CompTaskNode*>& sorted_bp_comptasks4data_chain,
HashMap<int64_t, CompTaskNode*>* parallel_id2updt) {
HashMap<uint64_t, CompTaskNode*>* parallel_id2updt) {
std::vector<CompTaskNode*> comptasks4faker_chain;
for (const std::unique_ptr<TaskNode>& node : nodes()) {
CompTaskNode* comp_node = dynamic_cast<CompTaskNode*> (node.get());
......@@ -63,10 +63,10 @@ void MdUpdtTaskGraph::InitFaker2MccoyAndParallelId2UpdtMap(
void MdUpdtTaskGraph::CompleteUpdateTaskAndFwTask(
const std::vector<CompTaskNode*>& sorted_bp_comptasks4data_chain,
const HashMap<int64_t, CompTaskNode*>& parallel_id2updt) {
const HashMap<uint64_t, CompTaskNode*>& parallel_id2updt) {
for (CompTaskNode* bp_task : sorted_bp_comptasks4data_chain) {
// useful vars
int64_t parallel_id = bp_task->parallel_id();
uint64_t parallel_id = bp_task->parallel_id();
CompTaskNode* update_task = parallel_id2updt.at(parallel_id);
TaskNode* fw_task = bp_task->GetFwNode();
RegstDesc* model_diff_regst = bp_task->GetProducedRegstDesc("model_diff");
......
......@@ -23,10 +23,10 @@ class MdUpdtTaskGraph final : public TaskGraph {
void BuildTaskGraph(const ChainNode* data_chain);
void InitFaker2MccoyAndParallelId2UpdtMap(
const std::vector<CompTaskNode*>& sorted_bp_comptasks4data_chain,
HashMap<int64_t, CompTaskNode*>* parallel_id2updt);
HashMap<uint64_t, CompTaskNode*>* parallel_id2updt);
void CompleteUpdateTaskAndFwTask(
const std::vector<CompTaskNode*>& sorted_bp_comptasks4data_chain,
const HashMap<int64_t, CompTaskNode*>& parallel_id2updt);
const HashMap<uint64_t, CompTaskNode*>& parallel_id2updt);
};
} // namespace oneflow
......
......@@ -30,7 +30,7 @@ class RegstDesc {
static const char* kAllLbn;
private:
int64_t regst_desc_id_;
uint64_t regst_desc_id_;
const TaskNode* producer_;
HashMap<std::string, std::unique_ptr<Shape>> lbn2shape_;
......
......@@ -8,5 +8,5 @@ message RegstDescProto {
uint64 producer_task_id = 2;
repeated uint64 subscriber_task_ids = 3;
map<string, ShapeProto> lbn2shape = 4;
uint64 register_num = 5;
int64 register_num = 5;
}
......@@ -10,8 +10,8 @@ StageGraph::StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph) {
for (const std::unique_ptr<ChainNode>& cur_chain : chain_gph_->nodes()) {
chain2stages[cur_chain.get()] = {};
auto parallel_desc = cur_chain->parallel_desc();
int64_t range_idx = 0;
for (int64_t machine_id : parallel_desc->sorted_machine_ids()) {
uint64_t range_idx = 0;
for (uint64_t machine_id : parallel_desc->sorted_machine_ids()) {
StageNode* stage_node = NewFinalNode();
stage_node->mut_machine_id() = machine_id;
stage_node->set_chain_node(cur_chain.get());
......
......@@ -14,10 +14,10 @@ class StageNode final : public Node<StageNode, StageEdge> {
StageNode() = default;
~StageNode() = default;
const int64_t& machine_id() const {
const uint64_t& machine_id() const {
return machine_id_;
}
int64_t& mut_machine_id() {
uint64_t& mut_machine_id() {
return machine_id_;
}
......@@ -35,13 +35,13 @@ class StageNode final : public Node<StageNode, StageEdge> {
return parallel_range_;
}
const std::vector<int64_t>& SortedDevicePhyIds() const {
const std::vector<uint64_t>& SortedDevicePhyIds() const {
return chain_node_->parallel_desc()->sorted_device_phy_ids(machine_id_);
}
private:
const ChainNode* chain_node_;
int64_t machine_id_;
uint64_t machine_id_;
Range parallel_range_;
};
......
......@@ -73,9 +73,9 @@ void TaskGraph::Stage2DeviceCompTaskNodes(
TaskNodesInStage* task_nodes_in_stage,
bool is_first_stage,
bool is_last_stage) {
int64_t parallel_idx = stage->parallel_range().begin();
uint64_t parallel_idx = stage->parallel_range().begin();
for (auto device_phy_id : stage->SortedDevicePhyIds()) {
int64_t thread_local_id =
uint64_t thread_local_id =
IDMgr::Singleton().ThrdLocId4DevicePhyId(device_phy_id);
// comp_task_node
DeviceCompTaskNode* comp_task_node = NewTaskNode<DeviceCompTaskNode> ();
......@@ -113,9 +113,9 @@ void TaskGraph::Stage2DeviceCompTaskNodes(
void TaskGraph::Stage2HostCompTaskNodes(const StageNode* stage,
TaskNodesInStage* task_nodes_in_stage) {
int64_t parallel_begin = stage->parallel_range().begin();
int64_t parallel_end = stage->parallel_range().end();
int64_t parallel_idx = parallel_begin;
uint64_t parallel_begin = stage->parallel_range().begin();
uint64_t parallel_end = stage->parallel_range().end();
uint64_t parallel_idx = parallel_begin;
while (parallel_idx < parallel_end) {
HostCompTaskNode* comp_task_node = NewTaskNode<HostCompTaskNode> ();
comp_task_node->set_stage_node(stage);
......
......@@ -20,7 +20,7 @@ void TaskNode::set_stage_node(const StageNode* new_stage_node) {
CHECK(IsFwNode());
stage_node_ = new_stage_node;
}
int64_t& TaskNode::mut_thrd_loc_id() {
uint64_t& TaskNode::mut_thrd_loc_id() {
CHECK(IsFwNode());
return thrd_loc_id_;
}
......
......@@ -23,13 +23,13 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
TaskNode* GetBpNode() const;
const ChainNode* chain_node() const { return stage_node_->chain_node();}
const StageNode* stage_node() const { return stage_node_; }
const int64_t& thrd_loc_id() const { return thrd_loc_id_; }
const uint64_t& thrd_loc_id() const { return thrd_loc_id_; }
const ExecGraph& exec_gph() const { return exec_gph_; }
// Setters
void SetFwNode() { is_fw_node_ = true; }
void set_stage_node(const StageNode*);
int64_t& mut_thrd_loc_id();
uint64_t& mut_thrd_loc_id();
// return bp_node
std::unique_ptr<TaskNode> BuildAndConnectBpNode();
......@@ -59,7 +59,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
private:
// In task_gph level
const StageNode* stage_node_;
int64_t thrd_loc_id_;
uint64_t thrd_loc_id_;
bool is_fw_node_;
TaskNode* related_fw_or_bp_node_;
// In task level
......
......@@ -4,9 +4,11 @@ package oneflow;
import "operator/operator.proto";
import "graph/register_desc.proto";
import "task/task.proto";
import "job/id_manager.proto";
message OfElf {
repeated TaskProto tasks = 1;
repeated RegstDescProto regst_descs = 2;
repeated OperatorProto operators = 4;
repeated OperatorProto operators = 3;
IDMgrProto id_mgr = 4;
}
......@@ -18,10 +18,10 @@ class ParallelDesc {
// Getters
const ParallelPolicy& policy() const { return policy_; }
const DeviceType& device_type() const { return device_type_; }
const std::vector<int64_t>& sorted_machine_ids() const {
const std::vector<uint64_t>& sorted_machine_ids() const {
return sorted_machine_ids_;
}
const std::vector<int64_t>& sorted_device_phy_ids(int64_t machine_id) const {
const std::vector<uint64_t>& sorted_device_phy_ids(int64_t machine_id) const {
// If this is used to describe the disk
// the return shouble be empty
return machine_id2sorted_device_phy_ids_.at(machine_id);
......@@ -39,8 +39,8 @@ class ParallelDesc {
private:
ParallelPolicy policy_;
DeviceType device_type_;
std::vector<int64_t> sorted_machine_ids_;
HashMap<int64_t, std::vector<int64_t>> machine_id2sorted_device_phy_ids_;
std::vector<uint64_t> sorted_machine_ids_;
HashMap<uint64_t, std::vector<uint64_t>> machine_id2sorted_device_phy_ids_;
};
......
......@@ -14,7 +14,7 @@ enum DeviceType {
message Resource {
repeated Machine machines = 1;
uint64 device_num_per_machine = 2;
int64 device_num_per_machine = 2;
DeviceType device_type = 3;
}
......
......@@ -47,16 +47,16 @@ message CopyOpConf {
}
message CloneOpConf {
uint64 out_num = 1;
int64 out_num = 1;
string lbn = 2;
}
message BoxConcatConf {
uint64 axis = 1;
int64 axis = 1;
}
message BoxSplitConf {
uint64 axis = 1;
int64 axis = 1;
}
message BoxCloneConf {
......@@ -64,8 +64,8 @@ message BoxCloneConf {
message BoxingOpConf {
string lbn = 1;
uint64 in_num = 2;
uint64 out_num = 3;
int64 in_num = 2;
int64 out_num = 3;
BoxConcatConf concat_box = 4;
oneof out_box {
BoxSplitConf split_box = 5;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册