提交 39945d0f 编写于 作者: Y YuJianfeng

Add AllGather fusion pass

上级 38ad5673
......@@ -21,7 +21,7 @@
#include "pre_activate/ascend/ir_fission/bn_grad_split.h"
#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h"
#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h"
#include "pre_activate/pass/allreduce_fusion.h"
#include "pre_activate/pass/communication_op_fusion.h"
#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h"
#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h"
#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h"
......@@ -254,6 +254,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto optimizer = std::make_shared<GraphOptimizer>();
auto other_pm = std::make_shared<PassManager>("other_pm");
other_pm->AddPass(std::make_shared<AllReduceFusion>());
other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<BufferFusion>());
other_pm->AddPass(std::make_shared<GetitemTuple>());
......
......@@ -13,14 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/pass/allreduce_fusion.h"
#include "pre_activate/pass/communication_op_fusion.h"
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include "utils/utils.h"
#include "utils/graph_utils.h"
#include "operator/ops.h"
#include "device/kernel_info.h"
......@@ -31,9 +29,12 @@
namespace mindspore {
namespace opt {
namespace {
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allreduce_node_info, size_t start_index,
constexpr auto kAttrDefaultGroup = "default_group";
constexpr auto kAttrDefaultOp = "default_op";
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index,
size_t end_index) {
if (end_index >= allreduce_node_info.allreduce_node.size()) {
if (end_index >= communication_op_info.communication_op_nodes.size()) {
MS_LOG(EXCEPTION) << "end index out of vector size";
}
std::vector<std::string> inputs_device_format;
......@@ -43,7 +44,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allred
std::vector<std::vector<size_t>> outputs_shape;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t idx = start_index; idx <= end_index; ++idx) {
auto cnode = allreduce_node_info.allreduce_node[idx];
auto cnode = communication_op_info.communication_op_nodes[idx];
MS_EXCEPTION_IF_NULL(cnode);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
......@@ -64,14 +65,38 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allred
builder.SetOutputsDeviceType(outputs_device_type);
return builder.Build();
}
std::string GetFusionGroupKey(const AnfNodePtr &node) {
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
if (attr_fusion == nullptr) {
return "";
}
int fusion = GetValue<int>(attr_fusion);
if (fusion == 0) {
return "";
}
std::string group = kAttrDefaultGroup;
ValuePtr attr_group = primitive->GetAttr(kAttrGroup);
if (attr_group != nullptr) {
group = GetValue<std::string>(attr_group);
}
std::string op = kAttrDefaultOp;
ValuePtr attr_op = primitive->GetAttr(kAttrOp);
if (attr_op != nullptr) {
op = GetValue<std::string>(attr_op);
}
return group + op + std::to_string(fusion);
}
} // namespace
bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_info, size_t *segment_num,
std::vector<size_t> *segment_index) const {
bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
std::vector<size_t> *segment_index) const {
MS_EXCEPTION_IF_NULL(segment_num);
MS_EXCEPTION_IF_NULL(segment_index);
size_t allreduce_node_size = allreduce_node_info.allreduce_node.size();
MS_LOG(INFO) << "graph all reduce node size " << allreduce_node_size;
size_t communication_op_node_size = communication_op_info.communication_op_nodes.size();
MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size;
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
......@@ -82,30 +107,31 @@ bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_inf
uint32_t last_index = 0;
for (size_t i = 0; i < split_indices.size(); ++i) {
uint32_t index = split_indices[i];
if (index <= last_index || index >= allreduce_node_size) {
MS_LOG(EXCEPTION) << "invalid allreduce split index " << i << " " << index;
if (index <= last_index || index >= communication_op_node_size) {
MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index;
}
segment_index->push_back(index);
last_index = index;
segments++;
}
if (last_index != allreduce_node_size - 1) {
segment_index->push_back(allreduce_node_size - 1);
if (last_index != communication_op_node_size - 1) {
segment_index->push_back(communication_op_node_size - 1);
segments++;
}
} else {
segments = groups_;
for (size_t i = 0; i < segments - 1; ++i) {
segment_index->push_back((i + 1) * (allreduce_node_size / segments) - 1);
segment_index->push_back((i + 1) * (communication_op_node_size / segments) - 1);
}
segment_index->push_back(allreduce_node_size - 1);
segment_index->push_back(communication_op_node_size - 1);
}
if (segments >= allreduce_node_size) {
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments << ", allreduce_node_size=" << allreduce_node_size;
if (segments >= communication_op_node_size) {
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments
<< ", communication_op_node_size=" << communication_op_node_size;
return false;
}
if (segment_index->at(segments - 1) != allreduce_node_size - 1) {
if (segment_index->at(segments - 1) != communication_op_node_size - 1) {
MS_LOG(EXCEPTION) << "the last segment index is invalid.";
}
for (size_t i = 0; i < segments - 1; ++i) {
......@@ -118,19 +144,19 @@ bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_inf
return true;
}
AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph,
const AllReduceInfo_t &allreduce_node_info, size_t start_index,
size_t end_index) const {
AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
const CommunicationOpInfo &communication_op_info,
size_t start_index, size_t end_index) const {
MS_EXCEPTION_IF_NULL(func_graph);
auto prim = std::make_shared<Primitive>(kAllReduceOpName);
auto prim = std::make_shared<Primitive>(op_name_);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> fusion_inputs = {NewValueNode(prim)};
// get all inputs of current segment
if (end_index >= allreduce_node_info.allreduce_node.size()) {
if (end_index >= communication_op_info.communication_op_nodes.size()) {
MS_LOG(EXCEPTION) << "end index out of vector size";
}
for (size_t idx = start_index; idx <= end_index; ++idx) {
auto cnode = allreduce_node_info.allreduce_node[idx];
auto cnode = communication_op_info.communication_op_nodes[idx];
MS_EXCEPTION_IF_NULL(cnode);
fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
}
......@@ -141,14 +167,14 @@ AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph,
fused_node->set_kernel_info(kernel_info);
AbstractBasePtrList abstract_list;
for (size_t idx = start_index; idx <= end_index; ++idx) {
auto cnode = allreduce_node_info.allreduce_node[idx];
auto cnode = communication_op_info.communication_op_nodes[idx];
MS_EXCEPTION_IF_NULL(cnode);
AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node);
AnfAlgo::CopyNodeAttr("op", cnode, fused_node);
AnfAlgo::CopyNodeAttr("group", cnode, fused_node);
abstract_list.push_back(cnode->abstract());
}
auto kernel_build_info = GenerateKernelBuildInfo(allreduce_node_info, start_index, end_index);
auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get());
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
MS_EXCEPTION_IF_NULL(abstract_tuple);
......@@ -156,8 +182,8 @@ AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph,
return fused_node;
}
bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceInfo_t &allreduce_node_info,
size_t segment_num, const std::vector<size_t> &segment_index) const {
bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info,
size_t segment_num, const std::vector<size_t> &segment_index) const {
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
......@@ -169,12 +195,13 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn
start_index = end_index + 1;
continue;
}
AnfNodePtr new_allreduce = CreateFusedAllReduce(func_graph, allreduce_node_info, start_index, end_index);
// replace old allreduce with new allreduce
AnfNodePtr new_communication_op =
CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
// replace old communication op with new communication op
for (auto idx = start_index; idx <= end_index; ++idx) {
std::vector<AnfNodePtr> tuple_getitem_input;
tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
tuple_getitem_input.push_back(new_allreduce);
tuple_getitem_input.push_back(new_communication_op);
auto index = NewValueNode(SizeToInt(idx - start_index));
MS_EXCEPTION_IF_NULL(index);
auto imm = std::make_shared<Int32Imm>(idx - start_index);
......@@ -185,10 +212,10 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn
tuple_getitem_input.push_back(index);
AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input);
MS_EXCEPTION_IF_NULL(tuple_getitem);
auto allreduce_node_item = allreduce_node_info.allreduce_node.at(idx);
MS_EXCEPTION_IF_NULL(allreduce_node_item);
tuple_getitem->set_abstract(allreduce_node_item->abstract());
if (!manager->Replace(allreduce_node_item, tuple_getitem)) {
auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx);
MS_EXCEPTION_IF_NULL(communication_op_node_item);
tuple_getitem->set_abstract(communication_op_node_item->abstract());
if (!manager->Replace(communication_op_node_item, tuple_getitem)) {
MS_LOG(EXCEPTION) << "manager replace node failed";
}
}
......@@ -198,29 +225,24 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn
return changed;
}
bool AllReduceFusion::Run(const FuncGraphPtr &func_graph) {
bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
const float input_grad_size_num = 0.0;
const float input_grad_time_num = 0.0;
// divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion
std::unordered_map<std::string, AllReduceInfo_t> candidate_groups;
std::unordered_map<std::string, CommunicationOpInfo> candidate_groups;
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) {
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
int fusion = GetValue<int>(primitive->GetAttr("fusion"));
if (fusion == 0) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == op_name_) {
std::string key = GetFusionGroupKey(node);
if (key.empty()) {
continue;
}
std::string group = GetValue<std::string>(primitive->GetAttr("group"));
std::string op = GetValue<std::string>(primitive->GetAttr("op"));
std::string key = group + op + std::to_string(fusion);
if (candidate_groups.find(key) == candidate_groups.end()) {
AllReduceInfo_t allreduce_node_info;
candidate_groups[key] = allreduce_node_info;
CommunicationOpInfo communication_op_info;
candidate_groups[key] = communication_op_info;
}
candidate_groups[key].allreduce_node.push_back(node->cast<CNodePtr>());
candidate_groups[key].communication_op_nodes.push_back(node->cast<CNodePtr>());
candidate_groups[key].input_grad_size.push_back(input_grad_size_num);
candidate_groups[key].input_grad_time.push_back(input_grad_time_num);
}
......@@ -228,7 +250,7 @@ bool AllReduceFusion::Run(const FuncGraphPtr &func_graph) {
// split candidate group to segments according to _group class member
bool changed = false;
for (auto &it : candidate_groups) {
if (it.second.allreduce_node.size() <= 1) {
if (it.second.communication_op_nodes.size() <= 1) {
continue;
}
size_t segment_num = 0;
......
......@@ -13,37 +13,55 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
#include <utility>
#include <vector>
#include <string>
#include "pre_activate/common/pass.h"
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
struct AllReduceInfo_t {
std::vector<CNodePtr> allreduce_node;
struct CommunicationOpInfo {
std::vector<CNodePtr> communication_op_nodes;
std::vector<float> input_grad_size;
std::vector<float> input_grad_time;
};
class AllReduceFusion : public Pass {
class CommunicationOpFusion : public Pass {
public:
explicit AllReduceFusion(size_t groups = 1) : Pass("all_reduce_fusion"), groups_(groups) {}
~AllReduceFusion() override = default;
explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1)
: Pass(name), op_name_(std::move(op_name)), groups_(groups) {}
~CommunicationOpFusion() override = default;
bool Run(const FuncGraphPtr &graph) override;
private:
bool DoFusion(const FuncGraphPtr &func_graph, const AllReduceInfo_t &allreduce_node_info, size_t segment_num,
bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, size_t segment_num,
const std::vector<size_t> &segment_index) const;
AnfNodePtr CreateFusedAllReduce(const FuncGraphPtr &func_graph, const AllReduceInfo_t &allreduce_node_info,
size_t start_index, size_t end_index) const;
bool GetSplitSegments(const AllReduceInfo_t &allreduce_node_info, size_t *segment_num,
AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
const CommunicationOpInfo &communication_op_info, size_t start_index,
size_t end_index) const;
bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
std::vector<size_t> *segment_index) const;
std::string op_name_;
size_t groups_ = 1;
};
class AllReduceFusion : public CommunicationOpFusion {
public:
explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {}
~AllReduceFusion() override = default;
};
class AllGatherFusion : public CommunicationOpFusion {
public:
explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {}
~AllGatherFusion() override = default;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
......@@ -20,7 +20,7 @@
#include "device/gpu/gpu_stream_assign.h"
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/pass_manager.h"
#include "pre_activate/pass/allreduce_fusion.h"
#include "pre_activate/pass/communication_op_fusion.h"
#include "device/kernel_runtime_manager.h"
#include "predict/predict.h"
#include "common/utils.h"
......
......@@ -154,6 +154,9 @@ constexpr auto kAttrOutputUsedNum = "output_used_num";
constexpr auto kAttrHasBias = "has_bias";
constexpr auto kAttrN = "n";
constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active";
constexpr auto kAttrFusion = "fusion";
constexpr auto kAttrGroup = "group";
constexpr auto kAttrOp = "op";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";
......
......@@ -20,7 +20,7 @@
#include "ir/manager.h"
#include "debug/anf_ir_dump.h"
#include "session/anf_runtime_algorithm.h"
#include "pre_activate/pass/allreduce_fusion.h"
#include "pre_activate/pass/communication_op_fusion.h"
#include "pre_activate/common/optimizer.h"
#include "device/kernel_info.h"
#include "pre_activate/common/pass_manager.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册