提交 817d1ae2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2403 Session code review

Merge pull request !2403 from JoyLvliang/session-code-review
...@@ -52,6 +52,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) { ...@@ -52,6 +52,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
return nullptr; return nullptr;
} }
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param()); auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
MS_EXCEPTION_IF_NULL(param_value);
auto py_param = param_value->value(); auto py_param = param_value->value();
return py_param.ptr(); return py_param.ptr();
} }
...@@ -69,7 +70,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne ...@@ -69,7 +70,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
} }
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) { for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
if (input_idx > input_tensors.size()) { if (input_idx >= input_tensors.size()) {
MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size(); MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size();
} }
if (graph.inputs()[input_idx] == node) { if (graph.inputs()[input_idx] == node) {
...@@ -149,6 +150,8 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph, ...@@ -149,6 +150,8 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
} }
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
auto value_node = anf->cast<ValueNodePtr>(); auto value_node = anf->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value(); auto value = value_node->value();
...@@ -229,6 +232,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, ...@@ -229,6 +232,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input_tensor); MS_EXCEPTION_IF_NULL(input_tensor);
auto value_node = std::make_shared<ValueNode>(input_tensor); auto value_node = std::make_shared<ValueNode>(input_tensor);
MS_EXCEPTION_IF_NULL(value_node);
// construct abstract of value node // construct abstract of value node
auto type_of_tensor = input_tensor->Dtype(); auto type_of_tensor = input_tensor->Dtype();
auto shape_of_tensor = input_tensor->shape(); auto shape_of_tensor = input_tensor->shape();
...@@ -242,6 +246,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, ...@@ -242,6 +246,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor, ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
int tensor_mask) { int tensor_mask) {
MS_EXCEPTION_IF_NULL(graph);
auto param = graph->NewParameter(); auto param = graph->NewParameter();
MS_EXCEPTION_IF_NULL(param); MS_EXCEPTION_IF_NULL(param);
if (tensor_mask == kParameterWeightTensorMask) { if (tensor_mask == kParameterWeightTensorMask) {
...@@ -295,6 +300,7 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { ...@@ -295,6 +300,7 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
} }
bool ExistSummaryNode(const KernelGraph *graph) { bool ExistSummaryNode(const KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto ret = graph->get_return(); auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret); MS_EXCEPTION_IF_NULL(ret);
auto all_nodes = DeepLinkedGraphSearch(ret); auto all_nodes = DeepLinkedGraphSearch(ret);
...@@ -315,7 +321,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf ...@@ -315,7 +321,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
if (!anf->isa<Parameter>()) { if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
} }
MS_EXCEPTION_IF_NULL(graph);
auto m_tensor = GetParamDefaultInputTensor(anf); auto m_tensor = GetParamDefaultInputTensor(anf);
auto valid_inputs = graph->MutableValidInputs(); auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs); MS_EXCEPTION_IF_NULL(valid_inputs);
...@@ -344,6 +350,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf ...@@ -344,6 +350,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, valid_input, graph); auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
if (parameters.empty()) { if (parameters.empty()) {
...@@ -482,6 +489,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) ...@@ -482,6 +489,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) { ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
auto value_node = anf->cast<ValueNodePtr>(); auto value_node = anf->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf); auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
...@@ -509,6 +517,7 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker ...@@ -509,6 +517,7 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) { ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
if (!anf->isa<Parameter>()) { if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
} }
...@@ -536,6 +545,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph ...@@ -536,6 +545,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode; std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = NewKernelGraph(); auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create graph: " << graph->graph_id(); MS_LOG(INFO) << "Create graph: " << graph->graph_id();
size_t from_other_graph_depend_num = 0; size_t from_other_graph_depend_num = 0;
for (const auto &node : lst) { for (const auto &node : lst) {
...@@ -585,6 +595,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP ...@@ -585,6 +595,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL(all_out_graph); MS_EXCEPTION_IF_NULL(all_out_graph);
auto node_list = TopoSort(func_graph->get_return()); auto node_list = TopoSort(func_graph->get_return());
auto graph = NewKernelGraph(); auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph);
front_backend_graph_map_[func_graph] = graph; front_backend_graph_map_[func_graph] = graph;
MS_LOG(INFO) << "Create graph: " << graph->graph_id(); MS_LOG(INFO) << "Create graph: " << graph->graph_id();
...@@ -724,8 +735,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap ...@@ -724,8 +735,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
} }
auto anf_outputs = kernel_graph->outputs(); auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) { for (auto &item : anf_outputs) {
MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) { if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors)); outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
continue; continue;
...@@ -761,6 +772,7 @@ void SessionBasic::GetSummaryNodes(KernelGraph *graph) { ...@@ -761,6 +772,7 @@ void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
auto node = cnode->input(kSummaryGetItem); auto node = cnode->input(kSummaryGetItem);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (!AnfAlgo::IsRealKernel(item_with_index.first)) { if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
} }
...@@ -812,6 +824,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: ...@@ -812,6 +824,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> output_args; std::vector<AnfNodePtr> output_args;
for (const auto &output : outputs) { for (const auto &output : outputs) {
MS_EXCEPTION_IF_NULL(output);
MS_LOG(INFO) << "output:" << output->DebugString(); MS_LOG(INFO) << "output:" << output->DebugString();
} }
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
...@@ -883,7 +896,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf ...@@ -883,7 +896,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
} }
auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]); auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
inputs.push_back(parameter); inputs.push_back(parameter);
graph->MutableInputs()->push_back(parameter); auto mutable_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(mutable_inputs);
mutable_inputs->push_back(parameter);
} }
// set execution order // set execution order
auto cnode = graph->NewCNode(inputs); auto cnode = graph->NewCNode(inputs);
......
...@@ -48,11 +48,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>; ...@@ -48,11 +48,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class SessionBasic { class SessionBasic {
public: public:
SessionBasic() : device_id_(0) { SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {}
graphs_ = {};
run_op_graphs_ = {};
summary_callback_ = nullptr;
}
virtual void Init(uint32_t device_id) { device_id_ = device_id; } virtual void Init(uint32_t device_id) { device_id_ = device_id; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册