diff --git a/mindspore/ccsrc/session/ascend_inference_session.cc b/mindspore/ccsrc/session/ascend_inference_session.cc index aef7738d0b1ab85496fd3f8d1ef0170995ff04da..360a0ab954527204388c3f3fb57ab075b05f7bf3 100644 --- a/mindspore/ccsrc/session/ascend_inference_session.cc +++ b/mindspore/ccsrc/session/ascend_inference_session.cc @@ -32,7 +32,6 @@ using mindspore::tensor::TensorPy; namespace mindspore { namespace session { namespace { -std::set weight_infos; static TypeId GetDataType(const py::buffer_info &buf) { if (buf.format.size() == 1) { switch (buf.format.front()) { @@ -105,10 +104,33 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr &k MS_EXCEPTION_IF_NULL(pk_node); auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); MS_EXCEPTION_IF_NULL(device_address); - if (AnfAlgo::IsParameterWeight(pk_node)) { - if (weight_infos.count(pk_node) != 0) { - continue; + if (!AnfAlgo::IsParameterWeight(pk_node)) { + tensor = inputs[no_weight_input++]; + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; } + } + } +} + +GraphId AscendInferenceSession::CompileGraph(NotNull func_graph) { + auto graph_id = AscendSession::CompileGraph(func_graph); + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + // load weight data to device + auto input_nodes = kernel_graph->inputs(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + if (!input_nodes[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; + continue; + } + auto pk_node = input_nodes[i]->cast(); + MS_EXCEPTION_IF_NULL(pk_node); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + MS_EXCEPTION_IF_NULL(device_address); + if (AnfAlgo::IsParameterWeight(pk_node)) { auto param_value = std::dynamic_pointer_cast(pk_node->default_param()); MS_EXCEPTION_IF_NULL(param_value); auto py_param = param_value->value(); @@ -120,16 +142,9 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr &k LongToSize(buf.size * buf.itemsize), buf_type, buf.ptr)) { MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; } - weight_infos.insert(pk_node); - } else { - tensor = inputs[no_weight_input++]; - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } } } + return graph_id; } } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_inference_session.h b/mindspore/ccsrc/session/ascend_inference_session.h index 53be881f93dc91b7599307624f4ade335ca5d637..e8ccff3f174cf683e7dafdaf7c346bb7c7af29f6 100644 --- a/mindspore/ccsrc/session/ascend_inference_session.h +++ b/mindspore/ccsrc/session/ascend_inference_session.h @@ -38,6 +38,7 @@ class AscendInferenceSession : public AscendSession { ~AscendInferenceSession() = default; void LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const; + GraphId CompileGraph(NotNull func_graph) override; }; MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); } // namespace session