提交 ef596f26 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!802 [control sink]move the opt process to build graph

Merge pull request !802 from chenfei_mindspore/move-opt-into-build-graph
......@@ -22,28 +22,32 @@ namespace mindspore {
namespace kernel {
std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
if (input_index >= inputs_format_.size()) {
MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node";
MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node";
return kInvalidFormat;
}
return inputs_format_[input_index];
}
std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
if (output_index >= outputs_format_.size()) {
MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of input node";
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node";
return kInvalidFormat;
}
return outputs_format_[output_index];
}
TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const {
if (input_index >= inputs_device_type_.size()) {
MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node";
MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input";
return TypeId::kNumberTypeEnd;
}
return inputs_device_type_[input_index];
}
TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
if (output_index >= outputs_device_type_.size()) {
MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of input node";
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
return TypeId::kNumberTypeEnd;
}
return outputs_device_type_[output_index];
}
......
......@@ -82,6 +82,9 @@ class KernelBuildInfo {
bool operator==(const KernelBuildInfo &other) const;
public:
static auto constexpr kInvalidFormat = "InvalidFormat";
private:
KernelType kernel_type_;
std::vector<std::string> inputs_format_;
......
......@@ -26,7 +26,7 @@
namespace mindspore {
namespace kernel {
namespace {
void FilterInvaildKernelInfo(const CNodePtr &kernel_node,
void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
......@@ -63,9 +63,9 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
HcclMetadataInfo(kernel_node, kernel_info_list);
}
if (kernel_info_list->empty()) {
MS_LOG(EXCEPTION) << "op" << kernel_node->DebugString() << "kernel query fail!";
MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!";
}
FilterInvaildKernelInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
} // namespace kernel
} // namespace mindspore
......@@ -46,24 +46,40 @@ RtKerDescFactory &RtKerDescFactory::Get() {
void GetRtKelInfo(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_LOG(INFO) << "Mng kernel Info.";
MS_EXCEPTION_IF_NULL(kernel_info_list);
MS_EXCEPTION_IF_NULL(kernel_node);
std::string opNameLower = AnfAlgo::GetCNodeName(kernel_node);
(void)std::transform(opNameLower.begin(), opNameLower.end(), opNameLower.begin(), ::tolower);
auto ker_desc_ptr = RtKerDescFactory::Create(opNameLower);
if (ker_desc_ptr == nullptr) {
MS_LOG(DEBUG) << "Mng can't find op [" << opNameLower << "].";
if (ker_desc_ptr != nullptr && !ker_desc_ptr->GetKernelInfo().empty()) {
*kernel_info_list = ker_desc_ptr->GetKernelInfo();
return;
}
MS_EXCEPTION_IF_NULL(ker_desc_ptr);
auto kernel_info = ker_desc_ptr->GetKernelInfo();
if (kernel_info.empty()) {
MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "].";
// if can't find kernel info in kernel info database, use the default kernel info
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
if (node_name == "StreamSwitch" || node_name == "StreamActive") {
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set input infos
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
kernel_build_info_builder->SetInputsFormat(std::vector<std::string>(input_num, kOpFormat_DEFAULT));
std::vector<TypeId> input_types = {};
for (size_t i = 0; i < input_num; i++) {
input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i));
}
kernel_build_info_builder->SetInputsDeviceType(input_types);
// set output info
auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT));
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>(output_num, TypeId::kTypeUnknown));
// set ohter info
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
kernel_info_list->push_back(kernel_build_info_builder->Build());
return;
}
*kernel_info_list = kernel_info;
MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "].";
}
} // namespace kernel
} // namespace mindspore
......@@ -186,7 +186,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before.ir";
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before" + "_graph_" +
std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
DumpIRProto(kernel_graph, "before_hwopt");
}
......@@ -208,7 +209,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after.ir";
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after" + "_graph_" +
std::to_string(kernel_graph->graph_id()) + ".ir ";
DumpIR(file_path, kernel_graph);
}
}
......@@ -252,7 +254,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_before.ir";
std::string file_path =
save_graphs_path + "/" + "hwopt_d_before" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
// data layout optimization
......@@ -278,7 +281,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_end.ir";
std::string file_path =
save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph, true);
DumpIRProto(kernel_graph, "after_hwopt");
}
......
......@@ -27,6 +27,7 @@
namespace mindspore {
namespace opt {
void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
......
......@@ -300,7 +300,12 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->GetOutputFormat(output_idx);
auto format = build_info->GetOutputFormat(output_idx);
if (format == kernel::KernelBuildInfo::kInvalidFormat) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid output format";
}
return format;
}
std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
......@@ -314,7 +319,12 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->GetInputFormat(input_idx);
auto format = build_info->GetInputFormat(input_idx);
if (format == kernel::KernelBuildInfo::kInvalidFormat) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid input format";
}
return format;
}
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
......@@ -481,7 +491,12 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->GetOutputDeviceType(output_idx);
auto dtype = build_info->GetOutputDeviceType(output_idx);
if (dtype == TypeId::kNumberTypeEnd) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid dtype";
}
return dtype;
}
TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
......@@ -494,7 +509,12 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->GetInputDeviceType(input_idx);
auto dtype = build_info->GetInputDeviceType(input_idx);
if (dtype == TypeId::kNumberTypeEnd) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid dtype";
}
return dtype;
}
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
......
......@@ -15,6 +15,9 @@
*/
#include "session/ascend_session.h"
#include <algorithm>
#include <map>
#include <tuple>
#include <set>
#include "operator/ops.h"
#include "ir/meta_tensor.h"
#include "ir/anf.h"
......@@ -75,28 +78,15 @@ void DumpGraphInputArgs(const VectorRef &args) {
void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) {
MS_EXCEPTION_IF_NULL(graph);
for (auto &node : graph->execution_order()) {
if (is_override || AnfAlgo::GetStreamDistinctionLabel(node.get()) == kInvalidDistincLabel) {
MS_EXCEPTION_IF_NULL(node);
AnfAlgo::SetStreamDistinctionLabel(label, node.get());
}
}
}
GraphId GetDistinctionLabel(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
// if graph is empty,use graph id as distinction label
if (graph->execution_order().empty()) {
return graph->graph_id();
if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) {
graph->set_stream_distinction_label(label);
}
// else use first node of execution order as label
return AnfAlgo::GetStreamDistinctionLabel(graph->execution_order()[0].get());
}
std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> graph_inputs = graph->inputs();
auto valid_inputs = graph->ValidInputs();
auto valid_inputs = graph->valid_inputs();
size_t real_args_size = 0;
std::vector<BaseRef> real_args = {};
for (size_t i = 0; i < args.size(); i++) {
......@@ -141,23 +131,9 @@ std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &ar
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
MS_LOG(INFO) << "start";
auto graph_id = graph_sum_;
// construct graph, if successfully, graph_sum_ + 1
auto graph = ConstructKernelGraph(lst, outputs);
MS_EXCEPTION_IF_NULL(graph);
opt::AscendBackendIRFusionOptimization(graph);
// select kernel build info
SelectKernel(*graph);
// convert kernel Graph to model
predictmodel::StepConvertGraph(graph);
// optimize graph
HardwareOptimize(graph);
// init runtime resource
InitRuntimeResource();
// assign static memory of parameters
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->AssignStaticMemoryInput(graph.get());
auto graph_id = graph->graph_id();
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
return graph_id;
}
......@@ -166,16 +142,36 @@ void AscendSession::BuildGraph(GraphId graph_id) {
MS_LOG(INFO) << "start";
auto graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(graph);
// resource initialize
InitRuntimeResource();
// multiple graph handle
if (graph_id == final_graph_id_) {
if (!graph->executable()) {
return;
}
// insert assigns to child graph
InsertAllAssigns();
// insert switch and active to child graph
MergeSwitchCompile();
// OptChildGraphs
auto graph_order = GetGraphOrder(final_graph_id_);
auto &graph_type = GetGraphOrderType(final_graph_id_);
for (size_t i = 0; i < graph_order.size(); i++) {
if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
continue;
}
MS_LOG(INFO) << "Start build child graph " << graph_order[i];
auto child_graph = GetGraph(graph_order[i]);
CompileChildGraph(child_graph);
}
// merge child graph
MergeGraphExecOrder();
} else {
auto single_graph = GetGraph(graph_id);
CompileChildGraph(single_graph);
// set the distinction label of single graph
SetStreamDistinctionLabel(GetGraph(graph_id), graph_id, false);
single_graph->set_stream_distinction_label(graph_id);
single_graph->UpdateExecuteKernelStreamLabel();
}
// adjust execution order because merge child graph and other special operations
AdjustKernel(graph);
......@@ -197,9 +193,26 @@ void AscendSession::BuildGraph(GraphId graph_id) {
// load task info to device if it is sink mode
LoadTask(graph);
}
// sync the inital const tensor to device
SyncInitialTenosrToDevice();
MS_LOG(INFO) << "end";
}
void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
MS_EXCEPTION_IF_NULL(child_graph);
opt::AscendBackendIRFusionOptimization(child_graph);
// select kernel build info
SelectKernel(*child_graph);
// convert kernel Graph to model
predictmodel::StepConvertGraph(child_graph);
// optimize graph
HardwareOptimize(child_graph);
// assign static memory of parameters
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->AssignStaticMemoryInput(child_graph.get());
}
void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *const outputs) {
MS_LOG(INFO) << "start";
......@@ -458,11 +471,9 @@ void AscendSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const
GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
MS_LOG(INFO) << "Start! Args size " << args.size();
auto final_graph = std::make_shared<KernelGraph>();
final_graph_id_ = graph_sum_++;
graphs_[final_graph_id_] = final_graph;
final_graph->set_graph_id(final_graph_id_);
MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << "success";
auto final_graph = NewKernelGraph();
final_graph_id_ = final_graph->graph_id();
MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success";
// init private variables and bind them with final_graph_id
graph_execute_orders_[final_graph_id_] = std::vector<GraphId>();
graph_order_types_[final_graph_id_] = std::vector<GraphType>();
......@@ -498,6 +509,46 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
return final_graph_id_;
}
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
auto fake_graph = GetGraph(fake_graph_id);
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);
auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr {
auto parameter = fake_graph->NewParameter();
MS_EXCEPTION_IF_NULL(parameter);
parameter->set_abstract(abstract);
auto new_parameter = fake_graph->NewParameter(parameter);
// Add new parameter to the graph input of fake_graph to sure that all parameters will be allocated memory.
auto graph_inputs = fake_graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
graph_inputs->push_back(new_parameter);
return new_parameter;
};
auto create_parameter_from_cnode = [&](const AnfNodePtr &cnode, size_t output_idx) -> AnfNodePtr {
MS_EXCEPTION_IF_NULL(cnode);
auto abstract = cnode->abstract();
MS_EXCEPTION_IF_NULL(abstract);
// create multiple parameters if is a tuple output real kernel
if (abstract->isa<abstract::AbstractTuple>()) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
return create_parameter((*tuple_abstract)[output_idx]);
}
return create_parameter(cnode->abstract());
};
if (AnfAlgo::CheckPrimitiveType(output_item_with_index.first, prim::kPrimMakeTuple)) {
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
auto make_tuple = output_item_with_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
auto input = make_tuple->inputs()[i];
make_tuple_inputs.push_back(CreateFakeOutput(fake_graph_id, input));
}
return fake_graph->NewCNode(make_tuple_inputs);
}
return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second);
}
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
auto final_graph = GetGraph(final_graph_id_);
MS_EXCEPTION_IF_NULL(final_graph);
......@@ -559,12 +610,6 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
condition_graph->AddValueNodeToGraph(counter_const);
// create a new switch op
auto switch_primitive = std::make_shared<Primitive>("StreamSwitch");
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeInt32});
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
auto cond_output_it = condition_output_.find(condition_graph_id);
if (cond_output_it == condition_output_.end()) {
MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id;
......@@ -574,11 +619,9 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
MS_EXCEPTION_IF_NULL(cond_output_kernel);
std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const};
CNodePtr switch_node = condition_graph->NewCNode(inputs);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get());
MS_EXCEPTION_IF_NULL(switch_node);
switch_node->set_abstract(std::make_shared<abstract::AbstractNone>());
AnfAlgo::SetGraphId(condition_graph_id, switch_node.get());
AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(GetGraph(condition_graph_id)), switch_node.get());
// set attr: cond_ RT_GREATER
AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue<int>(static_cast<int>(RT_GREATER)), switch_node);
// set attr:data_type
......@@ -586,9 +629,9 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
// set attr:true branch graph id ,which is same to stream distinction label
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(true_graph_id), switch_node);
// append switch at the end of condition graph
std::vector<CNodePtr> exec_order = condition_graph->execution_order();
exec_order.push_back(switch_node);
condition_graph->set_execution_order(exec_order);
auto return_node = condition_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
InsertControlDependToGraph(condition_graph_id, return_node->input(1), switch_node);
MS_LOG(INFO) << "Finish!";
}
......@@ -615,8 +658,14 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) {
MS_EXCEPTION_IF_NULL(true_last);
MS_EXCEPTION_IF_NULL(false_last);
MS_LOG(INFO) << "The last graph of false branch is " << false_last_id;
// now only consider the single output
InsertMultipleAssignToGraph(true_last_id, true_last->output(), false_last->output());
// create fake output
auto fake_output_graph = NewKernelGraph();
graph_execute_order.push_back(fake_output_graph->graph_id());
graph_order_type.push_back(COMMON_GRAPH);
fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output()));
final_graph->set_output(fake_output_graph->output());
InsertMultipleAssignToGraph(true_last_id, true_last->output(), final_graph->output());
InsertMultipleAssignToGraph(false_last_id, false_last->output(), final_graph->output());
// insert stream active for loop sink
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
......@@ -650,14 +699,14 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id,
if (false_graph_id != kInvalidGraphId) {
// false graph and condition in graph same stream
auto condition_graph = GetGraph(cond_graph_id);
SetStreamDistinctionLabel(GetGraph(false_graph_id), GetDistinctionLabel(condition_graph), true);
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
// if false graph is a condition graph and has been switch compiled before,it's false should be updated again
auto cond_it = switches_.find(false_graph_id);
while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) {
cond_graph_id = cond_it->first;
false_graph_id = cond_it->second.second;
condition_graph = GetGraph(cond_graph_id);
SetStreamDistinctionLabel(GetGraph(false_graph_id), GetDistinctionLabel(condition_graph), true);
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
cond_it = switches_.find(false_graph_id);
}
}
......@@ -691,7 +740,7 @@ void AscendSession::MergeSwitchCompile() {
}
// insert stream active to common graph
if (prev_graph_id != kInvalidGraphId) {
InsertStreamActiveToGraph(prev_graph_id, GetDistinctionLabel(condition_graph));
InsertStreamActiveToGraph(prev_graph_id, condition_graph->stream_distinction_label());
}
// if this is a 'if' condition
auto it = while_condition_graphs_.find(cond_graph_id);
......@@ -700,12 +749,39 @@ void AscendSession::MergeSwitchCompile() {
} else {
// if it is a while,insert a stream active to true graph
GraphId from_graph = it->second;
InsertStreamActiveToGraph(from_graph, GetDistinctionLabel(condition_graph));
InsertStreamActiveToGraph(from_graph, condition_graph->stream_distinction_label());
}
}
MS_LOG(INFO) << "Finish!";
}
void AscendSession::InsertAllAssigns() {
std::set<std::pair<AnfNodePtr, AnfNodePtr>> assigns;
for (auto assign : assigns_) {
auto front_anf = std::get<0>(assign);
auto to_graph_id = std::get<1>(assign);
auto input_idx = std::get<2>(assign);
auto to_graph = GetGraph(to_graph_id);
MS_EXCEPTION_IF_NULL(to_graph);
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
if (input_idx >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
}
auto backend_parameter = graph_inputs[input_idx];
(void)assigns.insert(std::pair<AnfNodePtr, AnfNodePtr>(front_anf, backend_parameter));
}
// erase the repeat assign
for (auto &assign : assigns) {
auto front_anf = assign.first;
auto backend_parameter = assign.second;
auto from_graph_id = GetGraphIdByNode(front_anf);
auto from_graph = GetGraph(from_graph_id);
MS_EXCEPTION_IF_NULL(from_graph);
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
}
}
// insert active to graph
void AscendSession::SetActive(GraphId from, GraphId to) {
if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) {
......@@ -735,20 +811,21 @@ void AscendSession::SetActive(GraphId from, GraphId to) {
while_condition_graphs_[to] = from;
}
void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter) {
void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx) {
MS_LOG(INFO) << "Start!";
MS_EXCEPTION_IF_NULL(backend_parameter);
MS_EXCEPTION_IF_NULL(front_anf);
if (!backend_parameter->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "Backend parameter's type is not a parameter,but is " << backend_parameter->ToString();
}
auto from_graph_id = GetGraphIdByNode(front_anf);
auto from_graph = GetGraph(from_graph_id);
MS_EXCEPTION_IF_NULL(from_graph);
auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get());
auto to_graph = GetGraph(to_graph_id);
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
MS_EXCEPTION_IF_NULL(to_graph);
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
if (input_idx >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
}
auto backend_parameter = graph_inputs[input_idx];
MS_EXCEPTION_IF_NULL(backend_parameter);
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node["
<< backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get())
<< "]";
......@@ -759,39 +836,21 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An
// if arg is the the parameter of child graph,it is parameter of final graph too
if (front_anf->isa<Parameter>()) {
MS_EXCEPTION_IF_NULL(backend_arg);
if (!AnfAlgo::OutputAddrExist(backend_arg, 0)) {
// set parameter's addr in child graph to parameter in final graph
AnfAlgo::SetOutputAddr(AnfAlgo::GetMutableOutputAddr(backend_parameter, 0), 0, backend_arg.get());
MS_LOG(INFO) << "Assign mem of node" << backend_parameter->DebugString() << " of graph "
<< AnfAlgo::GetGraphId(backend_parameter.get()) << " to node" << backend_arg->DebugString()
<< "of graph " << AnfAlgo::GetGraphId(backend_arg.get());
return;
}
// if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device
// type same to arg
if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) {
AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get());
}
// if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph
// unless it's a weight.If backend_parameter is a weight,we should assign the value back.
AnfAlgo::SetOutputAddr(AnfAlgo::GetMutableOutputAddr(backend_arg, 0), 0, backend_parameter.get());
MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString()
<< "] will be replaced.";
to_graph->ReplaceNode(backend_parameter, backend_arg);
return;
}
InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
MS_LOG(INFO) << "Finish!";
MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node"
<< backend_parameter->DebugString() << "of graph " << to_graph_id;
(void)assigns_.insert(std::tuple<AnfNodePtr, GraphId, size_t>(front_anf, to_graph_id, input_idx));
}
void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter) {
void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id,
size_t input_idx) {
MS_LOG(INFO) << "Start!";
// sync data from host to device
MS_EXCEPTION_IF_NULL(front_tensor);
size_t tensor_size = front_tensor->data().nbytes();
auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
MS_EXCEPTION_IF_NULL(addr);
if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
front_tensor->data_type(), front_tensor->data_c(false))) {
MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
}
std::pair<GraphId, size_t> graph_input_pair(to_graph_id, input_idx);
initial_tenosrs_[graph_input_pair] = front_tensor;
MS_LOG(INFO) << "Finish!";
}
......@@ -818,10 +877,9 @@ size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfN
if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
return input_index + output_num;
}
auto &graph_inputs = graph->inputs();
auto &valid_inputs = graph->ValidInputs();
auto valid_inputs = graph->valid_inputs();
if (valid_inputs[input_index]) {
SetChildGraphParameter(node, graph_inputs[input_index]);
SetChildGraphParameter(node, graph->graph_id(), input_index);
} else {
MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString();
}
......@@ -833,8 +891,7 @@ size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const Valu
if (!value->isa<Tensor>()) {
MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString();
}
auto &graph_inputs = graph->inputs();
SetChildGraphParameter(value->cast<TensorPtr>(), graph_inputs[input_index]);
SetChildGraphParameter(value->cast<TensorPtr>(), graph->graph_id(), input_index);
return ++input_index;
}
......@@ -905,8 +962,6 @@ GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
void AscendSession::MergeGraphExecOrder() {
MS_LOG(INFO) << "Start!";
// insert switch to graph
MergeSwitchCompile();
// merge graph order
auto &graph_order = GetGraphOrder(final_graph_id_);
auto &graph_type = GetGraphOrderType(final_graph_id_);
......@@ -916,6 +971,13 @@ void AscendSession::MergeGraphExecOrder() {
MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!";
return;
}
if (graph_order.size() > 1) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->enable_task_sink()) {
MS_LOG(INFO) << "Control sink network should run with task-sink mode!";
}
}
// if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph
SetStreamDistinctionLabel(final_graph, graph_order[0], false);
std::vector<CNodePtr> final_exec_order = final_graph->execution_order();
......@@ -930,7 +992,11 @@ void AscendSession::MergeGraphExecOrder() {
MS_EXCEPTION_IF_NULL(child_graph);
auto exec_order = child_graph->execution_order();
MS_LOG(INFO) << "Merge graph,graph_id " << graph_id;
(void)std::copy(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order));
(void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order),
[&](CNodePtr node) -> CNodePtr {
AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get());
return node;
});
// add all value nodes of child graphs to final graph
for (auto &value_node : child_graph->graph_value_nodes()) {
final_graph->AddValueNodeToGraph(value_node);
......@@ -969,15 +1035,9 @@ void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from
// generate a new cnode
auto assign_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(assign_node);
assign_node->set_abstract(std::make_shared<abstract::AbstractNone>());
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), assign_node.get());
AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(graph), assign_node.get());
assign_node->set_abstract(to->abstract());
// append the assign at the end of from graph
auto exec_order = graph->execution_order();
exec_order.push_back(assign_node);
graph->set_execution_order(exec_order);
InsertDependToGraph(graph_id, assign_node);
}
void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) {
......@@ -997,24 +1057,46 @@ void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodeP
void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) {
MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream;
auto from_graph = graphs_[graph_id];
auto from_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(from_graph);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("StreamActive"))};
auto active_node = from_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(active_node);
active_node->set_abstract(std::make_shared<abstract::AbstractNone>());
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), active_node.get());
// set the active stream id into the attr of active node
std::vector<uint32_t> active_index_value = {};
active_index_value.push_back(actived_stream);
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_value), active_node);
AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(from_graph), active_node.get());
// append the active node at the end of from graph
auto exec_order = from_graph->execution_order();
exec_order.push_back(active_node);
from_graph->set_execution_order(exec_order);
auto return_node = from_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
InsertControlDependToGraph(graph_id, return_node->input(1), active_node);
}
void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) {
MS_LOG(INFO) << "Insert depend at the end of graph, the attach node is " << attch_node->DebugString();
auto graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))};
auto return_node = graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
inputs.push_back(return_node->input(1));
inputs.push_back(attch_node);
auto depend_node = graph->NewCNode(inputs);
return_node->set_input(1, depend_node);
}
void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node,
const AnfNodePtr &second_node) {
MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString()
<< ", the second node is " << second_node->DebugString();
auto graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("ControlDepend"))};
inputs.push_back(first_node);
inputs.push_back(second_node);
auto control_depend = graph->NewCNode(inputs);
InsertDependToGraph(graph_id, control_depend);
}
size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) {
......@@ -1043,5 +1125,29 @@ std::vector<GraphType> &AscendSession::GetGraphOrderType(GraphId final_graph_id)
}
return graph_type_iter->second;
}
void AscendSession::SyncInitialTenosrToDevice() {
for (auto &item : initial_tenosrs_) {
auto to_graph_id = item.first.first;
auto input_idx = item.first.second;
auto front_tensor = item.second;
auto to_graph = GetGraph(to_graph_id);
MS_EXCEPTION_IF_NULL(to_graph);
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
if (input_idx >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
}
auto backend_parameter = graph_inputs[input_idx];
// sync data from host to device
MS_EXCEPTION_IF_NULL(front_tensor);
size_t tensor_size = front_tensor->data().nbytes();
auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
MS_EXCEPTION_IF_NULL(addr);
if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
front_tensor->data_type(), front_tensor->data_c(false))) {
MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
}
}
}
} // namespace session
} // namespace mindspore
......@@ -21,6 +21,9 @@
#include <vector>
#include <utility>
#include <stack>
#include <map>
#include <tuple>
#include <set>
#include "session/session_basic.h"
#include "session/kernel_graph.h"
#include "kernel/kernel.h"
......@@ -60,6 +63,8 @@ class AscendSession : public SessionBasic {
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
// insert active to graph
void SetActive(GraphId, GraphId) override;
// compile child graph when session have multiple child graphs
void CompileChildGraph(const KernelGraphPtr &child_graph);
private:
void InitRuntimeResource();
......@@ -95,12 +100,16 @@ class AscendSession : public SessionBasic {
size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph);
// handle condition graph from vm
void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id);
// insert depend to graph, used to attch control nodes to graph
void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node);
// insert depend to graph, used to attch control nodes to graph
void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node);
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id);
// set child graph parameter if front arg is a anf
void SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter);
void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx);
// set child graph parameter if front arg is a tensor
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter);
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx);
// update the execution order of all child graphs
void UpdateGraphOrder(GraphId to_graph);
// handle switch when merge
......@@ -113,6 +122,12 @@ class AscendSession : public SessionBasic {
void CopyOutputOfIf(GraphId false_graph_id);
// check if graph cache exist
bool GraphCacheExist(const GraphInfo &graph_info) const;
// insert all assign to child graph
void InsertAllAssigns();
// create fake output of final graph
AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output);
// sync intial tensors' data to device
void SyncInitialTenosrToDevice();
// member variables
// key is final_graph_id,value is child graph execute order of final graph
......@@ -124,6 +139,10 @@ class AscendSession : public SessionBasic {
// record all conditions
std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
std::unordered_map<GraphId, AnfNodePtr> condition_output_;
// share parameters
std::set<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_;
// initial tensors, these tensor will sync data to device before run graph
std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
// final_graph_id is used in every root graph has it's own session situation
GraphId final_graph_id_;
};
......
......@@ -295,10 +295,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
std::vector<TypeId> types;
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
types.push_back(kTypeUnknown);
}
std::vector<TypeId> types = std::vector<TypeId>(AnfAlgo::GetOutputTensorNum(value_node), kTypeUnknown);
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
......@@ -330,10 +327,11 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
MS_LOG(EXCEPTION) << "old can't be same with new";
}
if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
MS_LOG(EXCEPTION) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
return;
}
if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
MS_LOG(EXCEPTION) << "anf is not exist in the mape ,old " << old_backend_anf->DebugString();
MS_LOG(EXCEPTION) << "anf is not exist in the map ,old " << old_backend_anf->DebugString();
}
front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
......@@ -528,5 +526,44 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
}
return false;
}
void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node) {
MS_EXCEPTION_IF_NULL(old_anf_node);
MS_EXCEPTION_IF_NULL(new_anf_node);
MS_EXCEPTION_IF_NULL(inputs_);
auto it = node_output_edges_.find(old_anf_node);
if (it == node_output_edges_.end()) {
MS_LOG(EXCEPTION) << "Can't find anf node in node_output_edges map";
}
auto &outputs = it->second;
for (auto &output_node : outputs) {
auto output_cnode = output_node.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
auto &output_node_inputs = output_cnode->inputs();
for (size_t i = 1; i < output_node_inputs.size(); i++) {
if (output_node_inputs[i] == old_anf_node) {
output_cnode->set_input(i, new_anf_node);
}
}
// update graph inputs
for (size_t i = 0; i < inputs_->size(); i++) {
if ((*inputs_)[i] == old_anf_node) {
(*inputs_)[i] = new_anf_node;
break;
}
}
}
// update front to backend map
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
// update output depend relations
node_output_edges_[new_anf_node] = it->second;
(void)node_output_edges_.erase(old_anf_node);
}
void KernelGraph::UpdateExecuteKernelStreamLabel() {
for (auto &kernel : execution_order_) {
AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
}
}
} // namespace session
} // namespace mindspore
......@@ -27,6 +27,7 @@
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "utils/graph_utils.h"
#include "device/kernel_info.h"
namespace mindspore {
namespace session {
......@@ -37,6 +38,7 @@ class KernelGraph : public FuncGraph {
inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
execution_order_ = {};
executable_ = true;
stream_distinction_label_ = kInvalidDistincLabel;
}
~KernelGraph() override = default;
......@@ -88,7 +90,15 @@ class KernelGraph : public FuncGraph {
void set_executable(bool executable) { executable_ = executable; }
// set invalid inputs for control sink
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
const std::vector<bool> &ValidInputs() const { return valid_inputs_; }
std::vector<bool> valid_inputs() const { return valid_inputs_; }
// replace node in graph
void ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node);
// set stream label of graph
void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; }
// get stream label of graph
uint32_t stream_distinction_label() { return stream_distinction_label_; }
// refresh execute kernel stream label
void UpdateExecuteKernelStreamLabel();
private:
// remove value node form graph
......@@ -108,6 +118,7 @@ class KernelGraph : public FuncGraph {
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
std::vector<CNodePtr> execution_order_;
uint32_t graph_id_;
uint32_t stream_distinction_label_;
// record map bettween front anf and backend anf,use two map implement bidirectional map
std::unordered_map<AnfNodePtr, AnfNodePtr> front_backend_anf_map_;
......
......@@ -418,9 +418,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
MS_LOG(INFO) << "Create graph: " << graph_sum_;
auto graph = NewKernelGraph();
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
size_t from_other_graph_depend_num = 0;
for (const auto &node : lst) {
MS_EXCEPTION_IF_NULL(node);
......@@ -457,7 +456,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
}
graph->SetExecOrderByDefault();
opt::BackendCommonOptimization(graph);
graphs_[graph_sum_++] = graph;
return graph;
}
......@@ -589,14 +587,14 @@ void SessionBasic::Summary(KernelGraph *graph) {
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> output_args;
for (const auto &output : outputs) {
MS_LOG(INFO) << "output:" << output->DebugString();
}
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
if (backend_anf != nullptr) {
return backend_anf;
}
for (const auto &output : outputs) {
MS_LOG(INFO) << "output:" << output->DebugString();
}
MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
};
output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
......@@ -696,5 +694,12 @@ BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
KernelGraphPtr SessionBasic::NewKernelGraph() {
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
graphs_[graph_sum_++] = graph;
return graph;
}
} // namespace session
} // namespace mindspore
......@@ -104,6 +104,8 @@ class SessionBasic {
const std::vector<bool> &tensors_mask);
// trans BaseRef list to py::tuple
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref);
// create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph();
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
......
......@@ -27,6 +27,7 @@ assign_op_info = TBERegOp("Assign") \
.input(1, "value", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册