提交 817d1ae2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2403 Session code review

Merge pull request !2403 from JoyLvliang/session-code-review
......@@ -52,6 +52,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
return nullptr;
}
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
MS_EXCEPTION_IF_NULL(param_value);
auto py_param = param_value->value();
return py_param.ptr();
}
......@@ -69,7 +70,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
}
if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
if (input_idx > input_tensors.size()) {
if (input_idx >= input_tensors.size()) {
MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size();
}
if (graph.inputs()[input_idx] == node) {
......@@ -149,6 +150,8 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
}
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
auto value_node = anf->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
......@@ -229,6 +232,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input_tensor);
auto value_node = std::make_shared<ValueNode>(input_tensor);
MS_EXCEPTION_IF_NULL(value_node);
// construct abstract of value node
auto type_of_tensor = input_tensor->Dtype();
auto shape_of_tensor = input_tensor->shape();
......@@ -242,6 +246,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
int tensor_mask) {
MS_EXCEPTION_IF_NULL(graph);
auto param = graph->NewParameter();
MS_EXCEPTION_IF_NULL(param);
if (tensor_mask == kParameterWeightTensorMask) {
......@@ -295,6 +300,7 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
}
bool ExistSummaryNode(const KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto all_nodes = DeepLinkedGraphSearch(ret);
......@@ -315,7 +321,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}
MS_EXCEPTION_IF_NULL(graph);
auto m_tensor = GetParamDefaultInputTensor(anf);
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
......@@ -344,6 +350,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
if (parameters.empty()) {
......@@ -482,6 +489,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
auto value_node = anf->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
......@@ -509,6 +517,7 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}
......@@ -536,6 +545,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
size_t from_other_graph_depend_num = 0;
for (const auto &node : lst) {
......@@ -585,6 +595,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL(all_out_graph);
auto node_list = TopoSort(func_graph->get_return());
auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph);
front_backend_graph_map_[func_graph] = graph;
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
......@@ -724,8 +735,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
}
auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) {
MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
continue;
......@@ -761,6 +772,7 @@ void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
auto node = cnode->input(kSummaryGetItem);
MS_EXCEPTION_IF_NULL(node);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
}
......@@ -812,6 +824,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> output_args;
for (const auto &output : outputs) {
MS_EXCEPTION_IF_NULL(output);
MS_LOG(INFO) << "output:" << output->DebugString();
}
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
......@@ -883,7 +896,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
}
auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
inputs.push_back(parameter);
graph->MutableInputs()->push_back(parameter);
auto mutable_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(mutable_inputs);
mutable_inputs->push_back(parameter);
}
// set execution order
auto cnode = graph->NewCNode(inputs);
......
......@@ -48,11 +48,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class SessionBasic {
public:
SessionBasic() : device_id_(0) {
graphs_ = {};
run_op_graphs_ = {};
summary_callback_ = nullptr;
}
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {}
virtual void Init(uint32_t device_id) { device_id_ = device_id; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册