提交 5b39a3ea 编写于 作者: L lvliang

fix-check-nullptr-by-calling-function

上级 99e7e43c
...@@ -572,7 +572,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con ...@@ -572,7 +572,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
// run graph steps // run graph steps
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const { const std::vector<tensor::TensorPtr> &inputs_const) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<tensor::TensorPtr> inputs(inputs_const); std::vector<tensor::TensorPtr> inputs(inputs_const);
size_t input_ctrl_size = 1; size_t input_ctrl_size = 1;
MS_EXCEPTION_IF_NULL(context_); MS_EXCEPTION_IF_NULL(context_);
...@@ -585,6 +584,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap ...@@ -585,6 +584,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
<< ", input_ctrl_size:" << input_ctrl_size; << ", input_ctrl_size:" << input_ctrl_size;
} }
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = inputs[i]; auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
...@@ -594,8 +595,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap ...@@ -594,8 +595,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
auto pk_node = input_node->cast<ParameterPtr>(); auto pk_node = input_node->cast<ParameterPtr>();
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
bool need_sync = false; bool need_sync = false;
MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); if (ms_context->enable_pynative_infer()) {
if (MsContext::GetInstance()->enable_pynative_infer()) {
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
need_sync = true; need_sync = true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册