diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 76341fbf5a26b7641628f704c130753e94a9c01e..42ad15aa3c7f570a90cd04983af7f86bf31f143a 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -42,6 +42,7 @@ void UpdateOutputTensors(const VectorRef *outputs, if (tensor->NeedSyncDeviceToHostImmediately()) { tensor->data_sync(); tensor->set_device_address(nullptr); + tensor->set_sync_status(kNeedSyncHostToDevice); } } } diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index 430c3a3e9de07a8866366c0c889da2bc66e4e4c4..feb95d66c0e930035384f320e3af2fcd49e2bdf4 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -248,7 +248,6 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const tensor->data_c())) { MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!"; } - tensor->set_sync_status(kNeedSyncHostToDevice); } address->ref_count_ = INIT_NODE_REF; tensor->set_device_address(address); diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index c4c380c19383951ad6d5eda14e1e6d7b157aac02..7b4414df1342457c518a4936079989acc35eb0d8 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -557,7 +557,6 @@ void Tensor::data_sync() const { if (!device_sync_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; } - sync_status_ = kNeedSyncHostToDevice; } TypeId Tensor::set_data_type(const TypeId data_type) { diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 1c4631b7f26df73c54292ade75772939ec41e57d..13472b3f0c61296c26d837b784ddfb4019a604cd 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -289,7 +289,7 @@ class Tensor : public MetaTensor { if (event_ != nullptr) { event_->Wait(); } - event_ == nullptr; + event_ = nullptr; } void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; } @@ -306,7 +306,7 @@ class Tensor : public MetaTensor { bool init_flag_{false}; TensorDataPtr data_{nullptr}; std::string id_{""}; - std::shared_ptr event_{nullptr}; + mutable std::shared_ptr event_{nullptr}; mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; DeviceSyncPtr device_sync_{nullptr}; std::vector padding_type_;