提交 af5019b9 编写于 作者: Z zhoufeng

link child graphs

Signed-off-by: Nzhoufeng <zhoufeng54@huawei.com>
上级 d9c74e0a
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "hccl/hcom.h" #include "hccl/hcom.h"
#include "common/trans.h" #include "common/trans.h"
#include "runtime/context.h" #include "runtime/context.h"
#include "device/ascend/ascend_label_assign.h"
#include "device/ascend/ascend_stream_assign.h" #include "device/ascend/ascend_stream_assign.h"
#include "device/ascend/ascend_memory_pool.h" #include "device/ascend/ascend_memory_pool.h"
#include "framework/ge_runtime/model_runner.h" #include "framework/ge_runtime/model_runner.h"
...@@ -281,21 +282,24 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { ...@@ -281,21 +282,24 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
return true; return true;
} }
AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); AscendStreamAssign &stream_assign_instance = AscendStreamAssign::GetInstance();
AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance();
// the streams' flag not HEAD_STREAM // the streams' flag not HEAD_STREAM
std::vector<uint32_t> wait_active_stream_list; std::vector<uint32_t> wait_active_stream_list;
assign_instance.GetWaitStreams(&wait_active_stream_list); stream_assign_instance.GetWaitStreams(&wait_active_stream_list);
auto force_copy_stream_list = assign_instance.hcom_streams(); auto force_copy_stream_list = stream_assign_instance.hcom_streams();
MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum() MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_assign_instance.GetTotalStreamNum()
<< ", total event num:" << assign_instance.total_event_num() << ", total event num:" << stream_assign_instance.total_event_num()
<< ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph))
<< ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", wait_active_stream_list size:" << wait_active_stream_list.size()
<< ", force_copy_stream_list size:" << force_copy_stream_list.size(); << ", force_copy_stream_list size:" << force_copy_stream_list.size();
std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list;
std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>( std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>(
task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0,
0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0); 0, 0, 0, 0, 0, stream_assign_instance.GetTotalStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)),
stream_assign_instance.total_event_num(), 0);
auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model));
if (!ret.second) { if (!ret.second) {
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
*/ */
#include <vector> #include <vector>
#include <string>
#include <set>
#include "device/ascend/ascend_label_assign.h" #include "device/ascend/ascend_label_assign.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
...@@ -36,6 +38,7 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) { ...@@ -36,6 +38,7 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) {
uint32_t goto_label_id = GetValue<uint32_t>(value); uint32_t goto_label_id = GetValue<uint32_t>(value);
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get()); AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get());
MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id;
node->set_inputs({node->input(0)});
} }
static void UpdateLabelSwitch(NotNull<CNodePtr> node) { static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
...@@ -58,29 +61,93 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) { ...@@ -58,29 +61,93 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id;
} }
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get()); AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get());
node->set_inputs({node->input(0), node->input(1)});
} }
void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph) { static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>> graph, NotNull<uint32_t *> label_id,
auto cnode_list = graph->execution_order(); NotNull<std::set<std::shared_ptr<session::KernelGraph>> *> memo) {
// 1 assign label id to label_set if (memo->find(graph.get()) != memo->end()) {
uint32_t cur_label_id = 0; return;
for (auto &node : cnode_list) { }
if (AnfAlgo::GetCNodeName(node) == kLabelSetOpName) {
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(cur_label_id), node); MS_LOG(INFO) << "Assign label for " << graph->ToString();
MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << cur_label_id; auto nodes = TopoSort(graph->get_return());
++cur_label_id; for (auto &node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::string node_name = AnfAlgo::GetCNodeName(node);
if (node_name == kLabelSetOpName && !AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) {
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(*label_id), node);
MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << *label_id;
++(*label_id);
} }
} }
// 2 update label_switch / label_goto
for (auto &node : cnode_list) { for (auto &cg : graph->child_graph_order()) {
if (AnfAlgo::GetCNodeName(node) == kLabelGotoOpName) { AssignLabelForLabelSet(NOT_NULL(cg), label_id, memo);
UpdateLabelGoto(NOT_NULL(node)); }
}
static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGraph>> graph,
NotNull<std::set<std::shared_ptr<session::KernelGraph>> *> memo) {
if (memo->find(graph.get()) != memo->end()) {
return;
}
MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString();
auto nodes = TopoSort(graph->get_return());
for (auto &node : nodes) {
if (!node->isa<CNode>()) {
continue;
} }
if (AnfAlgo::GetCNodeName(node) == kLabelSwitchOpName) { auto cnode = node->cast<CNodePtr>();
UpdateLabelSwitch(NOT_NULL(node)); MS_EXCEPTION_IF_NULL(cnode);
std::string node_name = AnfAlgo::GetCNodeName(node);
if (node_name == kLabelGotoOpName) {
UpdateLabelGoto(NOT_NULL(cnode));
cnode->set_abstract(nullptr);
} }
if (node_name == kLabelSwitchOpName) {
UpdateLabelSwitch(NOT_NULL(cnode));
}
}
for (auto &cg : graph->child_graph_order()) {
AssignLabelForGotoSwitch(NOT_NULL(cg), memo);
}
}
void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) {
MS_LOG(INFO) << "Assign label start.";
std::set<std::shared_ptr<session::KernelGraph>> memo;
uint32_t label_id = 0;
AssignLabelForLabelSet(graph, NOT_NULL(&label_id), NOT_NULL(&memo));
memo.clear();
{
std::lock_guard<std::mutex> lock(label_num_mutex_);
label_num_[graph.get().get()] = label_id;
} }
AssignLabelForGotoSwitch(graph, NOT_NULL(&memo));
MS_LOG(INFO) << "Assign label end.";
}
uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) {
std::lock_guard<std::mutex> lock(label_num_mutex_);
auto iter = label_num_.find(graph.get());
if (iter == label_num_.end()) {
MS_LOG(WARNING) << "Graph " << graph->ToString() << " has not assigned label.";
return 1;
}
return iter->second;
}
uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) {
return GetLabelNum(NOT_NULL(graph.get().get()));
} }
} // namespace ascend } // namespace ascend
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ #define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_
#include <memory> #include <memory>
#include <map>
#include "session/kernel_graph.h" #include "session/kernel_graph.h"
#include "utils/contract.h" #include "utils/contract.h"
...@@ -35,11 +36,16 @@ class AscendLabelAssign { ...@@ -35,11 +36,16 @@ class AscendLabelAssign {
AscendLabelAssign(const AscendLabelAssign &) = delete; AscendLabelAssign(const AscendLabelAssign &) = delete;
AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; AscendLabelAssign &operator=(const AscendLabelAssign &) = delete;
void AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph); void AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph);
uint32_t GetLabelNum(NotNull<const session::KernelGraph *> graph);
uint32_t GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph);
private: private:
AscendLabelAssign() = default; AscendLabelAssign() = default;
~AscendLabelAssign() = default; ~AscendLabelAssign() = default;
std::map<const session::KernelGraph *, uint32_t> label_num_;
std::mutex label_num_mutex_;
}; };
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "kernel/rts/label_switch.h" #include "kernel/rts/label_switch.h"
#include <asm-generic/param.h> #include <asm-generic/param.h>
#include <memory> #include <memory>
#include <string>
#include "runtime/stream.h" #include "runtime/stream.h"
#include "framework/ge_runtime/task_info.h" #include "framework/ge_runtime/task_info.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
...@@ -66,13 +67,33 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr ...@@ -66,13 +67,33 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr
MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id;
std::vector<TaskInfoPtr> task_info_list; std::vector<TaskInfoPtr> task_info_list;
cond_ = inputs[0]->addr; cond_ = inputs[0]->addr;
// std::shared_ptr<LabelSwitchTaskInfo> task_info_ptr = // todo: need update ge task info define
// std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, &label_list_, cond_); auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, 0);
// need updata ge task info define // auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, label_list_, cond_);
std::shared_ptr<LabelSwitchTaskInfo> task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_);
MS_EXCEPTION_IF_NULL(task_info_ptr); MS_EXCEPTION_IF_NULL(task_info_ptr);
task_info_list.emplace_back(task_info_ptr); task_info_list.emplace_back(task_info_ptr);
return task_info_list; return task_info_list;
} }
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernelInfo() {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> label_switch_build_info{};
vector<string> input_format{kOpFormat_DEFAULT, kOpFormat_DEFAULT};
vector<TypeId> input_type{kNumberTypeUInt32, kNumberTypeBool};
if (input_format.size() != input_type.size()) {
MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size "
<< input_type.size();
}
for (size_t i = 0; i < input_format.size(); ++i) {
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetInputsFormat({input_format[i]});
builder.SetInputsDeviceType({input_type[i]});
builder.SetProcessor(AICORE);
builder.SetKernelType(RT_KERNEL);
builder.SetFusionType(OPAQUE);
label_switch_build_info.emplace_back(builder.Build());
}
return label_switch_build_info;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -42,6 +42,14 @@ class LabelSwitchKernel : public RtKernel { ...@@ -42,6 +42,14 @@ class LabelSwitchKernel : public RtKernel {
void *cond_; void *cond_;
}; };
class LabelSwitchDesc : public RtKerDesc {
public:
LabelSwitchDesc() = default;
~LabelSwitchDesc() override = default;
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetKernelInfo() override;
};
MS_REG_RTKERNEL_DESC(labelswitch, LabelSwitchDesc);
MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
......
...@@ -44,6 +44,12 @@ RtKerDescFactory &RtKerDescFactory::Get() { ...@@ -44,6 +44,12 @@ RtKerDescFactory &RtKerDescFactory::Get() {
return _this; return _this;
} }
static bool IsDefaultKernelInfo(const std::string &name) {
static const std::set<std::string> white_list = {kStreamSwitchOpName, kStreamActiveOpName, kLabelSetOpName,
kLabelGotoOpName};
return white_list.find(name) != white_list.end();
}
void GetRtKelInfo(const CNodePtr &kernel_node, void GetRtKelInfo(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_info_list); MS_EXCEPTION_IF_NULL(kernel_info_list);
...@@ -58,7 +64,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node, ...@@ -58,7 +64,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node,
} }
// if can't find kernel info in kernel info database, use the default kernel info // if can't find kernel info in kernel info database, use the default kernel info
auto node_name = AnfAlgo::GetCNodeName(kernel_node); auto node_name = AnfAlgo::GetCNodeName(kernel_node);
if (node_name == "StreamSwitch" || node_name == "StreamActive") { if (IsDefaultKernelInfo(node_name)) {
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set input infos // set input infos
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
......
...@@ -331,12 +331,14 @@ bool ExecuteAction(const ResourcePtr &res) { ...@@ -331,12 +331,14 @@ bool ExecuteAction(const ResourcePtr &res) {
} }
auto graph_id = res->results()[kOutput].cast<GraphId>(); auto graph_id = res->results()[kOutput].cast<GraphId>();
auto bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::MsBackend>>(); std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr);
MS_EXCEPTION_IF_NULL(msbc_ptr);
compile::VmEvalFuncPtr run = compile::VmEvalFuncPtr run =
std::make_shared<compile::VmEvalFunc>([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef { std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {
MS_LOG(INFO) << "Execute args size" << args.size(); MS_LOG(INFO) << "Execute args size " << args.size();
auto outs = bc_ptr->RunGraph(graph_id, args); auto outs = msbc_ptr->RunGraph(graph_id, args);
MS_LOG(DEBUG) << "out size" << outs.size(); MS_LOG(DEBUG) << "out size " << outs.size();
return outs[0]; return outs[0];
}); });
res->results()[kOutput] = run; res->results()[kOutput] = run;
......
...@@ -6,22 +6,23 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ...@@ -6,22 +6,23 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
) )
if (ENABLE_GPU) if (ENABLE_GPU)
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"gpu_session.cc" "gpu_session.cc"
) )
list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST})
endif () endif ()
if (ENABLE_CPU) if (ENABLE_CPU)
file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"cpu_session.cc" "cpu_session.cc"
) )
list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST}) list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST})
endif () endif ()
if (ENABLE_D) if (ENABLE_D)
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"ascend_session.cc" "ascend_session.cc"
"ascend_control_parser.cc"
) )
list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST})
endif () endif ()
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include <memory>
#include "session/ascend_control_parser.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace session {
static VectorRef GetCallArgs(std::vector<AnfNodePtr>::iterator iter_begin, std::vector<AnfNodePtr>::iterator iter_end) {
VectorRef call_args;
for (auto iter = iter_begin; iter != iter_end; ++iter) {
if (utils::isa<ValueNode>(*iter)) {
call_args.push_back(GetValueNode(*iter));
} else {
call_args.push_back(*iter);
}
}
return call_args;
}
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
std::set<KernelGraphPtr> memo;
ProcessKernelGraph(kg, nullptr, nullptr, {}, NOT_NULL(&memo));
}
NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label, const VectorRef &args,
NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString();
// 0. recursive condition
if (memo->find(kg) != memo->end()) {
MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString();
return NOT_NULL(kg->get_start_label());
}
// 2. args replace placeholder
LinkParentGraph(kg, last_node, last_label, args);
// 3. topological sort
std::vector<CNodePtr> nodes = GetCNodes(TopoSort(kg->get_return()));
if (nodes.empty()) {
MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!";
}
// 4. insert first_label
auto start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
for (auto node : nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimPartial)) {
InsertControlDependToGraph(kg, NOT_NULL(start_label), NOT_NULL(node));
break;
}
}
kg->set_start_label(start_label);
// 5. traverse
for (size_t i = 0; i < nodes.size(); ++i) {
auto &cnode = nodes[i];
if (cnode->size() < kCNodePrim + 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr fn = cnode->input(kCNodePrim);
if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
MS_LOG(DEBUG) << "continue node " << cnode->DebugString();
continue;
}
AnfNodePtr arg = cnode->input(kCNodeCallArg);
if (IsValueNode<KernelGraph>(arg)) {
RecurseCall(kg, NOT_NULL(cnode), (i + 1 < nodes.size() ? nodes[i + 1] : nullptr), memo);
} else if (!arg->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString();
} else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) {
auto arg_cnode = arg->cast<CNodePtr>();
cnode->set_inputs(cnode->inputs());
RecurseSwitch(kg, NOT_NULL(cnode), memo);
} else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) {
auto arg_cnode = arg->cast<CNodePtr>();
cnode->set_inputs(cnode->inputs());
RecurseSwitchLayer(kg, NOT_NULL(cnode), memo);
}
}
MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString();
return NOT_NULL(start_label);
}
std::vector<CNodePtr> AscendControlParser::GetCNodes(const std::vector<AnfNodePtr> &in) {
std::vector<CNodePtr> out;
for (auto &node : in) {
if (node->isa<CNode>()) {
out.push_back(node->cast<CNodePtr>());
}
}
return out;
}
void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) {
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))};
auto return_node = kg->get_return();
MS_EXCEPTION_IF_NULL(return_node);
inputs.push_back(return_node->input(1));
inputs.push_back(attch_node.get());
auto depend_node = kg->NewCNode(inputs);
return_node->set_input(1, depend_node);
}
void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
NotNull<AnfNodePtr> second_node) {
MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString()
<< ", the second node is " << second_node->DebugString();
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
first_node, second_node};
auto control_depend = kg->NewCNode(inputs);
InsertDependToGraph(kg, NOT_NULL(control_depend));
}
void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label, const VectorRef &args) {
if (from_graph_call_node != nullptr) {
SetSubGraphInput(kg, NOT_NULL(from_graph_call_node), args);
}
auto origin_return = kg->get_return();
std::vector<AnfNodePtr> origin_return_inputs = origin_return->inputs();
// if entry graph, replace return with make_tuple
if (from_graph_call_node == nullptr || last_label == nullptr) {
MS_LOG(INFO) << kg->ToString() << " is entry graph.";
std::vector<AnfNodePtr> make_tuple_inputs = {std::make_shared<ValueNode>(prim::kPrimMakeTuple)};
make_tuple_inputs.insert(make_tuple_inputs.end(), origin_return_inputs.begin() + 1, origin_return_inputs.end());
auto make_tuple = kg->NewCNode(make_tuple_inputs);
origin_return->set_inputs({origin_return->input(kCNodePrim), make_tuple});
} else {
// else replace return with label_goto
auto label_goto =
kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label});
InsertDependToGraph(kg, NOT_NULL(label_goto));
}
}
void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process call func " << cur_node->DebugString();
// 1 get kernel graph
auto origin_inputs = cur_node->inputs();
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))};
auto call_args = GetCallArgs(origin_inputs.begin() + 1, origin_inputs.end());
if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) {
MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode";
return;
}
// 2 return label
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
// 3 add depend relationship
InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
if (next_node != nullptr && next_node != kg->get_return()) {
InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
}
auto call_kg = GetValueNode<KernelGraphPtr>(origin_inputs[kCNodeCallArg]);
// 4 modify call op to goto op
cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]);
// 5 recurse sub graph
CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, call_args, memo);
new_inputs.push_back(sub_label);
new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end());
cur_node->set_inputs(new_inputs);
cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "success process call func " << cur_node->DebugString();
}
void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
if (cur_node->size() < kCNodeSwitchLength) {
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength;
}
// 1 return label
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(prim::kPrimLabelSet)});
// 2 recurse sub graph
auto origin_switch_inputs = cur_node->inputs();
std::vector<AnfNodePtr> new_switch_inputs = {
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
origin_switch_inputs[kCNodeSwitchCond]};
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
// 2.1 branch kernel graph and args
CNodePtr partial;
KernelGraphPtr branch_fg;
VectorRef call_args;
std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
// 2.2 add depend relationship
InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
// 2.3 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo);
new_switch_inputs.push_back(branch_label);
}
std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end());
cur_node->set_inputs(new_switch_inputs);
cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "success process switch func " << cur_node->DebugString();
}
void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
if (cur_node->size() < kCNodeSwitchLayerLength) {
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
}
auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch);
MS_EXCEPTION_IF_NULL(branch_tuple);
if (!branch_tuple->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
}
auto branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs();
// 1 return label
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName))});
// 2 recurse sub graph
auto origin_switch_inputs = cur_node->inputs();
std::vector<AnfNodePtr> new_switch_inputs = {std::make_shared<ValueNode>(prim::kPrimLabelSwitch),
origin_switch_inputs[kCNodeSwitchCond]};
for (size_t i = 0; i < branch_partial.size(); ++i) {
// 2.1 branch kernel graph and args
CNodePtr partial;
KernelGraphPtr branch_fg;
VectorRef call_args;
std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
// 2.2 add depend relationship
InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
// 2.3 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo);
new_switch_inputs.push_back(branch_label);
}
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
cur_node->set_inputs(new_switch_inputs);
cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString();
}
std::tuple<CNodePtr, KernelGraphPtr, VectorRef> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
if (!node.get()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString();
}
// 2.1 branch kernel graph and args
auto partial_cnode = utils::cast<CNodePtr>(node.get());
if (partial_cnode->size() < kCNodePartialLength) {
MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength;
}
auto partial_inputs = partial_cnode->inputs();
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
auto call_args = GetCallArgs(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end());
return {partial_cnode, branch_kg, call_args};
}
void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) {
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
return;
}
if (from.get() == to.get()) {
return;
}
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
<< to->DebugString();
// config inputs of assign node
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from};
// generate a new cnode
auto assign_node = kg->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(assign_node);
assign_node->set_abstract(to->abstract());
// append the assign at the end of from graph
InsertDependToGraph(kg, NOT_NULL(assign_node));
}
size_t AscendControlParser::SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node,
size_t input_index) {
auto output_num = AnfAlgo::GetOutputTensorNum(node);
if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
return input_index + output_num;
}
auto &graph_inputs = kg->inputs();
if (input_index >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
}
auto backend_parameter = graph_inputs[input_index];
if (node.get()->isa<Parameter>()) {
MS_EXCEPTION_IF_NULL(backend_parameter);
MS_LOG(INFO) << "Reuse node [" << node->DebugString() << "], old node[" << backend_parameter->DebugString()
<< "] will be replaced.";
kg->ReplaceNode(backend_parameter, node);
return input_index;
}
InsertAssignToGraph(kg, node, NOT_NULL(backend_parameter));
return input_index + 1;
}
void AscendControlParser::SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node,
const VectorRef &args) {}
} // namespace session
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
#include <set>
#include <vector>
#include <tuple>
#include "session/kernel_graph.h"
#include "utils/base_ref.h"
#include "utils/contract.h"
namespace mindspore {
namespace session {
class AscendControlParser {
public:
static void LinkGraph(NotNull<KernelGraphPtr> kg);
static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node);
static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
NotNull<AnfNodePtr> second_node);
private:
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label, const VectorRef &args,
NotNull<std::set<KernelGraphPtr> *> memo);
static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
NotNull<std::set<KernelGraphPtr> *> memo);
static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
NotNull<std::set<KernelGraphPtr> *> memo);
static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
NotNull<std::set<KernelGraphPtr> *> memo);
static std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &in);
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label, const VectorRef &args);
static void SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node,
const VectorRef &args);
static std::tuple<CNodePtr, KernelGraphPtr, VectorRef> ParsePartial(NotNull<AnfNodePtr> node);
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static size_t SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node, size_t input_index);
static constexpr size_t kCNodePrim = 0;
static constexpr size_t kCNodeCallArg = 1;
static constexpr size_t kCNodeSwitchCond = 1;
static constexpr size_t kCNodeSwitchTrue = 2;
static constexpr size_t kCNodeSwitchFalse = 3;
static constexpr size_t kCNodeSwitchLength = 4;
static constexpr size_t kCNodePartialLength = 2;
static constexpr size_t kCNodePartialFunc = 1;
static constexpr size_t kCNodeSwitchLayerCond = 1;
static constexpr size_t kCNodeSwitchLayerBranch = 2;
static constexpr size_t kCNodeSwitchLayerLength = 3;
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
...@@ -160,14 +160,14 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) { ...@@ -160,14 +160,14 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) {
std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) { std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
std::vector<CNodePtr> cnodes = {}; std::vector<CNodePtr> cnodes = {};
size_t i = 0; size_t i = 0;
for (const auto anf : anf_nodes) { for (auto anf : anf_nodes) {
MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString(); MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString();
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
if (anf->isa<CNode>()) { if (anf->isa<CNode>()) {
cnodes.push_back(anf->cast<CNodePtr>()); cnodes.push_back(anf->cast<CNodePtr>());
} }
} }
return std::move(cnodes); return cnodes;
} }
std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) { std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) {
...@@ -189,7 +189,7 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co ...@@ -189,7 +189,7 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end())); ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end()));
} }
} }
return std::move(ret); return ret;
} }
void UpdateRealInput(KernelGraph *graph) { void UpdateRealInput(KernelGraph *graph) {
...@@ -232,7 +232,7 @@ void UpdateRealInput(KernelGraph *graph) { ...@@ -232,7 +232,7 @@ void UpdateRealInput(KernelGraph *graph) {
auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end());
partial_cnode->set_inputs( partial_cnode->set_inputs(
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2));
return std::move(ret); return ret;
}; };
bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
...@@ -256,27 +256,28 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { ...@@ -256,27 +256,28 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
// split switch // split switch
SplitGraph(graph); SplitGraph(graph);
// insert goto labels and label_sets // insert goto labels and label_sets
LinkChildGraphs(graph.get()); LinkChildGraphs(NOT_NULL(graph));
// resource initialize // resource initialize
InitRuntimeResource(); InitRuntimeResource();
// ir fusion // assign label
IRFusion(graph); AssignLabel(NOT_NULL(graph));
// kernel select if (!graph->executable()) {
SelectKernelGraphKernel(*graph); return graph->graph_id();
// convert model of predict module }
ConvertPredictModel(graph); for (auto iter : graphs_) {
// hardware optimize if (iter.second == graph) {
HardwareOptimizeGraphs(graph); MS_LOG(INFO) << "Entry graph " << graph->ToString() << " graph id " << graph->graph_id();
final_graph_id_ = graph->graph_id();
}
MS_LOG(INFO) << "CompileChildGraph " << iter.second->ToString();
CompileChildGraph(iter.second);
}
// adjust kernel // adjust kernel
AdjustKernel(graph); AdjustKernel(graph);
// root graph valiate,include genearte execute order and so on // root graph valiate,include genearte execute order and so on
RootGraphExecutorValidate(graph.get()); RootGraphExecutorValidate(graph.get());
// assign stream // assign stream
AssignStream(graph); AssignStream(graph);
// assign label
AssignLabel(NOT_NULL(graph));
// build kernel if node is cnode
BuildKernel(graph);
// alloc mem // alloc mem
MemoryAlloc(graph.get()); MemoryAlloc(graph.get());
// task generate // task generate
...@@ -556,7 +557,7 @@ void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_grap ...@@ -556,7 +557,7 @@ void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_grap
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
} }
void AscendSession::AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const { void AscendSession::AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const {
MS_LOG(INFO) << "Start!"; MS_LOG(INFO) << "Start!";
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph); device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph);
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
...@@ -1305,29 +1306,13 @@ void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived ...@@ -1305,29 +1306,13 @@ void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived
} }
void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) {
MS_LOG(INFO) << "Insert depend at the end of graph, the attach node is " << attch_node->DebugString(); AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node));
auto graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))};
auto return_node = graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
inputs.push_back(return_node->input(1));
inputs.push_back(attch_node);
auto depend_node = graph->NewCNode(inputs);
return_node->set_input(1, depend_node);
} }
void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node,
const AnfNodePtr &second_node) { const AnfNodePtr &second_node) {
MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node),
<< ", the second node is " << second_node->DebugString(); NOT_NULL(second_node));
auto graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("ControlDepend"))};
inputs.push_back(first_node);
inputs.push_back(second_node);
auto control_depend = graph->NewCNode(inputs);
InsertDependToGraph(graph_id, control_depend);
} }
size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) {
...@@ -1482,5 +1467,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { ...@@ -1482,5 +1467,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
SplitGraph(child_graph); SplitGraph(child_graph);
} }
} }
void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); }
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "session/kernel_graph.h" #include "session/kernel_graph.h"
#include "kernel/kernel.h" #include "kernel/kernel.h"
#include "session/session_factory.h" #include "session/session_factory.h"
#include "session/ascend_control_parser.h"
namespace mindspore { namespace mindspore {
namespace session { namespace session {
...@@ -74,7 +75,7 @@ class AscendSession : public SessionBasic { ...@@ -74,7 +75,7 @@ class AscendSession : public SessionBasic {
void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const; void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const; void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const;
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void MemoryAlloc(KernelGraph *kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const;
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
...@@ -96,7 +97,8 @@ class AscendSession : public SessionBasic { ...@@ -96,7 +97,8 @@ class AscendSession : public SessionBasic {
void SetFinalGraphOutput(const VectorRef &vec_output); void SetFinalGraphOutput(const VectorRef &vec_output);
void SplitGraph(const KernelGraphPtr &graph); void SplitGraph(const KernelGraphPtr &graph);
void LinkChildGraphs(KernelGraph *graph) {} void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
void IRFusion(const KernelGraphPtr &graph) {} void IRFusion(const KernelGraphPtr &graph) {}
void SelectKernelGraphKernel(const KernelGraph &graph) {} void SelectKernelGraphKernel(const KernelGraph &graph) {}
void ConvertPredictModel(const KernelGraphPtr graph) {} void ConvertPredictModel(const KernelGraphPtr graph) {}
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "utils/graph_utils.h" #include "utils/graph_utils.h"
#include "utils/contract.h"
#include "device/kernel_info.h" #include "device/kernel_info.h"
namespace mindspore { namespace mindspore {
...@@ -108,6 +109,7 @@ class KernelGraph : public FuncGraph { ...@@ -108,6 +109,7 @@ class KernelGraph : public FuncGraph {
std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; } std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; }
// checkout whether current graph is leaf graph // checkout whether current graph is leaf graph
bool IsLeafGraph() const; bool IsLeafGraph() const;
// set input_tensors pointer of control parameter // set input_tensors pointer of control parameter
void set_input_ctrl_tensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &input_tensors_ptr) { void set_input_ctrl_tensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &input_tensors_ptr) {
input_ctrl_tensors_ = input_tensors_ptr; input_ctrl_tensors_ = input_tensors_ptr;
...@@ -126,6 +128,9 @@ class KernelGraph : public FuncGraph { ...@@ -126,6 +128,9 @@ class KernelGraph : public FuncGraph {
// used to dump ir // used to dump ir
std::string ToString() const override; std::string ToString() const override;
void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
CNodePtr get_start_label() { return start_label_; }
private: private:
// remove value node form graph // remove value node form graph
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
...@@ -168,12 +173,16 @@ class KernelGraph : public FuncGraph { ...@@ -168,12 +173,16 @@ class KernelGraph : public FuncGraph {
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_; std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_;
// child graph execute order in root graph // child graph execute order in root graph
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_; std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
// input_tensors of control parameter // input_tensors of control parameter
std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_; std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_;
// parameter graph // parameter graph
std::shared_ptr<KernelGraph> parent_graph_; std::shared_ptr<KernelGraph> parent_graph_;
// record real parameters,inputs_ is the formal parameters // record real parameters,inputs_ is the formal parameters
std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_; std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_;
CNodePtr start_label_;
}; };
} // namespace session } // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
......
...@@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ...@@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/transform/*.cc" "../../../mindspore/ccsrc/transform/*.cc"
"../../../mindspore/ccsrc/session/anf_runtime_algorithm.cc" "../../../mindspore/ccsrc/session/anf_runtime_algorithm.cc"
"../../../mindspore/ccsrc/session/ascend_session.cc" "../../../mindspore/ccsrc/session/ascend_session.cc"
"../../../mindspore/ccsrc/session/ascend_control_parser.cc"
"../../../mindspore/ccsrc/session/kernel_graph.cc" "../../../mindspore/ccsrc/session/kernel_graph.cc"
"../../../mindspore/ccsrc/session/session_basic.cc" "../../../mindspore/ccsrc/session/session_basic.cc"
"../../../mindspore/ccsrc/session/session_factory.cc" "../../../mindspore/ccsrc/session/session_factory.cc"
......
...@@ -22,7 +22,9 @@ namespace mindspore { ...@@ -22,7 +22,9 @@ namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &>) {} void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) {}
uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { return 1; }
uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { return 1; }
void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; } void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; }
...@@ -39,9 +41,7 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve ...@@ -39,9 +41,7 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
} // namespace ascend } // namespace ascend
void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; } void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; }
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; } void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; }
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return true; }
return true;
}
bool KernelAdjust::NeedInsertSwitch() { return true; } bool KernelAdjust::NeedInsertSwitch() { return true; }
void KernelAdjust::Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr) { return; } void KernelAdjust::Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr) { return; }
} // namespace device } // namespace device
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册