提交 0252bca8 编写于 作者: J Jinhui Yuan 提交者: GitHub

sketch of merge reduce project (#1159)

* sketch of merge reduce project

* add reduce_concat, reduce_split in logical graph (#1160)

* add reduce_concat, reduce_split in logical graph

* init ReduceTaskNodes in CollectReduceTaskNodes

* add CompTaskNode for ReduceConcat & ReduceSplit

* set ReduceConcat/Split color index

* copy blob desc from ReduceConcat in to ReduceSplit out

* refine CollectReduceTaskNodes

* SetMemSharing for ReduceConcat, ReduceSplit regst

* complete ReduceConcat & ReduceSplit op

* fill ReduceConcat & ReduceSplit kernel

* simplify ReduceConcatCompActor

* make ReduceScatter & ReduceSplit as input-wise actor

* reduce_scatter & reduce_split use is_inplace

* use ByteSizeOfBlobBody for reduce related packed blob

* Fix dev merge reduce (#1168)

* check concat and split occur simultaneously

* fix ReduceScatter & ReduceSplit as Inputwise actor

* ReduceConcat & ReduceSplit works

* fix single gpu issue

* Refactor reduce (#1170)

* backup, not complete yet

* remove reduce_id

* rm useless comment

* add reduce_graph (#1169)

* add reduce_graph

* fix iter

* add IsLogicalNodeMergeable and fix bug

* remove needless constructor calls

* node VisualStr may conflict, using node_id_str instead

* reduce group works (#1171)

* refine

* sort nodes in topo (#1172)

* add reduce_group_size in job_conf, fix 121 config of ReduceSplit and MdUpdt

* resolve code review issues (variable names)

* refine variable names

* Dev merge reduce rename reduce group (#1174)

* ReduceGraph=>ChainLogicalGraph

* rename Group=>Chain

* reformat

* use pointer instead of reference for mutable argument

* format change

* worker node only pull sub_plan (#1176)

* log compile time

* use c++11 member initialization syntax

* FixPackedBlobDescOfProducedRegst for ReduceSplit

* Dev merge reduce refine chain logical graph (#1177)

* remove IsMerageable

* split TryMergeOneChain and rename to TryMergeTwoChains

* reformat

* resolve review issues


Former-commit-id: 3aa79c70
上级 216c4585
......@@ -11,6 +11,4 @@ void NaiveActor::Act() {
});
}
REGISTER_ACTOR(TaskType::kReduceScatter, NaiveActor);
} // namespace oneflow
#include "oneflow/core/actor/reduce_concat_compute_actor.h"
namespace oneflow {
void ReduceConcatCompActor::VirtualCompActorInit(const TaskProto& proto) {
InputWiseCompActor::Init(proto);
}
void ReduceConcatCompActor::SetKernelCtxOther(void** other) {
int64_t in_bn_id = InBnId4RegstDescId(cur_processed_regst_desc_id());
other_val_ = std::make_pair(in_bn_id, EnableInplace());
*other = static_cast<void*>(&other_val_);
}
REGISTER_ACTOR(TaskType::kReduceConcat, ReduceConcatCompActor);
} // namespace oneflow
#ifndef ONEFLOW_CORE_ACTOR_REDUCE_CONCAT_COMPUTE_ACTOR_H_
#define ONEFLOW_CORE_ACTOR_REDUCE_CONCAT_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/input_wise_compute_actor.h"
namespace oneflow {
class ReduceConcatCompActor final : public InputWiseCompActor {
public:
OF_DISALLOW_COPY_AND_MOVE(ReduceConcatCompActor);
ReduceConcatCompActor() = default;
~ReduceConcatCompActor() = default;
private:
void VirtualCompActorInit(const TaskProto& proto) override;
void SetKernelCtxOther(void** other) override;
std::pair<int64_t, bool> other_val_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_REDUCE_CONCAT_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/reduce_scatter_compute_actor.h"
namespace oneflow {
void ReduceScatterCompActor::VirtualCompActorInit(const TaskProto& proto) {
InputWiseCompActor::Init(proto);
}
void ReduceScatterCompActor::SetKernelCtxOther(void** other) {
other_val_ = EnableInplace();
*other = static_cast<void*>(&other_val_);
}
REGISTER_ACTOR(TaskType::kReduceScatter, ReduceScatterCompActor);
} // namespace oneflow
#ifndef ONEFLOW_CORE_ACTOR_REDUCE_SCATTER_COMPUTE_ACTOR_H_
#define ONEFLOW_CORE_ACTOR_REDUCE_SCATTER_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/input_wise_compute_actor.h"
namespace oneflow {
class ReduceScatterCompActor final : public InputWiseCompActor {
public:
OF_DISALLOW_COPY_AND_MOVE(ReduceScatterCompActor);
ReduceScatterCompActor() = default;
~ReduceScatterCompActor() = default;
private:
void VirtualCompActorInit(const TaskProto& proto) override;
void SetKernelCtxOther(void** other) override;
bool other_val_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_REDUCE_SCATTER_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/reduce_split_compute_actor.h"
namespace oneflow {
void ReduceSplitCompActor::VirtualCompActorInit(const TaskProto& proto) {
InputWiseCompActor::Init(proto);
}
void ReduceSplitCompActor::SetKernelCtxOther(void** other) {
other_val_ = EnableInplace();
*other = static_cast<void*>(&other_val_);
}
REGISTER_ACTOR(TaskType::kReduceSplit, ReduceSplitCompActor);
} // namespace oneflow
#ifndef ONEFLOW_CORE_ACTOR_REDUCE_SPLIT_COMPUTE_ACTOR_H_
#define ONEFLOW_CORE_ACTOR_REDUCE_SPLIT_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/input_wise_compute_actor.h"
namespace oneflow {
class ReduceSplitCompActor final : public InputWiseCompActor {
public:
OF_DISALLOW_COPY_AND_MOVE(ReduceSplitCompActor);
ReduceSplitCompActor() = default;
~ReduceSplitCompActor() = default;
private:
void VirtualCompActorInit(const TaskProto& proto) override;
void SetKernelCtxOther(void** other) override;
bool other_val_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_REDUCE_SPLIT_COMPUTE_ACTOR_H_
......@@ -69,34 +69,13 @@ void CommNet::AddWorkToStream(void* actor_read_id, const std::function<void()>&
}
CommNet::CommNet(const Plan& plan) {
HashMap<int64_t, int64_t> rid2mid;
HashMap<int64_t, int64_t> tid2mid;
int64_t this_machine_id = Global<MachineCtx>::Get()->this_machine_id();
for (const TaskProto& task_proto : plan.task()) {
for (const auto& regst_desc_it : task_proto.produced_regst_desc()) {
rid2mid.emplace(regst_desc_it.second.regst_desc_id(), task_proto.machine_id());
}
CHECK(tid2mid.emplace(task_proto.task_id(), task_proto.machine_id()).second);
}
for (const TaskProto& task_proto : plan.task()) {
if (task_proto.machine_id() != this_machine_id) { continue; }
for (const auto& regst_desc_set_it : task_proto.consumed_regst_desc_id()) {
for (int64_t regst_desc_id : regst_desc_set_it.second.regst_desc_id()) {
auto rid2mid_it = rid2mid.find(regst_desc_id);
CHECK(rid2mid_it != rid2mid.end());
peer_machine_id_.insert(rid2mid_it->second);
}
}
for (const auto& regst_desc_it : task_proto.produced_regst_desc()) {
for (int64_t consumer_task_id : regst_desc_it.second.consumer_task_id()) {
auto tid2mid_it = tid2mid.find(consumer_task_id);
CHECK(tid2mid_it != tid2mid.end());
peer_machine_id_.insert(tid2mid_it->second);
}
}
}
peer_machine_id_.erase(this_machine_id);
HashMap<int64_t, MachineIds> net_topo;
net_topo = PbMap2HashMap(plan.net_topo().peer_machine_ids());
auto machine_ids_it = net_topo.find(this_machine_id);
CHECK(machine_ids_it != net_topo.end());
std::vector<int64_t> peer_machine_ids = PbRf2StdVec(machine_ids_it->second.machine_id());
peer_machine_id_.insert(peer_machine_ids.begin(), peer_machine_ids.end());
ready_cb_poller_ = std::thread([this]() {
std::function<void()> cb;
......
#include "oneflow/core/operator/fully_connected_op.h"
#include "oneflow/core/graph/chain_logical_graph.h"
#include "oneflow/core/graph/logical_graph.h"
namespace oneflow {
struct ChainLogicalGraph::Chain {
std::vector<const LogicalNode*> nodes;
HashSet<const LogicalNode*> ancestors;
HashSet<const LogicalNode*> ancestors_and_this;
HashSet<const LogicalNode*> descendants;
HashSet<const LogicalNode*> descendants_and_this;
bool is_mergeable;
bool IsParallelDescEqual(const Chain& rhs) const {
CHECK_GT(nodes.size(), 0);
CHECK_GT(rhs.nodes.size(), 0);
return nodes.front()->parallel_desc()->Equal(rhs.nodes.front()->parallel_desc().get());
}
};
ChainLogicalGraph::ChainLogicalGraph(const LogicalGraph& logical_graph) {
std::list<Chain> chain_list;
HashMap<const LogicalNode*, std::list<Chain>::iterator> logical2chain_it;
HashMap<const LogicalNode*, size_t> logical2order_in_topo;
InitChains(logical_graph, &chain_list, &logical2chain_it, &logical2order_in_topo);
MergeChains(&chain_list, &logical2chain_it);
SortNodesInChains(&chain_list, logical2order_in_topo);
BuildGraph(logical_graph, &chain_list);
ToDotWithAutoFilePath();
}
void ChainLogicalGraph::InitChains(
const LogicalGraph& logical_graph, std::list<Chain>* chain_list,
HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it,
HashMap<const LogicalNode*, size_t>* logical2order_in_topo) {
logical_graph.ForEachNode([&](const LogicalNode* node) {
chain_list->emplace_back();
logical2chain_it->insert({node, --chain_list->end()});
Chain& chain = chain_list->back();
chain.nodes = {node};
chain.is_mergeable = IsLogicalNodeMergeable(node);
size_t order_in_topo = logical2order_in_topo->size();
logical2order_in_topo->emplace(node, order_in_topo);
});
logical_graph.TopoForEachNode([&](const LogicalNode* node) {
auto cur_chain = logical2chain_it->at(node);
for (const LogicalEdge* edge : node->in_edges()) {
LogicalNode* pred_node = edge->src_node();
auto pred_chain = logical2chain_it->at(pred_node);
cur_chain->ancestors.insert(pred_chain->ancestors.begin(), pred_chain->ancestors.end());
cur_chain->ancestors.insert(pred_node);
}
cur_chain->ancestors_and_this.insert(cur_chain->ancestors.begin(), cur_chain->ancestors.end());
cur_chain->ancestors_and_this.insert(cur_chain->nodes.begin(), cur_chain->nodes.end());
});
logical_graph.ReverseTopoForEachNode([&](const LogicalNode* node) {
auto cur_chain = logical2chain_it->at(node);
for (const LogicalEdge* edge : node->out_edges()) {
LogicalNode* succ_node = edge->dst_node();
auto succ_chain = logical2chain_it->at(succ_node);
cur_chain->descendants.insert(succ_chain->descendants.begin(), succ_chain->descendants.end());
cur_chain->descendants.insert(succ_node);
}
cur_chain->descendants_and_this.insert(cur_chain->descendants.begin(),
cur_chain->descendants.end());
cur_chain->descendants_and_this.insert(cur_chain->nodes.begin(), cur_chain->nodes.end());
});
}
void ChainLogicalGraph::MergeChains(
std::list<Chain>* chain_list,
HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) {
while (chain_list->size() > 1 && TryMergeTwoChains(chain_list, logical2chain_it)) {};
}
bool ChainLogicalGraph::TryMergeTwoChains(
std::list<Chain>* chain_list,
HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) {
return TryMergeTwoParallelChains(chain_list, logical2chain_it)
|| TryMergeTwoConnectedChains(chain_list, logical2chain_it);
}
bool ChainLogicalGraph::TryMergeTwoParallelChains(
std::list<Chain>* chain_list,
HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) {
for (auto lhs = chain_list->begin(); lhs != chain_list->end(); ++lhs) {
if (!lhs->is_mergeable) { continue; }
for (auto rhs = lhs; rhs != chain_list->end(); ++rhs) {
if (lhs == rhs) { continue; }
if (!rhs->is_mergeable) { continue; }
if (!lhs->IsParallelDescEqual(*rhs)) { continue; }
if (lhs->ancestors != rhs->ancestors || lhs->descendants != rhs->descendants) { continue; }
for (const LogicalNode* node : rhs->nodes) {
lhs->nodes.push_back(node);
lhs->ancestors_and_this.insert(node);
lhs->descendants_and_this.insert(node);
logical2chain_it->at(node) = lhs;
}
chain_list->erase(rhs);
return true;
}
}
return false;
}
bool ChainLogicalGraph::TryMergeTwoConnectedChains(
std::list<Chain>* chain_list,
HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) {
for (auto succ_chain_it = chain_list->begin(); succ_chain_it != chain_list->end();
++succ_chain_it) {
if (!succ_chain_it->is_mergeable) { continue; }
for (const LogicalNode* node_in_succ : succ_chain_it->nodes) {
for (const LogicalEdge* in_edge : node_in_succ->in_edges()) {
auto pred_chain_it = logical2chain_it->at(in_edge->src_node());
if (pred_chain_it == succ_chain_it) { continue; }
if (!pred_chain_it->is_mergeable) { continue; }
if (!pred_chain_it->IsParallelDescEqual(*succ_chain_it)) { continue; }
if (pred_chain_it->ancestors_and_this != succ_chain_it->ancestors
|| pred_chain_it->descendants != succ_chain_it->descendants_and_this) {
continue;
}
for (const LogicalNode* node : succ_chain_it->nodes) {
pred_chain_it->nodes.push_back(node);
pred_chain_it->ancestors_and_this.insert(node);
pred_chain_it->descendants.erase(node);
logical2chain_it->at(node) = pred_chain_it;
}
chain_list->erase(succ_chain_it);
return true;
}
}
}
return false;
}
void ChainLogicalGraph::SortNodesInChains(
std::list<Chain>* chain_list,
const HashMap<const LogicalNode*, size_t>& logical2order_in_topo) {
for (Chain& chain : *chain_list) {
std::sort(chain.nodes.begin(), chain.nodes.end(),
[&](const LogicalNode* a, const LogicalNode* b) {
return logical2order_in_topo.at(a) < logical2order_in_topo.at(b);
});
}
}
void ChainLogicalGraph::BuildGraph(const LogicalGraph& logical_graph,
std::list<Chain>* chain_list) {
HashMap<const LogicalNode*, ChainLogicalNode*> logical_node2chain_logical_node;
for (const Chain& chain : *chain_list) {
ChainLogicalNode* chain_logical_node = NewNode();
chain_logical_node->mut_logical_nodes() = chain.nodes;
for (const LogicalNode* node : chain.nodes) {
CHECK(logical_node2chain_logical_node.emplace(node, chain_logical_node).second);
}
}
std::unordered_set<std::pair<ChainLogicalNode*, ChainLogicalNode*>> pred_succ_pairs;
logical_graph.ForEachEdge([&](const LogicalEdge* edge) {
pred_succ_pairs.emplace(logical_node2chain_logical_node.at(edge->src_node()),
logical_node2chain_logical_node.at(edge->dst_node()));
});
for (auto& pair : pred_succ_pairs) {
if (pair.first == pair.second) { continue; }
ChainLogicalEdge* edge = NewEdge();
Connect(pair.first, edge, pair.second);
}
}
bool ChainLogicalGraph::IsLogicalNodeMergeable(const LogicalNode* logical_node) const {
if (logical_node->parallel_desc()->policy() != kDataParallel) { return false; }
if (!dynamic_cast<const NormalForwardLogicalNode*>(logical_node)) { return false; }
for (const std::shared_ptr<Operator>& op : logical_node->op_vec()) {
if (dynamic_cast<FullyConnectedOp*>(op.get())) { return false; }
if (op->IsRecurrentOp()) { return false; }
}
return true;
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_
#include "oneflow/core/graph/graph.h"
#include "oneflow/core/graph/logical_graph.h"
namespace oneflow {
class ChainLogicalEdge;
class ChainLogicalNode final : public Node<ChainLogicalNode, ChainLogicalEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainLogicalNode);
ChainLogicalNode() = default;
~ChainLogicalNode() override = default;
const std::vector<const LogicalNode*>& logical_nodes() const { return logical_nodes_; }
std::vector<const LogicalNode*>& mut_logical_nodes() { return logical_nodes_; }
private:
std::vector<const LogicalNode*> logical_nodes_;
};
class ChainLogicalEdge final : public Edge<ChainLogicalNode, ChainLogicalEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainLogicalEdge);
ChainLogicalEdge() = default;
~ChainLogicalEdge() override = default;
};
class ChainLogicalGraph final : public Graph<ChainLogicalNode, ChainLogicalEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainLogicalGraph);
explicit ChainLogicalGraph(const LogicalGraph& logical_graph);
~ChainLogicalGraph() override = default;
private:
struct Chain;
void InitChains(const LogicalGraph& logical_graph, std::list<Chain>* chain_list,
HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it,
HashMap<const LogicalNode*, size_t>* logical2order_in_topo);
void MergeChains(std::list<Chain>* chain_list,
HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it);
bool TryMergeTwoChains(std::list<Chain>* chain_list,
HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it);
bool TryMergeTwoParallelChains(
std::list<Chain>* chain_list,
HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it);
bool TryMergeTwoConnectedChains(
std::list<Chain>* chain_list,
HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it);
void SortNodesInChains(std::list<Chain>* chain_list,
const HashMap<const LogicalNode*, size_t>& logical2order_in_topo);
void BuildGraph(const LogicalGraph& logical_graph, std::list<Chain>* chain_list);
bool IsLogicalNodeMergeable(const LogicalNode* logical_node) const;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_
......@@ -148,10 +148,12 @@ template<typename NodeType, typename EdgeType>
template<typename StreamT>
void Graph<NodeType, EdgeType>::ToDotWithStream(StreamT& out_stream) {
out_stream << "digraph {\n";
this->ForEachNode([&](NodeType* node) { out_stream << "\"" << node->VisualStr() << "\"\n"; });
this->ForEachNode([&](NodeType* node) {
out_stream << "\"" << node->node_id_str() << "\" [label=\"" << node->VisualStr() << "\"]\n";
});
this->ForEachEdge([&](const EdgeType* edge) {
out_stream << "\"" << edge->src_node()->VisualStr() << "\" -> "
<< "\"" << edge->dst_node()->VisualStr() << "\""
out_stream << "\"" << edge->src_node()->node_id_str() << "\" -> "
<< "\"" << edge->dst_node()->node_id_str() << "\""
<< "[label=\"" << edge->VisualStr() << "\"];\n";
});
out_stream << "}\n";
......
......@@ -2,11 +2,14 @@
#include "oneflow/core/graph/task_graph.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/graph/chain_logical_graph.h"
#include "oneflow/core/common/balanced_splitter.h"
namespace oneflow {
LogicalGraph::LogicalGraph(bool is_train) {
BuildFwStruct();
if (is_train) { GroupNodesForReduceStruct(); }
SetMainModelParallel();
if (is_train) { BuildBwStruct(); }
MergeEdge();
......@@ -20,6 +23,30 @@ LogicalGraph::LogicalGraph(bool is_train) {
ToDotWithAutoFilePath();
}
void LogicalGraph::GroupNodesForReduceStruct() {
ChainLogicalGraph chain_logical_graph(*this);
std::vector<std::vector<const LogicalNode*>> fw_node_groups;
chain_logical_graph.ForEachNode(
[&](ChainLogicalNode* node) { fw_node_groups.emplace_back(node->logical_nodes()); });
for (auto& fw_node_group : fw_node_groups) {
if (fw_node_group.size() < Global<JobDesc>::Get()->reduce_group_size()) {
fw_node_groups_.emplace_back(std::move(fw_node_group));
} else {
int64_t fw_node_group_size = fw_node_group.size();
int64_t seg_num = fw_node_group_size / Global<JobDesc>::Get()->reduce_group_size() + 1;
BalancedSplitter bs(fw_node_group_size, seg_num);
FOR_RANGE(int64_t, idx, 0, seg_num) {
std::vector<const LogicalNode*> sub_fw_node_group;
Range range = bs.At(idx);
FOR_RANGE(int64_t, nid, range.begin(), range.end()) {
sub_fw_node_group.emplace_back(fw_node_group[nid]);
}
fw_node_groups_.emplace_back(std::move(sub_fw_node_group));
}
}
}
}
template<typename LogicalNodeType>
void LogicalGraph::ForEachLogicalNode(std::function<void(LogicalNodeType*)> func) {
std::vector<LogicalNodeType*> valid_nodes;
......@@ -352,6 +379,7 @@ void LogicalGraph::BuildAccuracyPrintStruct() {
void LogicalGraph::BuildModelStruct(bool is_train) {
HashMap<const LogicalNode*, NormalMdUpdtLogicalNode*> first_shared2mdupdt;
HashMap<const LogicalNode*, ReduceCtx> fw_node2reduce_ctx;
ForEachLogicalNode<ForwardLogicalNode>([&](ForwardLogicalNode* fw_logical) {
if (Global<JobDesc>::Get()->enable_write_snapshot()
&& fw_logical->HasOpWithForwardModelBlob()) {
......@@ -397,17 +425,69 @@ void LogicalGraph::BuildModelStruct(bool is_train) {
}
if (md_diff_acc_logical->parallel_desc()->parallel_num() > 1
&& md_diff_acc_logical->parallel_desc()->policy() == kDataParallel) {
BuildReduceStruct(md_diff_acc_logical, md_updt_logical);
ReduceCtx reduce_ctx;
reduce_ctx.fw_logicals.emplace_back(fw_logical);
reduce_ctx.bw_logicals.emplace_back(bw_logical);
reduce_ctx.md_diff_acc_logicals.emplace_back(md_diff_acc_logical);
reduce_ctx.md_updt_logicals.emplace_back(md_updt_logical);
CHECK(fw_node2reduce_ctx.emplace(fw_logical, reduce_ctx).second);
} else {
Connect<LogicalNode>(md_diff_acc_logical, NewEdge(), md_updt_logical);
}
}
}
});
for (auto& fw_node_group : fw_node_groups_) {
ReduceCtx group_reduce_ctx;
for (auto& fw_node : fw_node_group) {
auto reduce_ctx_it = fw_node2reduce_ctx.find(fw_node);
if (reduce_ctx_it != fw_node2reduce_ctx.end()) {
auto& reduce_ctx = reduce_ctx_it->second;
group_reduce_ctx.fw_logicals.emplace_back(reduce_ctx.fw_logicals.at(0));
group_reduce_ctx.bw_logicals.emplace_back(reduce_ctx.bw_logicals.at(0));
group_reduce_ctx.md_diff_acc_logicals.emplace_back(reduce_ctx.md_diff_acc_logicals.at(0));
group_reduce_ctx.md_updt_logicals.emplace_back(reduce_ctx.md_updt_logicals.at(0));
}
}
BuildReduceStruct(group_reduce_ctx);
}
SetupNormalMdUpdtOp();
}
void LogicalGraph::BuildReduceStruct(LogicalNode* src, LogicalNode* dst) {
void LogicalGraph::BuildReduceStruct(const ReduceCtx& reduce_ctx) {
if (reduce_ctx.fw_logicals.size() > 1) {
std::shared_ptr<const ParallelDesc> src_pd = reduce_ctx.fw_logicals[0]->parallel_desc();
OperatorConf reduce_concat_op_conf;
reduce_concat_op_conf.set_name("reduce_concat_" + NewUniqueId());
reduce_concat_op_conf.set_device_type(src_pd->device_type());
reduce_concat_op_conf.mutable_reduce_concat_conf()->set_in_num(reduce_ctx.fw_logicals.size());
LogicalNode* reduce_concat_node = NewNode<ReduceConcatLogicalNode>();
reduce_concat_node->mut_op_vec() = {ConstructOp(reduce_concat_op_conf)};
reduce_concat_node->mut_parallel_desc() = src_pd;
OperatorConf reduce_split_op_conf;
reduce_split_op_conf.set_name("reduce_split_" + NewUniqueId());
reduce_split_op_conf.set_device_type(src_pd->device_type());
reduce_split_op_conf.mutable_reduce_split_conf()->set_out_num(reduce_ctx.fw_logicals.size());
LogicalNode* reduce_split_node = NewNode<ReduceSplitLogicalNode>();
reduce_split_node->mut_op_vec() = {ConstructOp(reduce_split_op_conf)};
reduce_split_node->mut_parallel_desc() = src_pd;
for (auto& md_diff_acc_node : reduce_ctx.md_diff_acc_logicals) {
Connect(md_diff_acc_node, NewEdge(), reduce_concat_node);
}
for (auto& md_updt_node : reduce_ctx.md_updt_logicals) {
Connect(reduce_split_node, NewEdge(), md_updt_node);
}
AddReduceScatterAddGatherNodes(reduce_concat_node, reduce_split_node);
} else if (reduce_ctx.fw_logicals.size() == 1) {
AddReduceScatterAddGatherNodes(reduce_ctx.md_diff_acc_logicals.at(0),
reduce_ctx.md_updt_logicals.at(0));
}
}
void LogicalGraph::AddReduceScatterAddGatherNodes(LogicalNode* src, LogicalNode* dst) {
std::shared_ptr<const ParallelDesc> src_pd = src->parallel_desc();
std::shared_ptr<const ParallelDesc> dst_pd = dst->parallel_desc();
CHECK_EQ(src_pd->parallel_num(), dst_pd->parallel_num());
......
......@@ -27,8 +27,15 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
LogicalBlobId lbi;
std::vector<LogicalEdge*> edges;
};
struct ReduceCtx {
std::vector<LogicalNode*> fw_logicals;
std::vector<LogicalNode*> bw_logicals;
std::vector<LogicalNode*> md_diff_acc_logicals;
std::vector<LogicalNode*> md_updt_logicals;
};
template<typename LogicalNodeType>
void ForEachLogicalNode(std::function<void(LogicalNodeType*)> Handler);
void GroupNodesForReduceStruct();
void BuildFwStruct();
void NaiveBuildFwStruct(HashMap<std::string, std::vector<LogicalNode*>>* op_name2nodes);
......@@ -46,7 +53,8 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
void BuildLossPrintStruct();
void BuildAccuracyPrintStruct();
void BuildModelStruct(bool is_train);
void BuildReduceStruct(LogicalNode* src, LogicalNode* dst);
void AddReduceScatterAddGatherNodes(LogicalNode* src, LogicalNode* dst);
void BuildReduceStruct(const ReduceCtx& reduce_ctx);
void SetupNormalMdUpdtOp();
MdSaveLogicalNode* BuildMdSaveStruct(const ForwardLogicalNode* fw_logical,
LogicalNode* need_save_logical);
......@@ -58,6 +66,7 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
int64_t total_mbn_num_;
std::vector<std::vector<const LogicalNode*>> fw_node_groups_;
HashMap<const LogicalEdge*, std::string> edge2ibn_;
HashMap<const LogicalEdge*, std::string> edge2obn_;
};
......
......@@ -14,6 +14,8 @@
#include "oneflow/core/graph/reduce_local_add_compute_task_node.h"
#include "oneflow/core/graph/reduce_global_add_compute_task_node.h"
#include "oneflow/core/graph/reduce_gather_compute_task_node.h"
#include "oneflow/core/graph/reduce_concat_compute_task_node.h"
#include "oneflow/core/graph/reduce_split_compute_task_node.h"
#include "oneflow/core/graph/accuracy_compute_task_node.h"
#include "oneflow/core/graph/accuracy_accumulate_compute_task_node.h"
#include "oneflow/core/graph/accuracy_print_compute_task_node.h"
......@@ -335,6 +337,18 @@ REGISTER_BLD_SUB_TSK_GPH_MTHD("MdDiffAcc"
REGISTER_BLD_SUB_TSK_GPH_MTHD("NormalBackward"
"NormalMdUpdt",
BldSubTskGphToNormalMdUpdt);
REGISTER_BLD_SUB_TSK_GPH_MTHD("NormalBackward"
"ReduceConcat",
&TaskGraph::BldSubTskGphByOneToOne);
REGISTER_BLD_SUB_TSK_GPH_MTHD("MdDiffAcc"
"ReduceConcat",
&TaskGraph::BldSubTskGphByOneToOne);
REGISTER_BLD_SUB_TSK_GPH_MTHD("ReduceConcat"
"ReduceScatter",
&TaskGraph::BldSubTskGphByOneToOne);
REGISTER_BLD_SUB_TSK_GPH_MTHD("NormalBackward"
"ReduceScatter",
&TaskGraph::BldSubTskGphByOneToOne);
REGISTER_BLD_SUB_TSK_GPH_MTHD("MdDiffAcc"
"ReduceScatter",
&TaskGraph::BldSubTskGphByOneToOne);
......@@ -353,6 +367,12 @@ REGISTER_BLD_SUB_TSK_GPH_MTHD("ReduceGlobalAdd"
REGISTER_BLD_SUB_TSK_GPH_MTHD("ReduceGather"
"NormalMdUpdt",
&TaskGraph::BldSubTskGphByOneToOne);
REGISTER_BLD_SUB_TSK_GPH_MTHD("ReduceSplit"
"NormalMdUpdt",
&TaskGraph::BldSubTskGphByOneToOne);
REGISTER_BLD_SUB_TSK_GPH_MTHD("ReduceGather"
"ReduceSplit",
&TaskGraph::BldSubTskGphByOneToOne);
BldBoxingOpConfMthd GetMthdForBldBoxingOpConf(const LogicalNode* src, const LogicalNode* dst) {
std::string k = ConcatTypeName(src, dst);
......@@ -395,10 +415,12 @@ REGISTER_BLD_BOXING_OP_CONF_MTHD("NormalBackward"
OF_PP_MAKE_TUPLE_SEQ(MdSave, kMdSaveArea) \
OF_PP_MAKE_TUPLE_SEQ(MdDiffAcc, kDataBackwardArea) \
OF_PP_MAKE_TUPLE_SEQ(Print, kPrintArea) \
OF_PP_MAKE_TUPLE_SEQ(ReduceConcat, kMdUpdtArea) \
OF_PP_MAKE_TUPLE_SEQ(ReduceScatter, kMdUpdtArea) \
OF_PP_MAKE_TUPLE_SEQ(ReduceLocalAdd, kMdUpdtArea) \
OF_PP_MAKE_TUPLE_SEQ(ReduceGlobalAdd, kMdUpdtArea) \
OF_PP_MAKE_TUPLE_SEQ(ReduceGather, kMdUpdtArea) \
OF_PP_MAKE_TUPLE_SEQ(ReduceSplit, kMdUpdtArea) \
OF_PP_MAKE_TUPLE_SEQ(Accuracy, kDataForwardArea) \
OF_PP_MAKE_TUPLE_SEQ(AccuracyAcc, kDataForwardArea) \
OF_PP_MAKE_TUPLE_SEQ(AccuracyPrint, kPrintArea)
......
......@@ -203,6 +203,8 @@ DECLARE_NAIVE_LOGICAL_NODE(ReduceScatterLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(ReduceLocalAddLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(ReduceGlobalAddLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(ReduceGatherLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(ReduceConcatLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(ReduceSplitLogicalNode);
} // namespace oneflow
......
......@@ -11,7 +11,8 @@ void NormalBackwardCompTaskNode::ProduceAllRegstsAndBindEdges() {
for (TaskEdge* edge : out_edges()) {
const LogicalNode* succ_logical = GetOneSuccLogicalNodeOnEdge(edge);
if (succ_logical->TypeName() == "MdDiffAcc" || succ_logical->TypeName() == "NormalMdUpdt"
|| succ_logical->TypeName() == "ReduceScatter") {
|| succ_logical->TypeName() == "ReduceScatter"
|| succ_logical->TypeName() == "ReduceConcat") {
edge->AddRegst("model_diff", ProduceRegst("model_diff", true));
} else {
BindEdgeWithProducedRegst(edge, "in_diff");
......
......@@ -43,7 +43,10 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() {
void NormalMdUpdtCompTaskNode::ConsumeAllRegsts() {
if (!IsTrainable()) { return; }
for (TaskEdge* edge : in_edges()) {
ConsumeRegst("model_diff_acc_" + NewUniqueId(), edge->GetSoleRegst());
auto regst_descs = edge->GetRegsts();
for (auto& regst_desc : regst_descs) {
ConsumeRegst("model_diff_acc_" + NewUniqueId(), regst_desc);
}
}
}
......
#include "oneflow/core/graph/reduce_concat_compute_task_node.h"
#include "oneflow/core/graph/logical_node.h"
namespace oneflow {
void ReduceConcatCompTaskNode::ProduceAllRegstsAndBindEdges() {
this->SoleOutEdge()->AddRegst("out", ProduceRegst("out", false, 1, 1));
}
void ReduceConcatCompTaskNode::ConsumeAllRegsts() {
struct EdgeInfo {
int64_t bw_node_order;
TaskEdge* edge;
};
std::vector<EdgeInfo> edge_infos;
for (TaskEdge* edge : in_edges()) {
TaskNode* src_node = edge->src_node();
while (src_node->GetTaskType() != TaskType::kNormalBackward) {
src_node = src_node->SoleInEdge()->src_node();
}
CompTaskNode* bw_node = dynamic_cast<CompTaskNode*>(src_node);
EdgeInfo edge_info{bw_node->order_in_graph(), edge};
edge_infos.emplace_back(edge_info);
}
std::sort(edge_infos.begin(), edge_infos.end(), [](const EdgeInfo& lhs, const EdgeInfo& rhs) {
return lhs.bw_node_order < rhs.bw_node_order;
});
FOR_RANGE(size_t, idx, 0, edge_infos.size()) {
ConsumeRegst("in_" + std::to_string(idx), edge_infos[idx].edge->GetSoleRegst());
}
}
void ReduceConcatCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
std::shared_ptr<Operator> reduce_concat_op = this->logical_node()->SoleOp();
node->mut_op() = reduce_concat_op;
FOR_RANGE(size_t, i, 0, reduce_concat_op->input_bns().size()) {
node->BindBnWithRegst(reduce_concat_op->input_bns().Get(i),
GetSoleConsumedRegst("in_" + std::to_string(i)));
}
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(reduce_concat_op->BnInOp2Lbi(reduce_concat_op->SoleObn()));
node->BindBnWithRegst(reduce_concat_op->SoleObn(), out_regst);
node->InferBlobDescs(parallel_ctx());
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_REDUCE_CONCAT_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_REDUCE_CONCAT_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class ReduceConcatCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(ReduceConcatCompTaskNode);
ReduceConcatCompTaskNode() = default;
~ReduceConcatCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
TaskType GetTaskType() const override { return TaskType::kReduceConcat; }
CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kMix; }
private:
void BuildExecGphAndRegst() override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_REDUCE_CONCAT_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/reduce_split_compute_task_node.h"
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/register/register_desc.h"
namespace oneflow {
void ReduceSplitCompTaskNode::ProduceAllRegstsAndBindEdges() {
struct EdgeInfo {
int64_t bw_node_order;
TaskEdge* edge;
};
std::vector<EdgeInfo> edge_infos;
for (TaskEdge* edge : out_edges()) {
TaskNode* dst_node = edge->dst_node();
CHECK(dst_node->GetTaskType() == TaskType::kNormalMdUpdt);
CompTaskNode* mdupdt_node = dynamic_cast<CompTaskNode*>(dst_node);
for (TaskEdge* mdupdt_edge : mdupdt_node->out_edges()) {
if (IsBackwardTaskType(mdupdt_edge->dst_node()->GetTaskType())) {
CompTaskNode* bw_node = dynamic_cast<CompTaskNode*>(mdupdt_edge->dst_node());
// There may be multiple out_regsts on the same edge for shared_model app
EdgeInfo edge_info{bw_node->order_in_graph(), edge};
edge_infos.emplace_back(edge_info);
}
}
}
std::sort(edge_infos.begin(), edge_infos.end(), [](const EdgeInfo& lhs, const EdgeInfo& rhs) {
return lhs.bw_node_order < rhs.bw_node_order;
});
FOR_RANGE(size_t, idx, 0, edge_infos.size()) {
std::string out_regst_name = "out_" + std::to_string(idx);
std::shared_ptr<RegstDesc> out_regst = ProduceRegst(out_regst_name, false, 1, 1);
edge_infos[idx].edge->AddRegst(out_regst_name, out_regst);
}
}
void ReduceSplitCompTaskNode::ConsumeAllRegsts() {
ConsumeRegst("in", this->SoleInEdge()->GetSoleRegst());
}
void ReduceSplitCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
std::shared_ptr<Operator> reduce_split_op = this->logical_node()->SoleOp();
node->mut_op() = reduce_split_op;
node->BindBnWithRegst(reduce_split_op->SoleIbn(), GetSoleConsumedRegst("in"));
CompTaskNode* reduce_concat_node = FindPeerReduceConcatTaskNode();
CHECK_EQ(reduce_concat_node->consumed_regsts().size(), produced_regsts().size());
FOR_RANGE(size_t, i, 0, reduce_split_op->output_bns().size()) {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out_" + std::to_string(i));
CHECK(out_regst.get() != nullptr);
out_regst->CopyBlobDescFrom(
reduce_concat_node->GetSoleConsumedRegst("in_" + std::to_string(i)).get());
node->BindBnWithRegst(reduce_split_op->output_bns().Get(i), out_regst);
}
}
CompTaskNode* ReduceSplitCompTaskNode::FindPeerReduceConcatTaskNode() {
CompTaskNode* src_node = this;
bool found_direct_node = true;
while (src_node->GetTaskType() != TaskType::kReduceConcat && found_direct_node) {
found_direct_node = false;
for (TaskEdge* edge : src_node->in_edges()) {
CompTaskNode* comp_task_node = dynamic_cast<CompTaskNode*>(edge->src_node());
if (comp_task_node != nullptr) {
src_node = comp_task_node;
CHECK(src_node->GetTaskType() != TaskType::kNormalBackward);
found_direct_node = true;
break;
}
}
if (found_direct_node == false) { break; }
}
return src_node;
}
void ReduceSplitCompTaskNode::FixPackedBlobDescOfProducedRegst() {
int64_t out_regst_num = produced_regsts().size();
FOR_RANGE(int64_t, idx, 0, out_regst_num) {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out_" + std::to_string(idx));
CHECK(out_regst->IsLocked());
Shape& shape = out_regst->MutBlobDesc(GenPackedLbi())->mut_shape();
shape =
Shape({static_cast<int64_t>(RoundUp(shape.elem_cnt(), parallel_ctx()->parallel_num()))});
}
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_REDUCE_SPLIT_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_REDUCE_SPLIT_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class ReduceSplitCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(ReduceSplitCompTaskNode);
ReduceSplitCompTaskNode() = default;
~ReduceSplitCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
TaskType GetTaskType() const override { return TaskType::kReduceSplit; }
CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kMix; }
private:
void BuildExecGphAndRegst() override;
void FixPackedBlobDescOfProducedRegst() override;
CompTaskNode* FindPeerReduceConcatTaskNode();
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_REDUCE_SPLIT_COMPUTE_TASK_NODE_H_
......@@ -55,6 +55,7 @@ TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
&logical2sorted_out_box, MutBufTask, AllocateCpuThrdIdEvenly);
SetAreaIdForNewNodes(logical_edge->src_node(), logical_edge->dst_node());
});
MergeChainAndSetOrderInGraphForEachNode();
ToDotWithAutoFilePath();
}
......@@ -100,10 +101,7 @@ void TaskGraph::RemoveEmptyRegsts() {
ForEachNode([&](TaskNode* node) { node->UnbindBnWithEmptyRegst(); });
}
void TaskGraph::AddOrderingCtrlEdgeInSameChain() {
MergeChainAndSetOrderInGraphForEachNode();
BuildCtrlRegstDescInSameChain();
}
void TaskGraph::AddOrderingCtrlEdgeInSameChain() { BuildCtrlRegstDescInSameChain(); }
void TaskGraph::MergeChainAndSetOrderInGraphForEachNode() {
ChainGraph chain_graph(*this);
......@@ -136,12 +134,12 @@ void TaskGraph::BuildCtrlRegstDescInSameChain() {
}
struct ReduceTaskNodes {
CompTaskNode* scatter;
CompTaskNode* local_add;
CompTaskNode* global_add;
CompTaskNode* gather;
ReduceTaskNodes() : scatter(nullptr), local_add(nullptr), global_add(nullptr), gather(nullptr) {}
CompTaskNode* concat = nullptr;
CompTaskNode* scatter = nullptr;
CompTaskNode* local_add = nullptr;
CompTaskNode* global_add = nullptr;
CompTaskNode* gather = nullptr;
CompTaskNode* split = nullptr;
};
void TaskGraph::EnableMemSharingInReduceStruct() {
......@@ -163,6 +161,20 @@ void TaskGraph::CollectReduceTaskNodes(
return nullptr;
};
auto FindConcatAndScatter = [&](CompTaskNode* bw_or_md_diff_acc,
ReduceTaskNodes* reduce_task_nodes) {
CompTaskNode* concat_task_node =
FindSuccReduceTaskNode(bw_or_md_diff_acc, TaskType::kReduceConcat);
if (concat_task_node != nullptr) {
reduce_task_nodes->concat = concat_task_node;
reduce_task_nodes->scatter =
FindSuccReduceTaskNode(reduce_task_nodes->concat, TaskType::kReduceScatter);
} else {
reduce_task_nodes->scatter =
FindSuccReduceTaskNode(bw_or_md_diff_acc, TaskType::kReduceScatter);
}
};
ForEachNode([&](TaskNode* task_node) {
if (IsBackwardTaskType(task_node->GetTaskType()) == false) { return; }
if (task_node->device_type() != DeviceType::kGPU) { return; }
......@@ -175,15 +187,16 @@ void TaskGraph::CollectReduceTaskNodes(
}
ReduceTaskNodes& reduce_task_nodes = (*bw2reduce_tasks)[bw_task_node];
CompTaskNode* tmp_task_node = FindSuccReduceTaskNode(bw_task_node, TaskType::kMdDiffAcc);
if (tmp_task_node != nullptr) {
reduce_task_nodes.scatter = FindSuccReduceTaskNode(tmp_task_node, TaskType::kReduceScatter);
CompTaskNode* diff_acc_task_node = FindSuccReduceTaskNode(bw_task_node, TaskType::kMdDiffAcc);
if (diff_acc_task_node != nullptr) {
FindConcatAndScatter(diff_acc_task_node, &reduce_task_nodes);
} else {
reduce_task_nodes.scatter = FindSuccReduceTaskNode(bw_task_node, TaskType::kReduceScatter);
FindConcatAndScatter(bw_task_node, &reduce_task_nodes);
}
tmp_task_node = FindSuccReduceTaskNode(reduce_task_nodes.scatter, TaskType::kReduceLocalAdd);
if (tmp_task_node != nullptr) {
reduce_task_nodes.local_add = tmp_task_node;
CompTaskNode* local_add_task_node =
FindSuccReduceTaskNode(reduce_task_nodes.scatter, TaskType::kReduceLocalAdd);
if (local_add_task_node != nullptr) {
reduce_task_nodes.local_add = local_add_task_node;
reduce_task_nodes.global_add =
FindSuccReduceTaskNode(reduce_task_nodes.local_add, TaskType::kReduceGlobalAdd);
} else {
......@@ -192,6 +205,9 @@ void TaskGraph::CollectReduceTaskNodes(
}
reduce_task_nodes.gather =
FindSuccReduceTaskNode(reduce_task_nodes.global_add, TaskType::kReduceGather);
reduce_task_nodes.split =
FindSuccReduceTaskNode(reduce_task_nodes.gather, TaskType::kReduceSplit);
if (reduce_task_nodes.split == nullptr) { CHECK(reduce_task_nodes.concat == nullptr); }
CHECK(reduce_task_nodes.scatter != nullptr);
CHECK(reduce_task_nodes.global_add != nullptr);
......@@ -199,6 +215,41 @@ void TaskGraph::CollectReduceTaskNodes(
});
}
void TaskGraph::EnableMemSharingInReduceConcatSplitIfNeed(
const ReduceTaskNodes& reduce_task_nodes,
std::function<void(RegstDesc*, int64_t)> SetMemSharedField4Regst) {
if (reduce_task_nodes.concat == nullptr) { return; }
int32_t reduce_num = reduce_task_nodes.split->produced_regsts().size();
std::shared_ptr<RegstDesc> concat_out_regst = reduce_task_nodes.concat->GetProducedRegst("out");
std::shared_ptr<RegstDesc> split_in_regst = reduce_task_nodes.split->GetSoleConsumedRegst("in");
const BlobDesc* concat_out_packed = concat_out_regst->GetBlobDesc(GenPackedLbi());
const BlobDesc* split_in_packed = split_in_regst->GetBlobDesc(GenPackedLbi());
size_t concat_out_byte_size = RtBlobDesc(*concat_out_packed).ByteSizeOfBlobBody();
size_t split_in_byte_size = RtBlobDesc(*split_in_packed).ByteSizeOfBlobBody();
CHECK_EQ(concat_out_byte_size, split_in_byte_size);
SetMemSharedField4Regst(concat_out_regst.get(), 0);
SetMemSharedField4Regst(split_in_regst.get(), 0);
int64_t offset = 0;
FOR_RANGE(int32_t, idx, 0, reduce_num) {
auto concat_in_regst =
reduce_task_nodes.concat->GetSoleConsumedRegst("in_" + std::to_string(idx));
auto split_out_regst = reduce_task_nodes.split->GetProducedRegst("out_" + std::to_string(idx));
SetMemSharedField4Regst(concat_in_regst.get(), offset);
SetMemSharedField4Regst(split_out_regst.get(), offset);
// Check shape invariant
const BlobDesc* concat_in_packed = concat_in_regst->GetBlobDesc(GenPackedLbi());
const BlobDesc* split_out_packed = split_out_regst->GetBlobDesc(GenPackedLbi());
size_t concat_in_byte_size = RtBlobDesc(*concat_in_packed).ByteSizeOfBlobBody();
size_t split_out_byte_size = RtBlobDesc(*split_out_packed).ByteSizeOfBlobBody();
CHECK_EQ(concat_in_byte_size, split_out_byte_size);
offset += concat_in_byte_size;
}
}
void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_nodes) {
std::shared_ptr<const ParallelDesc> parallel_desc =
reduce_task_nodes.scatter->logical_node()->parallel_desc();
......@@ -211,12 +262,14 @@ void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_n
int64_t mem_shared_id = Global<IDMgr>::Get()->NewMemSharedId();
std::vector<int64_t> blob_index2offset(parallel_num, 0);
auto SetMemSharedField4Regst = [&](RegstDesc* regst, int64_t blob_id) {
auto SetMemSharedField4Regst = [&](RegstDesc* regst, int64_t offset) {
regst->set_enable_mem_sharing(true);
regst->set_mem_shared_id(mem_shared_id);
regst->set_mem_shared_offset(blob_index2offset.at(blob_id));
regst->set_mem_shared_offset(offset);
};
EnableMemSharingInReduceConcatSplitIfNeed(reduce_task_nodes, SetMemSharedField4Regst);
// scatter
{
std::shared_ptr<RegstDesc> consumed_regst =
......@@ -233,7 +286,8 @@ void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_n
for (int64_t i = 0; i < parallel_num; ++i) {
SetMemSharedField4Regst(
reduce_task_nodes.scatter->GetProducedRegst("out_" + std::to_string(i)).get(), i);
reduce_task_nodes.scatter->GetProducedRegst("out_" + std::to_string(i)).get(),
blob_index2offset.at(i));
}
}
......@@ -243,7 +297,7 @@ void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_n
CHECK_EQ(mem_shared_id, consumed_regst->mem_shared_id());
CHECK_EQ(blob_index2offset.at(blob_id), consumed_regst->mem_shared_offset());
} else {
SetMemSharedField4Regst(consumed_regst, blob_id);
SetMemSharedField4Regst(consumed_regst, blob_index2offset.at(blob_id));
}
};
......@@ -267,7 +321,7 @@ void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_n
for (int64_t i = 0; i < machine_num; ++i) {
SetMemSharedField4Regst(
reduce_task_nodes.local_add->GetProducedRegst("out_" + std::to_string(i)).get(),
i * dev_num_of_each_machine + dev_index_of_this_machine);
blob_index2offset.at(i * dev_num_of_each_machine + dev_index_of_this_machine));
}
}
......@@ -278,18 +332,20 @@ void TaskGraph::EnableMemSharingInOneReduce(const ReduceTaskNodes& reduce_task_n
for (const auto& kv : consumed_regsts) {
int64_t in_parallel_id = oneflow_cast<int64_t>(kv.first.substr(3));
CHECK_EQ(1, kv.second.size());
RegstDesc* consumed_regst = kv.second.front().get();
SetOrCheck4ConsumedRegst(consumed_regst, in_parallel_id == parallel_id, in_parallel_id);
SetOrCheck4ConsumedRegst(kv.second.front().get(), in_parallel_id == parallel_id,
in_parallel_id);
}
};
// global add
int consumed_regst_num = reduce_task_nodes.local_add ? machine_num : parallel_num;
HandleMemSharedFieldOfConsumedRegsts(reduce_task_nodes.global_add, consumed_regst_num);
SetMemSharedField4Regst(reduce_task_nodes.global_add->GetProducedRegst("out").get(), parallel_id);
SetMemSharedField4Regst(reduce_task_nodes.global_add->GetProducedRegst("out").get(),
blob_index2offset.at(parallel_id));
// gather
HandleMemSharedFieldOfConsumedRegsts(reduce_task_nodes.gather, parallel_num);
SetMemSharedField4Regst(reduce_task_nodes.gather->GetProducedRegst("out").get(), 0);
SetMemSharedField4Regst(reduce_task_nodes.gather->GetProducedRegst("out").get(),
blob_index2offset.at(0));
}
void TaskGraph::AddCtrlEdge4MemSharingInOneReduce(const ReduceTaskNodes& reduce_task_nodes) {
......
......@@ -87,6 +87,9 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
template<typename TaskNodeType>
bool IsEndingTaskType(TaskType type);
void EnableMemSharingInReduceConcatSplitIfNeed(
const ReduceTaskNodes&, std::function<void(RegstDesc*, int64_t)> SetMemSharedField4Regst);
void GeneratePersistenceThrdId(
const std::vector<std::pair<int64_t, CompTaskNode*>>& persistence_nodes);
......
......@@ -331,6 +331,12 @@ std::shared_ptr<RegstDesc> TaskEdge::GetSoleRegst() const {
return name_in_producer2regst_.begin()->second;
}
std::vector<std::shared_ptr<RegstDesc>> TaskEdge::GetRegsts() const {
std::vector<std::shared_ptr<RegstDesc>> regst_descs;
for (auto& pair : name_in_producer2regst_) { regst_descs.emplace_back(pair.second); }
return regst_descs;
}
void TaskEdge::AddRegst(const std::string& name_in_producer, std::shared_ptr<RegstDesc> regst) {
CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second);
}
......@@ -356,12 +362,10 @@ RegstDescIdSet* FindOrCreateConsumedCtrlRegstDescIdSet(TaskProto* task_proto,
}
std::map<TaskType, std::string> task_type2color = {
{kInvalid, "0"}, {kNormalForward, "2"}, {kNormalBackward, "3"},
{kRecordLoad, "1"}, {kDecode, "1"}, {kLoss, "4"},
{kLossAcc, "5"}, {kLossPrint, "1"}, {kNormalMdUpdt, "6"},
{kMdSave, "1"}, {kMdDiffAcc, "7"}, {kCopyHd, "8"},
{kCopyCommNet, "9"}, {kBoxing, "10"}, {kPrint, "1"},
{kReduceScatter, "2"}, {kReduceLocalAdd, "2"}, {kReduceGlobalAdd, "2"},
{kReduceGather, "2"}, {kAccuracy, "4"}, {kAccuracyPrint, "1"},
{kAccuracyAcc, "5"}};
{kInvalid, "0"}, {kNormalForward, "2"}, {kNormalBackward, "3"}, {kRecordLoad, "1"},
{kDecode, "1"}, {kLoss, "4"}, {kLossAcc, "5"}, {kLossPrint, "1"},
{kNormalMdUpdt, "6"}, {kMdSave, "1"}, {kMdDiffAcc, "7"}, {kCopyHd, "8"},
{kCopyCommNet, "9"}, {kBoxing, "10"}, {kPrint, "1"}, {kReduceConcat, "2"},
{kReduceScatter, "2"}, {kReduceLocalAdd, "2"}, {kReduceGlobalAdd, "2"}, {kReduceGather, "2"},
{kReduceSplit, "2"}, {kAccuracy, "4"}, {kAccuracyPrint, "1"}, {kAccuracyAcc, "5"}};
} // namespace oneflow
......@@ -36,10 +36,10 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
std::shared_ptr<RegstDesc> GetProducedRegst(const std::string& name);
const std::list<std::shared_ptr<RegstDesc>>& GetConsumedRegst(const std::string& name);
std::shared_ptr<RegstDesc> GetSoleConsumedRegst(const std::string& name);
const HashMap<std::string, std::shared_ptr<RegstDesc>>& produced_regsts() {
const HashMap<std::string, std::shared_ptr<RegstDesc>>& produced_regsts() const {
return produced_regsts_;
}
const HashMap<std::string, std::list<std::shared_ptr<RegstDesc>>>& consumed_regsts() {
const HashMap<std::string, std::list<std::shared_ptr<RegstDesc>>>& consumed_regsts() const {
return consumed_regsts_;
}
DeviceType device_type() const;
......@@ -126,6 +126,7 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> {
std::shared_ptr<RegstDesc> GetRegst(const std::string& name_in_producer) const;
std::shared_ptr<RegstDesc> GetSoleRegst() const;
std::vector<std::shared_ptr<RegstDesc>> GetRegsts() const;
void AddRegst(const std::string& name_in_producer, std::shared_ptr<RegstDesc> regst);
......
......@@ -48,6 +48,47 @@ Plan Compiler::Compile() {
return plan;
}
void Compiler::GenNetTopo(Plan* plan) {
HashMap<int64_t, int64_t> rid2mid;
HashMap<int64_t, int64_t> tid2mid;
std::map<int64_t, std::set<int64_t>> net_topo;
for (const TaskProto& task_proto : plan->task()) {
for (const auto& regst_desc_it : task_proto.produced_regst_desc()) {
rid2mid.emplace(regst_desc_it.second.regst_desc_id(), task_proto.machine_id());
}
CHECK(tid2mid.emplace(task_proto.task_id(), task_proto.machine_id()).second);
}
for (const TaskProto& task_proto : plan->task()) {
for (const auto& regst_desc_it : task_proto.produced_regst_desc()) {
int64_t rid = regst_desc_it.second.regst_desc_id();
auto rid2mid_it = rid2mid.find(rid);
CHECK(rid2mid_it != rid2mid.end());
int64_t producer_mid = rid2mid_it->second;
for (int64_t consumer_task_id : regst_desc_it.second.consumer_task_id()) {
auto tid2mid_it = tid2mid.find(consumer_task_id);
CHECK(tid2mid_it != tid2mid.end());
int64_t consumer_mid = tid2mid_it->second;
net_topo[producer_mid].insert(consumer_mid);
net_topo[consumer_mid].insert(producer_mid);
}
}
}
HashMap<int64_t, MachineIds> std_net_topo;
NetTopo& pb_net_topo = *(plan->mutable_net_topo());
for (auto& pair : net_topo) {
int64_t src_mid = pair.first;
if (pair.second.count(src_mid)) { pair.second.erase(src_mid); }
std::vector<int64_t> peer_mids(pair.second.begin(), pair.second.end());
MachineIds pb_mids;
*(pb_mids.mutable_machine_id()) = StdVec2PbRf<int64_t>(peer_mids);
CHECK(std_net_topo.emplace(src_mid, pb_mids).second);
}
*(pb_net_topo.mutable_peer_machine_ids()) = HashMap2PbMap(std_net_topo);
}
Plan Compiler::DoCompile() {
#ifdef WITH_CUDA
Global<CudnnConvCtxCache>::New();
......@@ -78,6 +119,7 @@ Plan Compiler::DoCompile() {
task_node->ToProto(plan.mutable_task()->Add());
});
plan.set_total_mbn_num(total_mbn_num);
GenNetTopo(&plan);
ToDotFile(plan, JoinPath(LogDir(), "/dot/plan.dot"));
#ifdef WITH_CUDA
Global<CudnnConvCtxCache>::Delete();
......
......@@ -18,6 +18,7 @@ class Compiler final {
private:
Plan DoCompile();
void GenNetTopo(Plan* plan);
};
} // namespace oneflow
......
......@@ -56,6 +56,7 @@ message OtherConf {
optional uint64 rdma_mem_block_mbyte = 110 [default = 8];
optional uint64 rdma_recv_msg_buf_mbyte = 111 [default = 6];
optional int64 reduce_group_size = 112 [default = 20];
optional bool collect_act_event = 113 [default = false];
optional bool enable_mem_sharing = 114 [default = true];
optional bool use_ordered_allreduce_in_mdupdt = 115 [default = false];
......
......@@ -51,6 +51,7 @@ class JobDesc final {
return IsTrain() && job_conf_.other().enable_write_snapshot();
}
bool enable_blob_mem_sharing() const { return job_conf_.other().enable_blob_mem_sharing(); }
int64_t reduce_group_size() const { return job_conf_.other().reduce_group_size(); }
// machine_name <-> machine_id
int64_t MachineID4MachineName(const std::string& machine_name) const;
......
......@@ -81,6 +81,8 @@ std::string cluster_thrd_ids_key(const std::string& plan_name) {
return plan_name + "_cluster_thrd_ids";
}
std::string net_topo_key(const std::string& plan_name) { return plan_name + "_net_topo"; }
std::string sub_plan_key(const std::string& plan_name, int64_t machine_id, int64_t thrd_id) {
return plan_name + "_" + std::to_string(machine_id) + "_" + std::to_string(thrd_id);
}
......@@ -114,6 +116,8 @@ void PushPlan(const std::string& plan_name, const Plan& plan) {
}
Global<CtrlClient>::Get()->PushKV(total_mbn_num_key(plan_name),
std::to_string(plan.total_mbn_num()));
Global<CtrlClient>::Get()->PushKV(net_topo_key(plan_name), plan.net_topo());
}
void PullPlan(const std::string& plan_name, Plan* plan) {
......@@ -122,18 +126,21 @@ void PullPlan(const std::string& plan_name, Plan* plan) {
PrintProtoToTextFile(cluster_thrd_ids, JoinPath(LogDir(), cluster_thrd_ids_key(plan_name)));
HashMap<int64_t, ThrdIds> machine_id2thrd_ids;
machine_id2thrd_ids = PbMap2HashMap(cluster_thrd_ids.machine_id2thrd_ids());
for (const auto& pair : machine_id2thrd_ids) {
int64_t machine_id = pair.first;
std::vector<int64_t> thrd_id_vec = PbRf2StdVec(pair.second.thrd_id());
for (auto thrd_id : thrd_id_vec) {
SubPlan sub_plan;
Global<CtrlClient>::Get()->PullKV(sub_plan_key(plan_name, machine_id, thrd_id), &sub_plan);
plan->mutable_task()->MergeFrom(sub_plan.task());
}
int64_t machine_id = Global<MachineCtx>::Get()->this_machine_id();
auto thrd_ids_it = machine_id2thrd_ids.find(machine_id);
CHECK(thrd_ids_it != machine_id2thrd_ids.end());
std::vector<int64_t> thrd_id_vec = PbRf2StdVec(thrd_ids_it->second.thrd_id());
for (auto thrd_id : thrd_id_vec) {
SubPlan sub_plan;
Global<CtrlClient>::Get()->PullKV(sub_plan_key(plan_name, machine_id, thrd_id), &sub_plan);
plan->mutable_task()->MergeFrom(sub_plan.task());
}
NetTopo net_topo;
std::string total_mbn_num;
Global<CtrlClient>::Get()->PullKV(total_mbn_num_key(plan_name), &total_mbn_num);
plan->set_total_mbn_num(oneflow_cast<int64_t>(total_mbn_num));
Global<CtrlClient>::Get()->PullKV(net_topo_key(plan_name), &net_topo);
*(plan->mutable_net_topo()) = net_topo;
}
bool HasRelayPlacement() {
......@@ -194,9 +201,12 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m
Plan improved_plan;
PushAvailableMemDescOfThisMachine();
AvailableMemDesc amd;
double start = GetCurTime();
if (machine_ctx->IsThisMachineMaster()) {
double start = GetCurTime();
naive_plan = Compiler().Compile();
LOG(INFO) << "compile time: " << GetCurTime() - start;
amd = PullAvailableMemDesc();
mem_shared_plan = Improver().ImproveMemSharedIdOnly(amd, naive_plan);
PushPlan("naive_plan", naive_plan);
......@@ -208,6 +218,7 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m
OF_BARRIER();
PrintProtoToTextFile(naive_plan, JoinPath(LogDir(), "naive_plan"));
PrintProtoToTextFile(mem_shared_plan, JoinPath(LogDir(), "mem_shared_plan"));
LOG(INFO) << "push_pull_plan:" << GetCurTime() - start;
if (HasRelayPlacement()) {
// Experiment Runtime
{ Runtime experiment_run(mem_shared_plan, true); }
......
......@@ -3,7 +3,16 @@ package oneflow;
import "oneflow/core/job/task.proto";
message MachineIds {
repeated int64 machine_id = 1;
}
message NetTopo {
map<int64, MachineIds> peer_machine_ids = 1;
}
message Plan {
repeated TaskProto task = 1;
required int64 total_mbn_num = 2;
required NetTopo net_topo = 3;
}
......@@ -27,9 +27,11 @@ enum TaskType {
kReduceGlobalAdd = 18;
kReduceGather = 19;
kReduceScatter = 20;
kAccuracy = 21;
kAccuracyAcc = 22;
kAccuracyPrint = 23;
kReduceConcat = 21;
kReduceSplit = 22;
kAccuracy = 23;
kAccuracyAcc = 24;
kAccuracyPrint = 25;
};
enum AreaType {
......
......@@ -104,6 +104,14 @@ message ReduceGatherKernelConf {
repeated int64 data_offset = 1;
}
message ReduceConcatKernelConf {
repeated int64 data_offset = 1;
}
message ReduceSplitKernelConf {
repeated int64 data_offset = 1;
}
message AccuracyKernelConf {
required DataType prediction_type = 1;
required DataType label_type = 2;
......@@ -154,6 +162,8 @@ message KernelConf {
NormalizationKernelConf normalization_conf = 250;
LocalResponseNormalizationKernelConf local_response_normalization_conf = 300;
ReduceGatherKernelConf reduce_gather_conf = 350;
ReduceConcatKernelConf reduce_concat_conf = 351;
ReduceSplitKernelConf reduce_split_conf = 352;
AccuracyKernelConf accuracy_conf = 401;
}
}
#include "oneflow/core/kernel/reduce_concat_kernel.h"
namespace oneflow {
template<DeviceType device_type>
void ReduceConcatKernel<device_type>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const auto* other_val = static_cast<std::pair<int64_t, bool>*>(ctx.other);
int64_t in_bn_id = other_val->first;
bool is_inplace = other_val->second;
if (is_inplace) { return; }
Blob* out_blob = BnInOp2Blob("out");
char* dst_cur_dptr = out_blob->mut_dptr<char>();
dst_cur_dptr += this->kernel_conf().reduce_concat_conf().data_offset().Get(in_bn_id);
Blob* in_blob = BnInOp2Blob(this->op_attribute().input_bns().Get(in_bn_id));
size_t in_byte_size = in_blob->ByteSizeOfDataContentField();
Memcpy<device_type>(ctx.device_ctx, dst_cur_dptr, in_blob->dptr<char>(), in_byte_size);
}
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReduceConcatConf, ReduceConcatKernel);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_REDUCE_CONCAT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_REDUCE_CONCAT_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type>
class ReduceConcatKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(ReduceConcatKernel);
ReduceConcatKernel() = default;
~ReduceConcatKernel() = default;
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_REDUCE_CONCAT_KERNEL_H_
......@@ -5,14 +5,14 @@ namespace oneflow {
template<DeviceType device_type>
void ReduceScatterKernel<device_type>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
if (device_type == DeviceType::kGPU) { return; }
bool is_inplace = *static_cast<bool*>(ctx.other);
if (is_inplace) { return; }
const Blob* in_blob = BnInOp2Blob("in");
const char* src_cur_dptr = in_blob->dptr<char>();
for (const std::string& obn : this->op_attribute().output_bns()) {
Blob* out_blob = BnInOp2Blob(obn);
size_t out_byte_size = out_blob->ByteSizeOfDataContentField();
Memcpy<DeviceType::kCPU>(ctx.device_ctx, out_blob->mut_dptr<char>(), src_cur_dptr,
out_byte_size);
Memcpy<device_type>(ctx.device_ctx, out_blob->mut_dptr<char>(), src_cur_dptr, out_byte_size);
src_cur_dptr += out_byte_size;
}
}
......
#include "oneflow/core/kernel/reduce_split_kernel.h"
namespace oneflow {
template<DeviceType device_type>
void ReduceSplitKernel<device_type>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
bool is_inplace = *static_cast<bool*>(ctx.other);
if (is_inplace) { return; }
const Blob* in_blob = BnInOp2Blob("in");
const char* src_cur_dptr = in_blob->dptr<char>();
for (const std::string& obn : this->op_attribute().output_bns()) {
Blob* out_blob = BnInOp2Blob(obn);
size_t out_byte_size = out_blob->ByteSizeOfDataContentField();
Memcpy<device_type>(ctx.device_ctx, out_blob->mut_dptr<char>(), src_cur_dptr, out_byte_size);
src_cur_dptr += out_byte_size;
}
}
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReduceSplitConf, ReduceSplitKernel);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_REDUCE_SPLIT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_REDUCE_SPLIT_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type>
class ReduceSplitKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(ReduceSplitKernel);
ReduceSplitKernel() = default;
~ReduceSplitKernel() = default;
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_REDUCE_SPLIT_KERNEL_H_
......@@ -577,6 +577,14 @@ message TransposeOpConf {
repeated int32 perm = 3;
}
message ReduceConcatOpConf {
required int32 in_num = 1;
}
message ReduceSplitOpConf {
required int32 out_num = 1;
}
message ReduceScatterOpConf {
required int32 out_num = 1;
};
......@@ -652,9 +660,11 @@ message OperatorConf {
ReduceLocalAddOpConf reduce_local_add_conf = 401;
ReduceGlobalAddOpConf reduce_global_add_conf = 402;
ReduceGatherOpConf reduce_gather_conf = 403;
RecordLoadOpConf record_load_conf = 404;
AccuracyOpConf accuracy_conf=405;
AccuracyPrintOpConf accuracy_print_conf = 406;
ReduceConcatOpConf reduce_concat_conf = 404;
ReduceSplitOpConf reduce_split_conf = 405;
RecordLoadOpConf record_load_conf = 406;
AccuracyOpConf accuracy_conf=407;
AccuracyPrintOpConf accuracy_print_conf = 408;
}
}
......
#include "oneflow/core/operator/reduce_concat_op.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/register/runtime_blob_desc.h"
namespace oneflow {
void ReduceConcatOp::InitFromOpConf() {
CHECK(op_conf().has_reduce_concat_conf());
for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) {
EnrollInputBn("in_" + std::to_string(i), false);
}
EnrollOutputBn("out", false);
}
const PbMessage& ReduceConcatOp::GetCustomizedConf() const {
return op_conf().reduce_concat_conf();
}
void ReduceConcatOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
int32_t in_num = op_conf().reduce_concat_conf().in_num();
CHECK_GE(in_num, 2);
BlobDesc* first_in_blob = GetBlobDesc4BnInOp(input_bns().Get(0));
BlobDesc* out_blob = GetBlobDesc4BnInOp(SoleObn());
*out_blob = *first_in_blob;
int64_t out_blob_elem_cnt = first_in_blob->shape().elem_cnt();
for (int32_t i = 1; i < in_num; ++i) {
out_blob_elem_cnt += GetBlobDesc4BnInOp(input_bns().Get(i))->shape().elem_cnt();
}
out_blob->mut_shape() = Shape({out_blob_elem_cnt});
}
void ReduceConcatOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
KernelConf* kernel_conf) const {
ReduceConcatKernelConf* reduce_concat_conf = kernel_conf->mutable_reduce_concat_conf();
int64_t offset = 0;
for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) {
reduce_concat_conf->mutable_data_offset()->Add(offset);
offset += RtBlobDesc(*(GetBlobDesc4BnInOp(input_bns().Get(i)))).ByteSizeOfBlobBody();
}
CHECK_EQ(offset, RtBlobDesc(*GetBlobDesc4BnInOp(SoleObn())).ByteSizeOfBlobBody());
}
LogicalBlobId ReduceConcatOp::obn2lbi(const std::string& output_bn) const {
LogicalBlobId ret;
ret.set_op_name(op_name());
ret.set_blob_name("out");
return ret;
}
REGISTER_OP(OperatorConf::kReduceConcatConf, ReduceConcatOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_REDUCE_CONCAT_OP_H_
#define ONEFLOW_CORE_OPERATOR_REDUCE_CONCAT_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class ReduceConcatOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(ReduceConcatOp);
ReduceConcatOp() = default;
~ReduceConcatOp() = default;
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
private:
void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext*, KernelConf*) const override;
LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
LogicalBlobId obn2lbi(const std::string& output_bn) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_REDUCE_CONCAT_OP_H_
#include "oneflow/core/operator/reduce_split_op.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/register/runtime_blob_desc.h"
namespace oneflow {
void ReduceSplitOp::InitFromOpConf() {
CHECK(op_conf().has_reduce_split_conf());
for (int32_t i = 0; i < op_conf().reduce_split_conf().out_num(); ++i) {
EnrollOutputBn("out_" + std::to_string(i), false);
}
EnrollInputBn("in", false);
}
const PbMessage& ReduceSplitOp::GetCustomizedConf() const { return op_conf().reduce_split_conf(); }
void ReduceSplitOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
KernelConf* kernel_conf) const {
ReduceSplitKernelConf* reduce_split_conf = kernel_conf->mutable_reduce_split_conf();
int64_t offset = 0;
for (int32_t i = 0; i < op_conf().reduce_split_conf().out_num(); ++i) {
reduce_split_conf->mutable_data_offset()->Add(offset);
offset += RtBlobDesc(*(GetBlobDesc4BnInOp(output_bns().Get(i)))).ByteSizeOfBlobBody();
}
CHECK_EQ(offset, RtBlobDesc(*GetBlobDesc4BnInOp(SoleIbn())).ByteSizeOfBlobBody());
}
REGISTER_OP(OperatorConf::kReduceSplitConf, ReduceSplitOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_REDUCE_SPLIT_OP_H_
#define ONEFLOW_CORE_OPERATOR_REDUCE_SPLIT_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class ReduceSplitOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(ReduceSplitOp);
ReduceSplitOp() = default;
~ReduceSplitOp() = default;
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override {}
private:
void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext*, KernelConf*) const override;
LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); }
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_REDUCE_SPLIT_OP_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册