提交 cc54bb56 编写于 作者: C chenfei

move opt to build graph

上级 64abbeaa
......@@ -22,28 +22,32 @@ namespace mindspore {
namespace kernel {
std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
if (input_index >= inputs_format_.size()) {
MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node";
MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node";
return kInvalidFormat;
}
return inputs_format_[input_index];
}
std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
if (output_index >= outputs_format_.size()) {
MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of input node";
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node";
return kInvalidFormat;
}
return outputs_format_[output_index];
}
TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const {
if (input_index >= inputs_device_type_.size()) {
MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node";
MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input";
return TypeId::kNumberTypeEnd;
}
return inputs_device_type_[input_index];
}
TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
if (output_index >= outputs_device_type_.size()) {
MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of input node";
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
return TypeId::kNumberTypeEnd;
}
return outputs_device_type_[output_index];
}
......
......@@ -82,6 +82,9 @@ class KernelBuildInfo {
bool operator==(const KernelBuildInfo &other) const;
public:
static auto constexpr kInvalidFormat = "InvalidFormat";
private:
KernelType kernel_type_;
std::vector<std::string> inputs_format_;
......
......@@ -26,7 +26,7 @@
namespace mindspore {
namespace kernel {
namespace {
void FilterInvaildKernelInfo(const CNodePtr &kernel_node,
void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
......@@ -63,9 +63,9 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
HcclMetadataInfo(kernel_node, kernel_info_list);
}
if (kernel_info_list->empty()) {
MS_LOG(EXCEPTION) << "op" << kernel_node->DebugString() << "kernel query fail!";
MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!";
}
FilterInvaildKernelInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
} // namespace kernel
} // namespace mindspore
......@@ -46,24 +46,40 @@ RtKerDescFactory &RtKerDescFactory::Get() {
void GetRtKelInfo(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_LOG(INFO) << "Mng kernel Info.";
MS_EXCEPTION_IF_NULL(kernel_info_list);
MS_EXCEPTION_IF_NULL(kernel_node);
std::string opNameLower = AnfAlgo::GetCNodeName(kernel_node);
(void)std::transform(opNameLower.begin(), opNameLower.end(), opNameLower.begin(), ::tolower);
auto ker_desc_ptr = RtKerDescFactory::Create(opNameLower);
if (ker_desc_ptr == nullptr) {
MS_LOG(DEBUG) << "Mng can't find op [" << opNameLower << "].";
if (ker_desc_ptr != nullptr && !ker_desc_ptr->GetKernelInfo().empty()) {
*kernel_info_list = ker_desc_ptr->GetKernelInfo();
return;
}
MS_EXCEPTION_IF_NULL(ker_desc_ptr);
auto kernel_info = ker_desc_ptr->GetKernelInfo();
if (kernel_info.empty()) {
MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "].";
// if can't find kernel info in kernel info database, use the default kernel info
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
if (node_name == "StreamSwitch" || node_name == "StreamActive") {
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set input infos
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
kernel_build_info_builder->SetInputsFormat(std::vector<std::string>(input_num, kOpFormat_DEFAULT));
std::vector<TypeId> input_types = {};
for (size_t i = 0; i < input_num; i++) {
input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i));
}
kernel_build_info_builder->SetInputsDeviceType(input_types);
// set output info
auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT));
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>(output_num, TypeId::kTypeUnknown));
// set ohter info
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
kernel_info_list->push_back(kernel_build_info_builder->Build());
return;
}
*kernel_info_list = kernel_info;
MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "].";
}
} // namespace kernel
} // namespace mindspore
......@@ -186,7 +186,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before.ir";
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before" + "_graph_" +
std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
DumpIRProto(kernel_graph, "before_hwopt");
}
......@@ -208,7 +209,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after.ir";
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after" + "_graph_" +
std::to_string(kernel_graph->graph_id()) + ".ir ";
DumpIR(file_path, kernel_graph);
}
}
......@@ -252,7 +254,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_before.ir";
std::string file_path =
save_graphs_path + "/" + "hwopt_d_before" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
// data layout optimization
......@@ -278,7 +281,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_end.ir";
std::string file_path =
save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph, true);
DumpIRProto(kernel_graph, "after_hwopt");
}
......
......@@ -27,6 +27,7 @@
namespace mindspore {
namespace opt {
void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
......
......@@ -300,7 +300,12 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->GetOutputFormat(output_idx);
auto format = build_info->GetOutputFormat(output_idx);
if (format == kernel::KernelBuildInfo::kInvalidFormat) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid output format";
}
return format;
}
std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
......@@ -314,7 +319,12 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->GetInputFormat(input_idx);
auto format = build_info->GetInputFormat(input_idx);
if (format == kernel::KernelBuildInfo::kInvalidFormat) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid input format";
}
return format;
}
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
......@@ -481,7 +491,12 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->GetOutputDeviceType(output_idx);
auto dtype = build_info->GetOutputDeviceType(output_idx);
if (dtype == TypeId::kNumberTypeEnd) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid dtype";
}
return dtype;
}
TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
......@@ -494,7 +509,12 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->GetInputDeviceType(input_idx);
auto dtype = build_info->GetInputDeviceType(input_idx);
if (dtype == TypeId::kNumberTypeEnd) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid dtype";
}
return dtype;
}
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
......
......@@ -21,6 +21,9 @@
#include <vector>
#include <utility>
#include <stack>
#include <map>
#include <tuple>
#include <set>
#include "session/session_basic.h"
#include "session/kernel_graph.h"
#include "kernel/kernel.h"
......@@ -60,6 +63,8 @@ class AscendSession : public SessionBasic {
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
// insert active to graph
void SetActive(GraphId, GraphId) override;
// compile child graph when session have multiple child graphs
void CompileChildGraph(const KernelGraphPtr &child_graph);
private:
void InitRuntimeResource();
......@@ -95,12 +100,16 @@ class AscendSession : public SessionBasic {
size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph);
// handle condition graph from vm
void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id);
// insert depend to graph, used to attch control nodes to graph
void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node);
// insert depend to graph, used to attch control nodes to graph
void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node);
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id);
// set child graph parameter if front arg is a anf
void SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter);
void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx);
// set child graph parameter if front arg is a tensor
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter);
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx);
// update the execution order of all child graphs
void UpdateGraphOrder(GraphId to_graph);
// handle switch when merge
......@@ -113,6 +122,12 @@ class AscendSession : public SessionBasic {
void CopyOutputOfIf(GraphId false_graph_id);
// check if graph cache exist
bool GraphCacheExist(const GraphInfo &graph_info) const;
// insert all assign to child graph
void InsertAllAssigns();
// create fake output of final graph
AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output);
// sync intial tensors' data to device
void SyncInitialTenosrToDevice();
// member variables
// key is final_graph_id,value is child graph execute order of final graph
......@@ -124,6 +139,10 @@ class AscendSession : public SessionBasic {
// record all conditions
std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
std::unordered_map<GraphId, AnfNodePtr> condition_output_;
// share parameters
std::set<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_;
// initial tensors, these tensor will sync data to device before run graph
std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
// final_graph_id is used in every root graph has it's own session situation
GraphId final_graph_id_;
};
......
......@@ -295,10 +295,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
std::vector<TypeId> types;
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
types.push_back(kTypeUnknown);
}
std::vector<TypeId> types = std::vector<TypeId>(AnfAlgo::GetOutputTensorNum(value_node), kTypeUnknown);
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
......@@ -330,10 +327,11 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
MS_LOG(EXCEPTION) << "old can't be same with new";
}
if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
MS_LOG(EXCEPTION) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
return;
}
if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
MS_LOG(EXCEPTION) << "anf is not exist in the mape ,old " << old_backend_anf->DebugString();
MS_LOG(EXCEPTION) << "anf is not exist in the map ,old " << old_backend_anf->DebugString();
}
front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
......@@ -528,5 +526,44 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
}
return false;
}
void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node) {
MS_EXCEPTION_IF_NULL(old_anf_node);
MS_EXCEPTION_IF_NULL(new_anf_node);
MS_EXCEPTION_IF_NULL(inputs_);
auto it = node_output_edges_.find(old_anf_node);
if (it == node_output_edges_.end()) {
MS_LOG(EXCEPTION) << "Can't find anf node in node_output_edges map";
}
auto &outputs = it->second;
for (auto &output_node : outputs) {
auto output_cnode = output_node.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
auto &output_node_inputs = output_cnode->inputs();
for (size_t i = 1; i < output_node_inputs.size(); i++) {
if (output_node_inputs[i] == old_anf_node) {
output_cnode->set_input(i, new_anf_node);
}
}
// update graph inputs
for (size_t i = 0; i < inputs_->size(); i++) {
if ((*inputs_)[i] == old_anf_node) {
(*inputs_)[i] = new_anf_node;
break;
}
}
}
// update front to backend map
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
// update output depend relations
node_output_edges_[new_anf_node] = it->second;
(void)node_output_edges_.erase(old_anf_node);
}
void KernelGraph::UpdateExecuteKernelStreamLabel() {
for (auto &kernel : execution_order_) {
AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
}
}
} // namespace session
} // namespace mindspore
......@@ -27,6 +27,7 @@
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "utils/graph_utils.h"
#include "device/kernel_info.h"
namespace mindspore {
namespace session {
......@@ -37,6 +38,7 @@ class KernelGraph : public FuncGraph {
inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
execution_order_ = {};
executable_ = true;
stream_distinction_label_ = kInvalidDistincLabel;
}
~KernelGraph() override = default;
......@@ -88,7 +90,15 @@ class KernelGraph : public FuncGraph {
void set_executable(bool executable) { executable_ = executable; }
// set invalid inputs for control sink
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
const std::vector<bool> &ValidInputs() const { return valid_inputs_; }
std::vector<bool> valid_inputs() const { return valid_inputs_; }
// replace node in graph
void ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node);
// set stream label of graph
void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; }
// get stream label of graph
uint32_t stream_distinction_label() { return stream_distinction_label_; }
// refresh execute kernel stream label
void UpdateExecuteKernelStreamLabel();
private:
// remove value node form graph
......@@ -108,6 +118,7 @@ class KernelGraph : public FuncGraph {
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
std::vector<CNodePtr> execution_order_;
uint32_t graph_id_;
uint32_t stream_distinction_label_;
// record map bettween front anf and backend anf,use two map implement bidirectional map
std::unordered_map<AnfNodePtr, AnfNodePtr> front_backend_anf_map_;
......
......@@ -417,9 +417,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
MS_LOG(INFO) << "Create graph: " << graph_sum_;
auto graph = NewKernelGraph();
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
size_t from_other_graph_depend_num = 0;
for (const auto &node : lst) {
MS_EXCEPTION_IF_NULL(node);
......@@ -456,7 +455,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
}
graph->SetExecOrderByDefault();
opt::BackendCommonOptimization(graph);
graphs_[graph_sum_++] = graph;
return graph;
}
......@@ -588,14 +586,14 @@ void SessionBasic::Summary(KernelGraph *graph) {
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> output_args;
for (const auto &output : outputs) {
MS_LOG(INFO) << "output:" << output->DebugString();
}
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
if (backend_anf != nullptr) {
return backend_anf;
}
for (const auto &output : outputs) {
MS_LOG(INFO) << "output:" << output->DebugString();
}
MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
};
output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
......@@ -695,5 +693,12 @@ BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
}
KernelGraphPtr SessionBasic::NewKernelGraph() {
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
graphs_[graph_sum_++] = graph;
return graph;
}
} // namespace session
} // namespace mindspore
......@@ -104,6 +104,8 @@ class SessionBasic {
const std::vector<bool> &tensors_mask);
// trans BaseRef list to py::tuple
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref);
// create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph();
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
......
......@@ -27,6 +27,7 @@ assign_op_info = TBERegOp("Assign") \
.input(1, "value", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册