提交 8f4bab4e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2410 Support insert memcpy between two hccl op if the part outputs of prior...

!2410 Support insert memcpy between two hccl op if the part outputs of prior hccl op linking to next hccl op
Merge pull request !2410 from huanghui/insert-memcpy-async-pass
......@@ -87,6 +87,7 @@
#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h"
#include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h"
#include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h"
......@@ -340,6 +341,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<ReduceScatterFusion>());
other_pm->AddPass(std::make_shared<BroadcastFusion>());
other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
optimizer->AddPassManager(other_pm);
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
#include <vector>
#include <set>
#include <string>
#include "utils/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/optimizer/opt.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
namespace {
bool IsPartOutputsOfHcclOp(const AnfNodePtr &node, const CNodePtr &cur_hccl, const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(cur_hccl);
MS_EXCEPTION_IF_NULL(graph);
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prev_node = cnode->input(kRealInputNodeIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(prev_node);
if (!AnfAlgo::IsCommunicationOp(prev_node)) {
return false;
}
auto prev_hccl_op = prev_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(prev_hccl_op);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(prev_hccl_op);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
bool is_contain = false;
for (size_t i = 1; i < cur_hccl->size(); ++i) {
if (cur_hccl->input(i) == output) {
is_contain = true;
break;
}
}
if (!is_contain) {
return true;
}
}
}
return false;
}
} // namespace
AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node);
std::vector<AnfNodePtr> memcpy_async_list;
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
for (size_t i = 1; i < hccl_node->size(); ++i) {
auto input = hccl_node->input(i);
MS_EXCEPTION_IF_NULL(input);
// when input is also a hccl op and just part outputs of it linking with cur_hccl_op
if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) {
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
auto kernel_info = std::make_shared<device::KernelInfo>();
memcpy_async->set_kernel_info(kernel_info);
MS_EXCEPTION_IF_NULL(kernel_select_);
kernel_select_->SelectKernel(memcpy_async->cast<CNodePtr>());
new_inputs.push_back(memcpy_async);
memcpy_async_list.push_back(memcpy_async);
} else {
new_inputs.push_back(input);
}
}
if (!memcpy_async_list.empty()) {
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
new_hccl_node->set_inputs(new_inputs);
return new_hccl_node;
}
return nullptr;
}
const AnfNodePtr InsertMemcpyAsyncForCascade::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (!AnfAlgo::IsCommunicationOp(node)) {
return nullptr;
}
return InsertMemcpyAsync(func_graph, cnode);
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class InsertMemcpyAsyncForCascade : public PatternProcessPass {
public:
explicit InsertMemcpyAsyncForCascade(bool multigraph = true)
: PatternProcessPass("insert_memcpy_async_for_cascade", multigraph),
kernel_select_(std::make_shared<KernelSelect>()) {}
~InsertMemcpyAsyncForCascade() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const;
KernelSelectPtr kernel_select_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_OP_CASCADE_H_
......@@ -32,12 +32,17 @@ const std::set<std::string> kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNe
bool IsParameterOrValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>();
auto real_node = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_node);
if (real_node->isa<Parameter>()) {
return true;
}
return real_node->isa<ValueNode>();
}
void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) {
void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list,
const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(hccl_node);
MS_EXCEPTION_IF_NULL(memcpy_async);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
......@@ -48,49 +53,62 @@ void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async,
}
// find hccl_node's output which is a control depend
for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first;
int output_index = node_index.second;
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
CNodePtr control_depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(control_depend);
std::vector<AnfNodePtr> new_inputs;
for (size_t i = 0; i < control_depend->size(); ++i) {
if (i == IntToSize(output_index)) {
new_inputs.push_back(memcpy_async);
} else {
new_inputs.push_back(control_depend->input(i));
}
if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimControlDepend)) {
continue;
}
CNodePtr control_depend = node_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(control_depend);
std::vector<AnfNodePtr> new_inputs;
for (size_t i = 0; i < control_depend->size(); ++i) {
if (i == IntToSize(node_index.second)) {
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end());
make_tuple_inputs.emplace_back(hccl_node);
auto make_tuple = graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
new_inputs.push_back(make_tuple);
} else {
new_inputs.push_back(control_depend->input(i));
}
control_depend->set_inputs(new_inputs);
}
control_depend->set_inputs(new_inputs);
}
}
} // namespace
bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const {
bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input,
const CNodePtr &cur_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input);
MS_EXCEPTION_IF_NULL(cur_node);
// when input is a parameter or is a value node
if (IsParameterOrValueNode(input)) {
return true;
}
// when input is a Ref or some special cnodes
if (kernel_query_->IsTbeRef(input) ||
kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) {
return true;
}
if (input->isa<CNode>()) {
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(input);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
// when input is used by others
if (iter->second.size() > 1) {
return true;
// when input is a Ref cnode
if (kernel_query_->IsTbeRef(input)) {
return true;
}
// when input is some special cnodes
if (kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) {
return true;
}
// when input is used by others
auto iter = node_users.find(input);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
if (iter->second.size() > 1) {
return true;
}
}
return false;
}
......@@ -98,21 +116,20 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node);
bool has_insert_memcpy = false;
AnfNodePtr memcpy_async = nullptr;
std::vector<AnfNodePtr> memcpy_async_list;
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
for (size_t i = 1; i < hccl_node->size(); ++i) {
auto input = hccl_node->input(i);
if (NeedInsertMemcpy(graph, input)) {
memcpy_async = CreateMemcpyAsyncOp(graph, input);
has_insert_memcpy = true;
if (NeedInsertMemcpy(graph, input, hccl_node)) {
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
new_inputs.push_back(memcpy_async);
memcpy_async_list.push_back(memcpy_async);
} else {
new_inputs.push_back(input);
}
}
if (has_insert_memcpy) {
if (!memcpy_async_list.empty()) {
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
new_hccl_node->set_inputs(new_inputs);
auto manager = graph->manager();
......@@ -122,9 +139,7 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
MS_LOG(DEBUG) << "end replace";
// transer hccl op's control to the memcpy_async
if (hccl_node->size() == 2) {
TransferControl(new_hccl_node, memcpy_async, graph);
}
TransferControl(new_hccl_node, memcpy_async_list, graph);
}
}
......
......@@ -32,7 +32,7 @@ class InsertMemcpyAsyncForHcclOp : public PatternProcessPass {
private:
void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const;
bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const;
bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &cur_node) const;
KernelQueryPtr kernel_query_;
};
} // namespace opt
......
......@@ -22,6 +22,7 @@
#include "utils/utils.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/optimizer/common/optimizer.h"
#include "ir/param_value.h"
#define private public
#define protected public
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
......@@ -44,12 +45,10 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery {
~MockInsertMemcpyForHcclKernelQuery() override = default;
bool IsTbeRef(const AnfNodePtr &node) override {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
if (!node->isa<CNode>()) {
return false;
}
auto name = AnfAlgo::GetCNodeName(cnode);
return name == "ApplyMomentum";
return AnfAlgo::GetCNodeName(node->cast<CNodePtr>()) == "ApplyMomentum";
}
};
......@@ -105,6 +104,11 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) {
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr);
for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
......@@ -146,10 +150,16 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
ASSERT_TRUE(g != nullptr);
std::vector<int> shp_x{1, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract};
AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr);
for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
......@@ -161,5 +171,34 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWInsertMemcpyForHccl, test_cond5) {
get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "before");
ASSERT_TRUE(g != nullptr);
std::vector<int> shp_x{1, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr);
for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>();
pm->AddPass(pass);
optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kg);
kg->SetExecOrderByDefault();
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore
......@@ -17,6 +17,7 @@ from mindspore.ops import Primitive
from mindspore.ops import operations as P
all_reduce = P.AllReduce()
broadcast = P.Broadcast(1)
memcpy_async = Primitive('memcpy_async')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
......@@ -101,20 +102,40 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag):
fns = FnDict()
@fns
def before(a, b, c, d, e):
res1 = apply_momentun(a, b, c, d, e)
res2 = all_reduce(a)
res = control_depend(res1, res2)
res = make_tuple(res, res2)
def before(a, b):
x = relu(a)
y = all_reduce(b)
res = control_depend(x, y)
return res
@fns
def after(a, b, c, d, e):
res1 = apply_momentun(a, b, c, d, e)
res2 = memcpy_async(a)
res3 = all_reduce(res2)
res = control_depend(res1, res2)
res = make_tuple(res, res3)
def after(a, b):
x = relu(a)
y1 = memcpy_async(b)
y2 = all_reduce(y1)
res = control_depend(x, make_tuple(y1, y2))
return make_tuple(res)
return fns[tag]
def test_insert_memcpy_async_for_hccl_op_cond5(tag):
fns = FnDict()
@fns
def before(a, b, c):
x = relu(a)
y = broadcast((b, c))
res = control_depend(x, y)
return res
@fns
def after(a, b, c):
x = relu(a)
m1 = memcpy_async(b)
m2 = memcpy_async(c)
y = broadcast(m1, m2)
res = control_depend(x, make_tuple(m1, m2, y))
return make_tuple(res)
return fns[tag]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册