From 83ff625d52c119525a48a9898525e9745f2a91c3 Mon Sep 17 00:00:00 2001 From: kswang Date: Fri, 12 Jun 2020 19:39:27 +0800 Subject: [PATCH] sync data for cpu --- mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc | 4 ++++ mindspore/ccsrc/session/gpu_session.cc | 14 ++++++++------ mindspore/ccsrc/vm/transform.cc | 10 ++++++++-- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc index 67328f04c..ddcc6b825 100644 --- a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc @@ -192,8 +192,12 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, if (item->isa()) { auto address = AnfAlgo::GetMutableOutputAddr(item, 0); auto tensor = inputs[input_idx]; + auto tensor_address = tensor->device_address(); MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(tensor); + if (tensor_address != nullptr && tensor_address != address) { + (void)tensor->data_sync(); + } std::vector data_shape = tensor->shape(); size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { diff --git a/mindspore/ccsrc/session/gpu_session.cc b/mindspore/ccsrc/session/gpu_session.cc index aa0cf3ca4..560697053 100644 --- a/mindspore/ccsrc/session/gpu_session.cc +++ b/mindspore/ccsrc/session/gpu_session.cc @@ -103,17 +103,19 @@ void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { auto pk_node = input_node->cast(); auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + auto tensor_address = tensor->device_address(); bool need_sync = false; if (ms_context->enable_pynative_infer()) { - if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { + if (tensor_address.get() == nullptr || tensor_address != device_address) { need_sync = true; } - } else { - if (tensor->is_dirty()) { + } else if (tensor->is_dirty()) { + need_sync = true; + } else if (tensor_address != device_address) { + if (tensor_address->DeviceType() == device_address->DeviceType()) { + AnfAlgo::SetOutputAddr(tensor_address, 0, pk_node.get()); + } else { need_sync = true; - } else if (tensor->device_address() != device_address) { - AnfAlgo::SetOutputAddr(tensor->device_address(), 0, pk_node.get()); - need_sync = false; } } if (need_sync) { diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 732107beb..ed4fee18b 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -76,9 +76,15 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { return default_target; } auto primitive = value->cast(); - ValuePtr att_target = primitive->GetAttr("primitive_target"); + auto att_target = primitive->GetAttr("primitive_target"); if (att_target != nullptr) { - std::string target = GetValue(att_target); + if (!att_target->isa()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + } + auto target = GetValue(att_target); + if (kTargetSet.find(target) == kTargetSet.end()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + } return target; } return default_target; -- GitLab