提交 f42b3bbf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5598 add tensor sync status

Merge pull request !5598 from kisnwang/async-run-graph
...@@ -410,7 +410,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i ...@@ -410,7 +410,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
for (auto &pre_output : pre_output_tensors) { for (auto &pre_output : pre_output_tensors) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape()); tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address()); tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false); tensor->set_sync_status(kNoNeedSync);
outputs->emplace_back(tensor); outputs->emplace_back(tensor);
} }
} else { } else {
......
...@@ -38,9 +38,9 @@ void UpdateOutputTensors(VectorRef *outputs, ...@@ -38,9 +38,9 @@ void UpdateOutputTensors(VectorRef *outputs,
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
tensor->set_device_address(address); tensor->set_device_address(address);
} }
if (tensor->need_sync()) { if (tensor->NeedSyncDeviceToHostImmediately()) {
tensor->data_sync(); tensor->data_sync();
tensor->set_need_sync(false); tensor->set_sync_status(kNoNeedSync);
} }
} }
} }
......
...@@ -158,7 +158,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, ...@@ -158,7 +158,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
if (tensor_address == nullptr || tensor_address != device_address) { if (tensor_address == nullptr || tensor_address != device_address) {
need_sync = true; need_sync = true;
} }
} else if (tensor->is_dirty() || tensor_address == nullptr) { } else if (tensor->NeedSyncHostToDevice() || tensor_address == nullptr) {
need_sync = true; need_sync = true;
} else if (tensor_address != device_address) { } else if (tensor_address != device_address) {
if (tensor_address->DeviceType() == device_address->DeviceType()) { if (tensor_address->DeviceType() == device_address->DeviceType()) {
...@@ -177,7 +177,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, ...@@ -177,7 +177,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
} }
} }
} }
tensor->set_dirty(false); tensor->set_sync_status(kNoNeedSync);
} }
} }
...@@ -332,7 +332,7 @@ void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info ...@@ -332,7 +332,7 @@ void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info
for (auto &pre_output : pre_output_tensors) { for (auto &pre_output : pre_output_tensors) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape()); tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address()); tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false); tensor->set_sync_status(kNoNeedSync);
outputs->emplace_back(tensor); outputs->emplace_back(tensor);
} }
} else { } else {
......
...@@ -75,7 +75,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o ...@@ -75,7 +75,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
temp_shape.emplace_back(1); temp_shape.emplace_back(1);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_dirty(false); tensor->set_sync_status(kNoNeedSync);
tensor->SetNeedWait(true); tensor->SetNeedWait(true);
return tensor; return tensor;
} }
...@@ -96,12 +96,13 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o ...@@ -96,12 +96,13 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
tensor->set_need_sync(true); tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else {
tensor->set_sync_status(kNeedSyncDeviceToHost);
} }
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
tensor->SetNeedWait(true); tensor->SetNeedWait(true);
} }
tensor->set_dirty(false);
return tensor; return tensor;
} }
...@@ -198,7 +199,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto ...@@ -198,7 +199,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
auto *cur_val = static_cast<int32_t *>(cur_loop_tensor->data_c()); auto *cur_val = static_cast<int32_t *>(cur_loop_tensor->data_c());
MS_EXCEPTION_IF_NULL(cur_val); MS_EXCEPTION_IF_NULL(cur_val);
*cur_val = 0; *cur_val = 0;
cur_loop_tensor->set_dirty(true); cur_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
// set loop_count to zero // set loop_count to zero
MS_EXCEPTION_IF_NULL(inputs); MS_EXCEPTION_IF_NULL(inputs);
inputs->push_back(cur_loop_tensor); inputs->push_back(cur_loop_tensor);
...@@ -209,7 +210,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto ...@@ -209,7 +210,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
auto *next_val = static_cast<int32_t *>(next_loop_tensor->data_c()); auto *next_val = static_cast<int32_t *>(next_loop_tensor->data_c());
MS_EXCEPTION_IF_NULL(next_val); MS_EXCEPTION_IF_NULL(next_val);
*next_val = 0; *next_val = 0;
next_loop_tensor->set_dirty(true); next_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
// set loop_count to zero // set loop_count to zero
MS_EXCEPTION_IF_NULL(inputs); MS_EXCEPTION_IF_NULL(inputs);
inputs->push_back(next_loop_tensor); inputs->push_back(next_loop_tensor);
...@@ -219,7 +220,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto ...@@ -219,7 +220,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c()); auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c());
MS_EXCEPTION_IF_NULL(epoch_val); MS_EXCEPTION_IF_NULL(epoch_val);
*epoch_val = graph->current_epoch(); *epoch_val = graph->current_epoch();
epoch_tensor->set_dirty(true); epoch_tensor->set_sync_status(kNeedSyncHostToDevice);
inputs->push_back(epoch_tensor); inputs->push_back(epoch_tensor);
MS_LOG(INFO) << "Load epoch_val:" << *epoch_val; MS_LOG(INFO) << "Load epoch_val:" << *epoch_val;
...@@ -927,7 +928,7 @@ bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor ...@@ -927,7 +928,7 @@ bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) { if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address; return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
} }
if (tensor->is_dirty()) { if (tensor->NeedSyncHostToDevice()) {
return true; return true;
} }
if (tensor->device_address() != device_address) { if (tensor->device_address() != device_address) {
...@@ -976,7 +977,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap ...@@ -976,7 +977,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
} }
} }
tensor->set_dirty(false); tensor->set_sync_status(kNoNeedSync);
} }
} }
...@@ -1124,7 +1125,7 @@ void SessionBasic::Summary(KernelGraph *graph) { ...@@ -1124,7 +1125,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
tensor->data_type(), tensor->data_c())) { tensor->data_type(), tensor->data_c())) {
MS_LOG(ERROR) << "Failed to sync output from device to host."; MS_LOG(ERROR) << "Failed to sync output from device to host.";
} }
tensor->set_dirty(false); tensor->set_sync_status(kNoNeedSync);
params_list[output_item.first] = tensor; params_list[output_item.first] = tensor;
} }
// call callback function here // call callback function here
......
...@@ -373,7 +373,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat ...@@ -373,7 +373,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
auto tensor = py::cast<tensor::TensorPtr>(input); auto tensor = py::cast<tensor::TensorPtr>(input);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr()); auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
new_tensor->set_device_address(tensor->device_address()); new_tensor->set_device_address(tensor->device_address());
new_tensor->set_dirty(tensor->is_dirty()); new_tensor->set_sync_status(tensor->sync_status());
result[i] = new_tensor; result[i] = new_tensor;
} }
*status = PYNATIVE_SUCCESS; *status = PYNATIVE_SUCCESS;
......
...@@ -162,7 +162,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k ...@@ -162,7 +162,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
} }
if (bound_addresses_.find(address) != bound_addresses_.end()) { if (bound_addresses_.find(address) != bound_addresses_.end()) {
tensor->set_device_address(address); tensor->set_device_address(address);
tensor->set_need_sync(true); tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else { } else {
if (infer_type_id != device_type_id) { if (infer_type_id != device_type_id) {
size_t type_size = GetTypeByte(TypeIdToType(device_type_id)); size_t type_size = GetTypeByte(TypeIdToType(device_type_id));
...@@ -170,15 +170,16 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k ...@@ -170,15 +170,16 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>()); size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
address->ptr_ = resource_manager_.MemMalloc(tensor_size); address->ptr_ = resource_manager_.MemMalloc(tensor_size);
tensor->set_device_address(address); tensor->set_device_address(address);
tensor->set_need_sync(true); tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else { } else {
tensor->set_device_address(nullptr); tensor->set_device_address(nullptr);
address->ptr_ = tensor->data_c(); address->ptr_ = tensor->data_c();
tensor->set_sync_status(kNoNeedSync);
} }
address->ref_count_ = INIT_NODE_REF; address->ref_count_ = INIT_NODE_REF;
(void)bound_addresses_.insert(address); (void)bound_addresses_.insert(address);
} }
tensor->set_dirty(false);
return tensor; return tensor;
} }
...@@ -247,7 +248,7 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const ...@@ -247,7 +248,7 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const
tensor->data_c())) { tensor->data_c())) {
MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!"; MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!";
} }
tensor->set_dirty(true); tensor->set_sync_status(kNeedSyncHostToDevice);
} }
address->ref_count_ = INIT_NODE_REF; address->ref_count_ = INIT_NODE_REF;
tensor->set_device_address(address); tensor->set_device_address(address);
......
...@@ -534,7 +534,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph ...@@ -534,7 +534,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
auto pk_node = input_node->cast<ParameterPtr>(); auto pk_node = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
MS_EXCEPTION_IF_NULL(pk_node); MS_EXCEPTION_IF_NULL(pk_node);
if (tensor->is_dirty() || !pk_node->has_default()) { if (tensor->NeedSyncHostToDevice() || !pk_node->has_default()) {
need_sync = true; need_sync = true;
} }
} }
...@@ -551,7 +551,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph ...@@ -551,7 +551,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
return false; return false;
} }
} }
tensor->set_dirty(false); tensor->set_sync_status(kNoNeedSync);
} }
return true; return true;
} }
......
...@@ -422,10 +422,9 @@ Tensor::Tensor(const Tensor &tensor) ...@@ -422,10 +422,9 @@ Tensor::Tensor(const Tensor &tensor)
: MetaTensor(tensor), : MetaTensor(tensor),
init_flag_(tensor.init_flag_), init_flag_(tensor.init_flag_),
data_(tensor.data_), data_(tensor.data_),
dirty_(tensor.dirty_),
id_(tensor.id_), id_(tensor.id_),
event_(tensor.event_), event_(tensor.event_),
need_sync_(tensor.need_sync_), sync_status_(tensor.sync_status_),
device_sync_(tensor.device_sync_), device_sync_(tensor.device_sync_),
padding_type_(tensor.padding_type()) {} padding_type_(tensor.padding_type()) {}
...@@ -433,10 +432,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type) ...@@ -433,10 +432,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
: MetaTensor(data_type, tensor.shape_), : MetaTensor(data_type, tensor.shape_),
init_flag_(tensor.init_flag_), init_flag_(tensor.init_flag_),
data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)), data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
dirty_(tensor.dirty_),
id_(tensor.id_), id_(tensor.id_),
event_(tensor.event_), event_(tensor.event_),
need_sync_(tensor.need_sync_), sync_status_(tensor.sync_status_),
device_sync_(tensor.device_sync_), device_sync_(tensor.device_sync_),
padding_type_(tensor.padding_type()) {} padding_type_(tensor.padding_type()) {}
...@@ -483,12 +481,11 @@ bool Tensor::ValueEqual(const Tensor &tensor) const { ...@@ -483,12 +481,11 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
Tensor &Tensor::AssignValue(const Tensor &tensor) { Tensor &Tensor::AssignValue(const Tensor &tensor) {
if (this != &tensor) { if (this != &tensor) {
MetaTensor::operator=(tensor); MetaTensor::operator=(tensor);
dirty_ = tensor.dirty_;
device_sync_ = tensor.device_sync_; device_sync_ = tensor.device_sync_;
data_ = tensor.data_; data_ = tensor.data_;
id_ = tensor.id_; id_ = tensor.id_;
event_ = tensor.event_; event_ = tensor.event_;
need_sync_ = tensor.need_sync_; sync_status_ = tensor.sync_status_;
padding_type_ = tensor.padding_type_; padding_type_ = tensor.padding_type_;
} }
return *this; return *this;
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
// Other namespace should be a sub namespace of mindspore namespace in the ME project. // Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore { namespace mindspore {
// brief mindspore::tensor namespace // brief mindspore::tensor namespace
// enum TensorSyncStatus { kNoNeedSync, kNeedSyncHostToDevice, kNeedSyncDeviceToHost, kNeedSyncDeviceToHostImmediately };
// A sub namespace in ME to support tensor related definition. // A sub namespace in ME to support tensor related definition.
namespace tensor { namespace tensor {
// Tensor data interface. // Tensor data interface.
...@@ -260,9 +260,6 @@ class Tensor : public MetaTensor { ...@@ -260,9 +260,6 @@ class Tensor : public MetaTensor {
bool is_init() const { return init_flag_; } bool is_init() const { return init_flag_; }
void set_init_flag(bool flag) { init_flag_ = flag; } void set_init_flag(bool flag) { init_flag_ = flag; }
bool is_dirty() const { return dirty_; }
void set_dirty(const bool dirty) { dirty_ = dirty; }
DeviceSyncPtr device_address() const { return device_sync_; } DeviceSyncPtr device_address() const { return device_sync_; }
void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; } void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; }
void set_padding_type(std::vector<Axis> padding_type) { padding_type_ = padding_type; } void set_padding_type(std::vector<Axis> padding_type) { padding_type_ = padding_type; }
...@@ -293,17 +290,22 @@ class Tensor : public MetaTensor { ...@@ -293,17 +290,22 @@ class Tensor : public MetaTensor {
event_ == nullptr; event_ == nullptr;
} }
void set_need_sync(bool need_sync) { need_sync_ = need_sync; } void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; }
TensorSyncStatus sync_status() const { return sync_status_; }
bool NeedSyncDeviceToHostImmediately() const { return sync_status_ == kNeedSyncDeviceToHostImmediately; }
bool NeedSyncDeviceToHost() const { return sync_status_ == kNeedSyncDeviceToHost; }
bool need_sync() const { return need_sync_; } bool NeedSyncHostToDevice() const { return sync_status_ == kNeedSyncHostToDevice; }
private: private:
bool init_flag_{false}; bool init_flag_{false};
TensorDataPtr data_{nullptr}; TensorDataPtr data_{nullptr};
bool dirty_{true};
std::string id_{""}; std::string id_{""};
std::shared_ptr<WaitEvent> event_{nullptr}; std::shared_ptr<WaitEvent> event_{nullptr};
bool need_sync_{false}; TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
DeviceSyncPtr device_sync_{nullptr}; DeviceSyncPtr device_sync_{nullptr};
std::vector<Axis> padding_type_; std::vector<Axis> padding_type_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册