提交 6316343e 编写于 作者: L laiyongqiang

add atomic clean op for every communication op's input

上级 034453e4
......@@ -184,11 +184,17 @@ bool IsAtomicNode(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto atomic_flag = false;
std::vector<size_t> clean_output_indexs;
if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, kernel_node)) {
clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(kernel_node, kAttrAutomicOutputIndexs);
atomic_flag = true;
}
auto parameters_indexs = kernel_mod->GenParameters();
if (parameters_indexs.empty()) {
return false;
return atomic_flag;
}
auto atomic_flag = false;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
auto workspace_size_list = kernel_mod->GetWorkspaceSizeList();
......@@ -199,7 +205,7 @@ bool IsAtomicNode(const CNodePtr &kernel_node) {
parameters_indexs.push_back(0);
}
}
std::vector<size_t> clean_output_indexs;
// in parameters data sort as input->workspace->output
size_t index = 0;
while (index < output_num) {
......@@ -210,6 +216,8 @@ bool IsAtomicNode(const CNodePtr &kernel_node) {
index++;
}
if (atomic_flag) {
std::set<size_t> s(clean_output_indexs.begin(), clean_output_indexs.end());
clean_output_indexs.assign(s.begin(), s.end());
AnfAlgo::SetNodeAttr(kAttrAutomicOutputIndexs, MakeValue(clean_output_indexs), kernel_node);
}
for (size_t i = 0; i < workspace_num; ++i) {
......@@ -238,11 +246,49 @@ bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) {
return ret;
}
std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
const mindspore::session::KernelGraph *kernel_graph) {
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map;
for (auto &kernel : kernel_graph->execution_order()) {
auto input_num = AnfAlgo::GetInputTensorNum(kernel);
if (mindspore::session::AnfRuntimeAlgorithm::IsCommunicationOp(kernel)) {
for (size_t i = 0; i < input_num; i++) {
auto input_node = kernel->input(i + 1);
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
MS_LOG(INFO) << " Add atomic clean for single communication op input, comm:" << kernel->fullname_with_scope()
<< " input_node: " << kernel_input.first->fullname_with_scope()
<< " index: " << kernel_input.second;
auto iter = comm_input_info_map.find(kernel_input.first);
if (iter != comm_input_info_map.end()) {
iter->second.push_back(kernel_input.second);
} else {
std::vector<size_t> indexes = {kernel_input.second};
comm_input_info_map[kernel_input.first] = indexes;
}
}
}
}
// remove duplicate index
for (auto &info : comm_input_info_map) {
std::set<size_t> s(info.second.begin(), info.second.end());
info.second.assign(s.begin(), s.end());
}
return comm_input_info_map;
}
void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<CNodePtr> new_nodes;
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
for (const auto &anf_node : kernel_graph->execution_order()) {
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
auto indexes = comm_input_info_map[anf_node];
AnfAlgo::SetNodeAttr(kAttrAutomicOutputIndexs, MakeValue(indexes), anf_node);
}
if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册