提交 130cc296 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2931 Ascend control flow not split graphs

Merge pull request !2931 from zhoufeng/liantiao1
......@@ -40,6 +40,9 @@ using kernel::KernelBuildInfoPtr;
using kernel::KernelMod;
using kernel::KernelModPtr;
namespace {
constexpr size_t kNopNodeInputSize = 2;
constexpr size_t kNopNodeRealInputIndex = 1;
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
std::vector<size_t> shape_size_t;
......@@ -48,6 +51,26 @@ std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
}
} // namespace
AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
MS_EXCEPTION_IF_NULL(tuple_get_item);
if (tuple_get_item->size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
}
size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
MS_EXCEPTION_IF_NULL(tuple_get_item);
if (tuple_get_item->size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(output_index_value_node);
auto value_node = output_index_value_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return IntToSize(GetValue<int>(value_node->value()));
}
KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
if (anf_node->isa<ValueNode>()) {
......@@ -83,49 +106,47 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
}
}
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, int index,
bool visit_nop_node,
const std::vector<PrimitivePtr> &return_types) {
MS_EXCEPTION_IF_NULL(anf_node);
for (const auto &prim_type : return_types) {
if (CheckPrimitiveType(anf_node, prim_type)) {
return std::make_pair(anf_node, index);
}
if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
return CheckPrimitiveType(anf_node, prim_type);
})) {
return KernelWithIndex(anf_node, index);
}
if (anf_node->isa<ValueNode>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<Parameter>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input0 = cnode->input(0);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(input2);
auto value_node = input2->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
visit_nop_node, return_types);
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
} else if (opt::IsNopNode(cnode) && visit_nop_node) {
if (cnode->inputs().size() == 2) {
return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
} else {
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
if (!anf_node->isa<CNode>()) {
return KernelWithIndex(anf_node, 0);
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode),
GetTupleGetItemOutIndex(cnode), visit_nop_node, return_types);
if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
const std::vector<AnfNodePtr> &make_tuple_inputs = make_tuple->inputs();
size_t make_tuple_input_index = item_with_index_tmp.second + 1;
if (make_tuple_input_index >= make_tuple_inputs.size()) {
MS_LOG(EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size()
<< "].";
}
} else {
return std::make_pair(anf_node, index);
return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, visit_nop_node, return_types);
}
} else {
MS_LOG(EXCEPTION) << "The input is invalid";
return item_with_index_tmp;
}
if (CheckPrimitiveType(cnode, prim::kPrimDepend) || CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types);
}
if (opt::IsNopNode(cnode) && visit_nop_node) {
if (cnode->size() != kNopNodeInputSize) {
MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString();
}
return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, visit_nop_node, return_types);
}
return KernelWithIndex(anf_node, index);
}
std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
......@@ -591,7 +612,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
if (opt::IsNopNode(node) && visit_nop_node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() == 2) {
if (cnode->size() == kNopNodeInputSize) {
return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
} else {
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
......@@ -613,7 +634,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
if (opt::IsNopNode(node) && visit_nop_node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() == 2) {
if (cnode->inputs().size() == kNopNodeInputSize) {
return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
} else {
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
......@@ -806,7 +827,7 @@ bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
IsPrimitive(input, prim::kPrimReturn);
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
return !is_virtual_node;
}
......@@ -1117,5 +1138,14 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s
}
return GetCNodeOutputPrecision(kernel_with_index.first);
}
bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode.";
}
auto input = node->input(kAnfPrimitiveIndex);
return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
}
} // namespace session
} // namespace mindspore
......@@ -42,9 +42,12 @@ using DeviceAddress = device::DeviceAddress;
using DeviceAddressPtr = device::DeviceAddressPtr;
class AnfRuntimeAlgorithm {
public:
// get real input node of tuple_get_item
static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
// get input_anf_node's real kernel by recurse
static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, int output_index,
bool visit_nop_node = false,
const std::vector<PrimitivePtr> &return_types = {
prim::kPrimMakeTuple});
......@@ -205,6 +208,7 @@ class AnfRuntimeAlgorithm {
static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.
static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
static bool IsCondControlKernel(const CNodePtr &node);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;
......
......@@ -20,6 +20,8 @@
#include <map>
#include <vector>
#include <tuple>
#include <utility>
#include <functional>
#include "backend/session/kernel_graph.h"
#include "utils/base_ref.h"
#include "utils/contract.h"
......@@ -29,16 +31,23 @@ namespace mindspore {
namespace session {
class AscendControlParser {
public:
static void ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map);
static void LinkGraph(NotNull<KernelGraphPtr> kg);
static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node);
static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
NotNull<AnfNodePtr> second_node);
static void ExecutorValidate(NotNull<KernelGraphPtr> root_graph);
static void UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, const AnfNodePtr &jump_node,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
private:
class ReferenceCounter;
static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
static void EraseLabel(NotNull<KernelGraphPtr> root_graph);
static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg,
const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,
const NotNull<std::set<KernelGraphPtr> *> memo);
static NotNull<CNodePtr> GetStartLabel(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label);
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
......@@ -53,11 +62,10 @@ class AscendControlParser {
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label);
static KernelGraphPtr ParsePartial(NotNull<AnfNodePtr> node);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallNode(NotNull<CNodePtr> call_node);
static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node);
// root graph order
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
......@@ -65,6 +73,19 @@ class AscendControlParser {
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo);
};
class AscendControlParser::ReferenceCounter {
public:
explicit ReferenceCounter(std::function<bool(int32_t, int32_t)> func) : predicate_(func), count_() {}
void AddReadCount(const AnfNodePtr &key, int32_t num);
void AddWriteCount(const AnfNodePtr &key, int32_t num);
void EraseElem(const AnfNodePtr &key);
bool HasValidElem() const;
std::tuple<AnfNodePtr, int32_t, int32_t> GetOneValidElem() const;
private:
std::function<bool(int32_t, int32_t)> predicate_;
std::map<AnfNodePtr, std::pair<int32_t, int32_t>> count_;
};
} // namespace session
} // namespace mindspore
......
......@@ -289,6 +289,17 @@ static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
// this action should from bottom to top
graph->UpdateCallRealInput();
}
void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
auto return_node = root_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
if (return_node->size() <= kReturnDataIndex) {
return;
}
auto make_tuple = root_graph->NewCNode(
{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
root_graph->set_output(make_tuple);
}
} // namespace
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
......@@ -305,22 +316,39 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
std::vector<KernelGraphPtr> all_graphs;
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
BackendOptimization(all_graphs);
// split switch
SplitGraphs(NOT_NULL(root_graph));
// empty graph dont entry to backend
if (root_graph->execution_order().empty()) {
MS_LOG(INFO) << root_graph->ToString() << " is empty graph.";
InsertMakeTupleForOutput(NOT_NULL(root_graph));
root_graph->set_executable(false);
InitRuntimeResource();
return root_graph->graph_id();
}
// create parameter for multiple branch
std::set<KernelGraphPtr> memo;
CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
// insert goto labels and label_sets
LinkChildGraphs(NOT_NULL(root_graph));
// resource initialize
InitRuntimeResource();
// recurse compile child root_graph
std::set<KernelGraphPtr> memo;
RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo));
IrFusionPass(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
SelectKernel(NOT_NULL(root_graph));
memo.clear();
HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
// add make_tuple to the output graph
InsertMakeTupleForOutput(NOT_NULL(root_graph));
// root root_graph valiate,include genearte execute order and so on
RootGraphExecutorValidate(NOT_NULL(root_graph));
// adjust kernel
......@@ -1682,7 +1710,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
bool split_flag = false;
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
// update the root graph child graph order
AscendControlParser::UpdateChildGraphOrder(graph);
graph->UpdateChildGraphOrder();
// get child list from current graph
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(apply_list, cut_prims);
if (child_graph_lists.size() > 1) {
......@@ -1714,7 +1742,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
}
split_flag = true;
}
AscendControlParser::UpdateChildGraphOrder(graph);
graph->UpdateChildGraphOrder();
UpdateRealInput(graph, split_flag, memo);
MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end";
}
......@@ -1753,5 +1781,216 @@ void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const Not
}
}
}
void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(graph.get()) != memo->end()) {
return;
}
memo->insert(graph.get());
graph->UpdateChildGraphOrder();
for (auto &child_graph : graph->child_graph_order()) {
CreateMultiBranchOutput(NOT_NULL(child_graph), memo);
}
std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
auto node_list = GetCNodes(TopoSort(graph->get_return()));
for (auto &node : node_list) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
// auto multi_output_param = graph->NewParameter();
auto origin_inputs = graph->inputs();
auto output_param = CreateNewParameterFromCNode(node, true, graph.get().get());
MS_EXCEPTION_IF_NULL(graph->MutableInputs());
graph->MutableInputs()->operator=(origin_inputs);
graph->AddChildGraphResult(output_param);
std::vector<AnfNodePtr> depend_inputs = {
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name()))), output_param, node};
auto depend = graph->NewCNode(depend_inputs);
need_replace_list.emplace(node, depend);
MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString()
<< ", depend node is " << depend->DebugString();
// insert assign in order to transfer child graph output to parameter
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node);
for (auto &child_graph : child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
if (child_graph->get_output_null()) {
continue;
}
auto graph_output = child_graph->output();
AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, NOT_NULL(graph_output),
NOT_NULL(output_param));
}
}
}
// searching for nodes' input to replace call by depend(parameter, call)
for (auto &node : node_list) {
for (size_t i = 0; i < node->size(); ++i) {
auto input = node->input(i);
auto iter = need_replace_list.find(input);
if (iter != need_replace_list.end()) {
node->set_input(i, iter->second);
}
}
}
}
void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph.get());
opt::AscendBackendIRFusionOptimization(graph);
opt::AscendBackendFuseBasicOpt(graph, true);
opt::AscendBackendGraphKernelOpt(graph, true);
graph->SetExecOrderByDefault();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs) {
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
std::string file_path =
save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_path, graph.get());
}
for (auto &child_graph : graph->child_graph_order()) {
IrFusionPass(NOT_NULL(child_graph), memo);
}
}
void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
MS_LOG(INFO) << "Start select kernel.";
size_t raise_precision_count = 0;
size_t reduce_precision_count = 0;
std::set<KernelGraphPtr> memo;
(void)RecurseSelectKernelInfo(root_graph, NOT_NULL(&memo), &raise_precision_count, &reduce_precision_count);
memo.clear();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kGraphMode) {
if (raise_precision_count > 0) {
MS_LOG(WARNING) << "There has " << raise_precision_count
<< " node/nodes used raise precision to selected the kernel!";
}
if (reduce_precision_count > 0) {
MS_LOG(WARNING) << "There has " << raise_precision_count
<< " node/nodes used reduce precision to selected the kernel!";
}
}
MS_LOG(INFO) << "Finish!";
}
void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
NotNull<std::set<KernelGraphPtr> *> const memo,
size_t *const raise_precision_count,
size_t *const reduce_precision_count) const {
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph.get());
MS_LOG(INFO) << "Start to select kernel info in graph: " << graph->graph_id();
for (const auto &cnode : graph->execution_order()) {
if (AnfAlgo::IsCondControlKernel(cnode)) {
std::vector<KernelGraphPtr> child_graphs;
if (AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)) {
child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph);
}
for (auto &child_graph : child_graphs) {
RecurseSelectKernelInfo(NOT_NULL(child_graph), memo, raise_precision_count, reduce_precision_count);
}
}
auto status = device::ascend::SelectKernelInfo(cnode);
if (status == device::ascend::kStatusRaisePrecision) {
(*raise_precision_count)++;
} else if (status == device::ascend::kStatusReducePrecision) {
(*reduce_precision_count)++;
}
MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs) {
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
std::string file_path =
save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_path, graph.get());
}
MS_LOG(INFO) << "Finish selecting kernel info in graph: " << graph->graph_id();
}
void AscendSession::HardwareOptimize(NotNull<KernelGraphPtr> graph,
NotNull<std::set<KernelGraphPtr> *> const memo) const {
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph.get());
MS_LOG(INFO) << "Start to do HardwareOptimize in graph: " << graph->graph_id();
// convert kernel Graph to model
predictmodel::StepConvertGraph(graph.get());
HardwareOptimize(graph.get());
for (auto &child_graph : graph->child_graph_order()) {
HardwareOptimize(NOT_NULL(child_graph), memo);
}
MS_LOG(INFO) << "Finish doing HardwareOptimize in graph: " << graph->graph_id();
}
void AscendSession::AssignStaticMemory(NotNull<KernelGraphPtr> graph,
NotNull<std::set<KernelGraphPtr> *> const memo) const {
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph.get());
MS_LOG(INFO) << "Start to assign static memory for parameter in graph: " << graph->graph_id();
// assign static memory for parameters
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->AssignStaticMemoryInput(graph.get().get());
runtime_instance->AssignStaticMemoryValueNode(graph.get().get());
for (auto &child_graph : graph->child_graph_order()) {
AssignStaticMemory(NOT_NULL(child_graph), memo);
}
MS_LOG(INFO) << "Finish assigning static memory for parameter in graph: " << graph->graph_id();
}
void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
NotNull<std::set<KernelGraphPtr> *> const memo) const {
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph.get());
for (auto &child_graph : graph->child_graph_order()) {
UpdateRefOutputMap(NOT_NULL(child_graph), memo);
// copy ref map to final graph
auto child_ref_map = child_graph->GetRefMap();
for (auto &item : child_ref_map) {
if (graph->IsInRefOutputMap(item.first)) {
MS_LOG(WARNING) << "The ref pair <" << item.first.first->DebugString() << ", " << item.first.second
<< "> is already in " << graph->ToString();
continue;
}
graph->AddRefCorrespondPairs(item.first, item.second);
}
}
}
} // namespace session
} // namespace mindspore
......@@ -151,6 +151,15 @@ class AscendSession : public SessionBasic {
// sync intial tensors' data to device
void SyncInitialTenosrToDevice();
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
// create parameter to receive data from multiple branch output
void CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
void SelectKernel(NotNull<KernelGraphPtr> root_graph);
void RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> const memo,
size_t *const raise_precision_count, size_t *const reduce_precision_count) const;
void IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
void HardwareOptimize(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
// member variables
// key is final_graph_id,value is child graph execute order of final graph
......
......@@ -616,8 +616,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
depend_mode = AnfAlgo::GetNodeAttr<int>(cnode, kControlDependMode);
}
MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
<< "], depend_mode :" << depend_mode << ".";
MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
<< "], depend_mode :" << depend_mode << ".";
if (prior_node->isa<Parameter>() && depend_mode == 1) {
prior_nodes = GetOutputNodes(prior_node);
}
......@@ -647,7 +647,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
}
MS_EXCEPTION_IF_NULL(first_node);
MS_EXCEPTION_IF_NULL(second_node);
MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString()
<< ",second node:" << second_node->DebugString();
AddDependEdge(second_node, first_node, 1);
}
}
......@@ -991,6 +992,30 @@ bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const {
return false;
}
void KernelGraph::UpdateChildGraphOrder() {
MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
SetExecOrderByDefault();
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
std::vector<KernelGraphPtr> child_graph_order;
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
for (const auto &child_graph : call_child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
if (child_graph != parent_graph_) {
auto shared_this = std::dynamic_pointer_cast<KernelGraph>(shared_from_this());
MS_EXCEPTION_IF_NULL(shared_this);
child_graph->set_parent_graph(shared_this);
}
child_graph_order.push_back(child_graph);
}
}
for (size_t i = 0; i < child_graph_order.size(); ++i) {
MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
}
child_graph_order_ = child_graph_order;
}
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); }
......
......@@ -156,6 +156,12 @@ class KernelGraph : public FuncGraph {
bool IsFinalOutputKernel(const AnfNodePtr &node) const;
uint32_t current_epoch() const { return current_epoch_; }
void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
void UpdateChildGraphOrder();
const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
void AddChildGraphResult(const AnfNodePtr &parameter) { child_graph_result_.push_back(parameter); }
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
child_graph_result_ = child_graph_result;
}
private:
// remove value node form graph
......@@ -173,6 +179,7 @@ class KernelGraph : public FuncGraph {
void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
std::vector<AnfNodePtr> child_graph_result_;
std::vector<CNodePtr> execution_order_;
uint32_t graph_id_;
uint32_t stream_distinction_label_;
......
......@@ -74,7 +74,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
return input_tensors[input_idx];
}
}
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr";
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr";
}
}
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
......@@ -107,8 +107,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
return tensor;
}
BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(anf);
MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
......@@ -120,7 +120,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
MS_EXCEPTION_IF_NULL(cnode);
VectorRef ret;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
auto out = CreateTensorForOutput(cnode->input(i), graph, input_tensors);
ret.push_back(out);
}
return ret;
......@@ -133,25 +133,6 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
}
BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(anf);
if (!AnfAlgo::IsRealKernel(anf)) {
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel";
}
if (anf->isa<ValueNode>()) {
return CreateOneTensor(anf, 0, graph, input_tensors);
}
VectorRef ret;
if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
auto out = CreateOneTensor(anf, i, graph, input_tensors);
ret.emplace_back(out);
}
}
return ret;
}
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
......@@ -880,20 +861,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
const std::vector<tensor::TensorPtr> &input_tensors) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
if (!kernel_graph->child_graph_order().empty()) {
// use the last child graph output as the root graph output
UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors);
return;
}
auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
continue;
}
outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors));
}
}
......
......@@ -294,6 +294,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(mem_manager_);
auto graph_inputs = graph->inputs();
auto graph_valid_input = graph->valid_inputs();
graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
std::vector<AnfNodePtr> need_alloc_nodes;
for (size_t i = 0; i < graph_inputs.size(); ++i) {
auto item = graph_inputs[i];
......
......@@ -240,6 +240,7 @@ constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
constexpr auto kAttrOffset = "offset";
constexpr auto kAttrPsKey = "ps_key";
constexpr auto kAttrOptimizerType = "optim_type";
constexpr auto kAttrChildGraph = "child_graph";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册