提交 1cb8d9da 编写于 作者: C chujinjin

optimize updateoutput in gpu

上级 6eddd65c
...@@ -426,7 +426,12 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr ...@@ -426,7 +426,12 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
if (op_run_info.value != nullptr) { if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors; std::vector<tensor::TensorPtr> pre_output_tensors;
TensorValueToTensor(op_run_info.value, &pre_output_tensors); TensorValueToTensor(op_run_info.value, &pre_output_tensors);
std::copy(pre_output_tensors.begin(), pre_output_tensors.end(), std::back_inserter(outputs)); for (auto &pre_output : pre_output_tensors) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false);
outputs.emplace_back(tensor);
}
} else { } else {
UpdateOutputs(graph, &outputs, input_tensors); UpdateOutputs(graph, &outputs, input_tensors);
} }
......
...@@ -300,7 +300,18 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph ...@@ -300,7 +300,18 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
} }
// Fetch outputs // Fetch outputs
VectorRef outputs; VectorRef outputs;
UpdateOutputs(kernel_graph, &outputs, input_tensors); if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
for (auto &pre_output : pre_output_tensors) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false);
outputs.emplace_back(tensor);
}
} else {
UpdateOutputs(kernel_graph, &outputs, input_tensors);
}
// Trans output to tuple // Trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(outputs); auto output_tensors = TransformBaseRefListToTuple(outputs);
if (!utils::isa<PyObjectRef>(output_tensors) || if (!utils::isa<PyObjectRef>(output_tensors) ||
......
...@@ -565,9 +565,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat ...@@ -565,9 +565,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
if (session == nullptr) { if (session == nullptr) {
session = session::SessionFactory::Get().Create(device_target); session = session::SessionFactory::Get().Create(device_target);
MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->device_id());
} }
MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->device_id());
std::vector<tensor::TensorPtr> input_tensors; std::vector<tensor::TensorPtr> input_tensors;
std::vector<int> tensors_mask; std::vector<int> tensors_mask;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册