提交 83ff625d 编写于 作者: K kswang

sync data for cpu

上级 2a2dd7d3
......@@ -192,8 +192,12 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
if (item->isa<Parameter>()) {
auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
auto tensor = inputs[input_idx];
auto tensor_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address != nullptr && tensor_address != address) {
(void)tensor->data_sync();
}
std::vector<int> data_shape = tensor->shape();
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) {
......
......@@ -103,17 +103,19 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto pk_node = input_node->cast<ParameterPtr>();
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
auto tensor_address = tensor->device_address();
bool need_sync = false;
if (ms_context->enable_pynative_infer()) {
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
if (tensor_address.get() == nullptr || tensor_address != device_address) {
need_sync = true;
}
} else if (tensor->is_dirty()) {
need_sync = true;
} else if (tensor_address != device_address) {
if (tensor_address->DeviceType() == device_address->DeviceType()) {
AnfAlgo::SetOutputAddr(tensor_address, 0, pk_node.get());
} else {
if (tensor->is_dirty()) {
need_sync = true;
} else if (tensor->device_address() != device_address) {
AnfAlgo::SetOutputAddr(tensor->device_address(), 0, pk_node.get());
need_sync = false;
}
}
if (need_sync) {
......
......@@ -76,9 +76,15 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
return default_target;
}
auto primitive = value->cast<PrimitivePtr>();
ValuePtr att_target = primitive->GetAttr("primitive_target");
auto att_target = primitive->GetAttr("primitive_target");
if (att_target != nullptr) {
std::string target = GetValue<std::string>(att_target);
if (!att_target->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
auto target = GetValue<std::string>(att_target);
if (kTargetSet.find(target) == kTargetSet.end()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
return target;
}
return default_target;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册