提交 b5e57956 编写于 作者: K kswang

fix unnecessary tensor sync

上级 c1cb20ac
...@@ -42,6 +42,7 @@ void UpdateOutputTensors(const VectorRef *outputs, ...@@ -42,6 +42,7 @@ void UpdateOutputTensors(const VectorRef *outputs,
if (tensor->NeedSyncDeviceToHostImmediately()) { if (tensor->NeedSyncDeviceToHostImmediately()) {
tensor->data_sync(); tensor->data_sync();
tensor->set_device_address(nullptr); tensor->set_device_address(nullptr);
tensor->set_sync_status(kNeedSyncHostToDevice);
} }
} }
} }
......
...@@ -248,7 +248,6 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const ...@@ -248,7 +248,6 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const
tensor->data_c())) { tensor->data_c())) {
MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!"; MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!";
} }
tensor->set_sync_status(kNeedSyncHostToDevice);
} }
address->ref_count_ = INIT_NODE_REF; address->ref_count_ = INIT_NODE_REF;
tensor->set_device_address(address); tensor->set_device_address(address);
......
...@@ -557,7 +557,6 @@ void Tensor::data_sync() const { ...@@ -557,7 +557,6 @@ void Tensor::data_sync() const {
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) { if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
} }
sync_status_ = kNeedSyncHostToDevice;
} }
TypeId Tensor::set_data_type(const TypeId data_type) { TypeId Tensor::set_data_type(const TypeId data_type) {
......
...@@ -289,7 +289,7 @@ class Tensor : public MetaTensor { ...@@ -289,7 +289,7 @@ class Tensor : public MetaTensor {
if (event_ != nullptr) { if (event_ != nullptr) {
event_->Wait(); event_->Wait();
} }
event_ == nullptr; event_ = nullptr;
} }
void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; } void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; }
...@@ -306,7 +306,7 @@ class Tensor : public MetaTensor { ...@@ -306,7 +306,7 @@ class Tensor : public MetaTensor {
bool init_flag_{false}; bool init_flag_{false};
TensorDataPtr data_{nullptr}; TensorDataPtr data_{nullptr};
std::string id_{""}; std::string id_{""};
std::shared_ptr<WaitEvent> event_{nullptr}; mutable std::shared_ptr<WaitEvent> event_{nullptr};
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
DeviceSyncPtr device_sync_{nullptr}; DeviceSyncPtr device_sync_{nullptr};
std::vector<Axis> padding_type_; std::vector<Axis> padding_type_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册