提交 f42b3bbf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5598 add tensor sync status

Merge pull request !5598 from kisnwang/async-run-graph
......@@ -410,7 +410,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
for (auto &pre_output : pre_output_tensors) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false);
tensor->set_sync_status(kNoNeedSync);
outputs->emplace_back(tensor);
}
} else {
......
......@@ -38,9 +38,9 @@ void UpdateOutputTensors(VectorRef *outputs,
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
tensor->set_device_address(address);
}
if (tensor->need_sync()) {
if (tensor->NeedSyncDeviceToHostImmediately()) {
tensor->data_sync();
tensor->set_need_sync(false);
tensor->set_sync_status(kNoNeedSync);
}
}
}
......
......@@ -158,7 +158,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
if (tensor_address == nullptr || tensor_address != device_address) {
need_sync = true;
}
} else if (tensor->is_dirty() || tensor_address == nullptr) {
} else if (tensor->NeedSyncHostToDevice() || tensor_address == nullptr) {
need_sync = true;
} else if (tensor_address != device_address) {
if (tensor_address->DeviceType() == device_address->DeviceType()) {
......@@ -177,7 +177,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
}
}
}
tensor->set_dirty(false);
tensor->set_sync_status(kNoNeedSync);
}
}
......@@ -332,7 +332,7 @@ void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info
for (auto &pre_output : pre_output_tensors) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false);
tensor->set_sync_status(kNoNeedSync);
outputs->emplace_back(tensor);
}
} else {
......
......@@ -75,7 +75,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
temp_shape.emplace_back(1);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_dirty(false);
tensor->set_sync_status(kNoNeedSync);
tensor->SetNeedWait(true);
return tensor;
}
......@@ -96,12 +96,13 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
tensor->set_need_sync(true);
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else {
tensor->set_sync_status(kNeedSyncDeviceToHost);
}
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
tensor->SetNeedWait(true);
}
tensor->set_dirty(false);
return tensor;
}
......@@ -198,7 +199,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
auto *cur_val = static_cast<int32_t *>(cur_loop_tensor->data_c());
MS_EXCEPTION_IF_NULL(cur_val);
*cur_val = 0;
cur_loop_tensor->set_dirty(true);
cur_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
// set loop_count to zero
MS_EXCEPTION_IF_NULL(inputs);
inputs->push_back(cur_loop_tensor);
......@@ -209,7 +210,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
auto *next_val = static_cast<int32_t *>(next_loop_tensor->data_c());
MS_EXCEPTION_IF_NULL(next_val);
*next_val = 0;
next_loop_tensor->set_dirty(true);
next_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
// set loop_count to zero
MS_EXCEPTION_IF_NULL(inputs);
inputs->push_back(next_loop_tensor);
......@@ -219,7 +220,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c());
MS_EXCEPTION_IF_NULL(epoch_val);
*epoch_val = graph->current_epoch();
epoch_tensor->set_dirty(true);
epoch_tensor->set_sync_status(kNeedSyncHostToDevice);
inputs->push_back(epoch_tensor);
MS_LOG(INFO) << "Load epoch_val:" << *epoch_val;
......@@ -927,7 +928,7 @@ bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
}
if (tensor->is_dirty()) {
if (tensor->NeedSyncHostToDevice()) {
return true;
}
if (tensor->device_address() != device_address) {
......@@ -976,7 +977,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
tensor->set_dirty(false);
tensor->set_sync_status(kNoNeedSync);
}
}
......@@ -1124,7 +1125,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
tensor->data_type(), tensor->data_c())) {
MS_LOG(ERROR) << "Failed to sync output from device to host.";
}
tensor->set_dirty(false);
tensor->set_sync_status(kNoNeedSync);
params_list[output_item.first] = tensor;
}
// call callback function here
......
......@@ -373,7 +373,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
auto tensor = py::cast<tensor::TensorPtr>(input);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
new_tensor->set_device_address(tensor->device_address());
new_tensor->set_dirty(tensor->is_dirty());
new_tensor->set_sync_status(tensor->sync_status());
result[i] = new_tensor;
}
*status = PYNATIVE_SUCCESS;
......
......@@ -162,7 +162,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
}
if (bound_addresses_.find(address) != bound_addresses_.end()) {
tensor->set_device_address(address);
tensor->set_need_sync(true);
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else {
if (infer_type_id != device_type_id) {
size_t type_size = GetTypeByte(TypeIdToType(device_type_id));
......@@ -170,15 +170,16 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
address->ptr_ = resource_manager_.MemMalloc(tensor_size);
tensor->set_device_address(address);
tensor->set_need_sync(true);
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else {
tensor->set_device_address(nullptr);
address->ptr_ = tensor->data_c();
tensor->set_sync_status(kNoNeedSync);
}
address->ref_count_ = INIT_NODE_REF;
(void)bound_addresses_.insert(address);
}
tensor->set_dirty(false);
return tensor;
}
......@@ -247,7 +248,7 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const
tensor->data_c())) {
MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!";
}
tensor->set_dirty(true);
tensor->set_sync_status(kNeedSyncHostToDevice);
}
address->ref_count_ = INIT_NODE_REF;
tensor->set_device_address(address);
......
......@@ -534,7 +534,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
auto pk_node = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(tensor);
MS_EXCEPTION_IF_NULL(pk_node);
if (tensor->is_dirty() || !pk_node->has_default()) {
if (tensor->NeedSyncHostToDevice() || !pk_node->has_default()) {
need_sync = true;
}
}
......@@ -551,7 +551,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
return false;
}
}
tensor->set_dirty(false);
tensor->set_sync_status(kNoNeedSync);
}
return true;
}
......
......@@ -422,10 +422,9 @@ Tensor::Tensor(const Tensor &tensor)
: MetaTensor(tensor),
init_flag_(tensor.init_flag_),
data_(tensor.data_),
dirty_(tensor.dirty_),
id_(tensor.id_),
event_(tensor.event_),
need_sync_(tensor.need_sync_),
sync_status_(tensor.sync_status_),
device_sync_(tensor.device_sync_),
padding_type_(tensor.padding_type()) {}
......@@ -433,10 +432,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
: MetaTensor(data_type, tensor.shape_),
init_flag_(tensor.init_flag_),
data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
dirty_(tensor.dirty_),
id_(tensor.id_),
event_(tensor.event_),
need_sync_(tensor.need_sync_),
sync_status_(tensor.sync_status_),
device_sync_(tensor.device_sync_),
padding_type_(tensor.padding_type()) {}
......@@ -483,12 +481,11 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
Tensor &Tensor::AssignValue(const Tensor &tensor) {
if (this != &tensor) {
MetaTensor::operator=(tensor);
dirty_ = tensor.dirty_;
device_sync_ = tensor.device_sync_;
data_ = tensor.data_;
id_ = tensor.id_;
event_ = tensor.event_;
need_sync_ = tensor.need_sync_;
sync_status_ = tensor.sync_status_;
padding_type_ = tensor.padding_type_;
}
return *this;
......
......@@ -36,7 +36,7 @@
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore {
// brief mindspore::tensor namespace
//
enum TensorSyncStatus { kNoNeedSync, kNeedSyncHostToDevice, kNeedSyncDeviceToHost, kNeedSyncDeviceToHostImmediately };
// A sub namespace in ME to support tensor related definition.
namespace tensor {
// Tensor data interface.
......@@ -260,9 +260,6 @@ class Tensor : public MetaTensor {
bool is_init() const { return init_flag_; }
void set_init_flag(bool flag) { init_flag_ = flag; }
bool is_dirty() const { return dirty_; }
void set_dirty(const bool dirty) { dirty_ = dirty; }
DeviceSyncPtr device_address() const { return device_sync_; }
void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; }
void set_padding_type(std::vector<Axis> padding_type) { padding_type_ = padding_type; }
......@@ -293,17 +290,22 @@ class Tensor : public MetaTensor {
event_ == nullptr;
}
void set_need_sync(bool need_sync) { need_sync_ = need_sync; }
void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; }
TensorSyncStatus sync_status() const { return sync_status_; }
bool NeedSyncDeviceToHostImmediately() const { return sync_status_ == kNeedSyncDeviceToHostImmediately; }
bool NeedSyncDeviceToHost() const { return sync_status_ == kNeedSyncDeviceToHost; }
bool need_sync() const { return need_sync_; }
bool NeedSyncHostToDevice() const { return sync_status_ == kNeedSyncHostToDevice; }
private:
bool init_flag_{false};
TensorDataPtr data_{nullptr};
bool dirty_{true};
std::string id_{""};
std::shared_ptr<WaitEvent> event_{nullptr};
bool need_sync_{false};
TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
DeviceSyncPtr device_sync_{nullptr};
std::vector<Axis> padding_type_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册