提交 a11e0e35 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3872 add internal output tensor

Merge pull request !3872 from kisnwang/cache-internal-tensor
......@@ -961,18 +961,40 @@ void KernelGraph::PrintGraphExecuteOrder() const {
}
}
void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) {
void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, int output_idx,
bool unique_target) {
if (front_node == nullptr || node == nullptr) {
MS_LOG(INFO) << "Front node or node is nullptr";
return;
}
MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString();
front_to_internal_outputs_map_[front_node] = node;
int output_idx = 0;
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
output_idx = AnfAlgo::GetTupleGetItemOutIndex(front_node->cast<CNodePtr>());
}
internal_outputs_to_front_map_[node][output_idx] = front_node;
internal_outputs_to_front_map_[node][output_idx] = std::pair<AnfNodePtr, bool>(front_node, unique_target);
}
void KernelGraph::AddInternalOutputTensor(const AnfNodePtr &node, int output_idx, const tensor::TensorPtr &tensor) {
if (node == nullptr) {
return;
}
internal_outputs_tensor_map_[node][output_idx] = tensor;
}
tensor::TensorPtr KernelGraph::GetInternalOutputTensor(const AnfNodePtr &node, int output_idx) {
if (node == nullptr) {
return nullptr;
}
auto iter = internal_outputs_tensor_map_.find(node);
if (iter == internal_outputs_tensor_map_.end()) {
return nullptr;
}
auto idx_iter = iter->second.find(output_idx);
if (idx_iter == iter->second.end()) {
return nullptr;
}
return idx_iter->second;
}
void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx,
......@@ -996,7 +1018,7 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
if (src_output_idx == -1) {
internal_outputs_to_front_map_[new_node] = front_nodes;
for (const auto &front_node_iter : front_nodes) {
front_to_internal_outputs_map_[front_node_iter.second] = new_node;
front_to_internal_outputs_map_[front_node_iter.second.first] = new_node;
}
internal_outputs_to_front_map_.erase(iter);
return;
......@@ -1008,9 +1030,9 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node";
return;
}
auto front_node = front_node_iter->second;
internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node;
front_to_internal_outputs_map_[front_node] = new_node;
auto front_node_pair = front_node_iter->second;
internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node_pair;
front_to_internal_outputs_map_[front_node_pair.first] = new_node;
front_nodes.erase(index);
if (front_nodes.empty()) {
internal_outputs_to_front_map_.erase(iter);
......@@ -1027,16 +1049,30 @@ AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_nod
bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, int output_idx) const {
auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
if (front_nodes_iter != internal_outputs_to_front_map_.end()) {
if (output_idx == -1) {
return true;
}
auto &front_nodes = front_nodes_iter->second;
if (front_nodes.find(output_idx) != front_nodes.end()) {
return true;
}
if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
return false;
}
return false;
if (output_idx == -1) {
return true;
}
auto &front_nodes = front_nodes_iter->second;
if (front_nodes.find(output_idx) == front_nodes.end()) {
return false;
}
return true;
}
bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const {
auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
return false;
}
auto &front_nodes = front_nodes_iter->second;
auto idx_iter = front_nodes.find(output_idx);
if (idx_iter == front_nodes.end()) {
return false;
}
return idx_iter->second.second;
}
void KernelGraph::UpdateChildGraphOrder() {
......
......@@ -143,11 +143,16 @@ class KernelGraph : public FuncGraph {
void PrintGraphExecuteOrder() const;
const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; }
void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node);
void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, int output_idx = 0,
bool unique_target = false);
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1,
int dst_output_idx = -1);
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const;
bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const;
void AddInternalOutputTensor(const AnfNodePtr &node, int output_idx, const tensor::TensorPtr &tensor);
tensor::TensorPtr GetInternalOutputTensor(const AnfNodePtr &node, int output_idx);
uint32_t current_epoch() const { return current_epoch_; }
void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
void UpdateChildGraphOrder();
......@@ -217,7 +222,8 @@ class KernelGraph : public FuncGraph {
CNodePtr end_goto_;
bool null_output_;
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
std::unordered_map<AnfNodePtr, std::unordered_map<int, AnfNodePtr>> internal_outputs_to_front_map_;
std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_;
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
uint32_t current_epoch_;
};
} // namespace session
......
......@@ -58,51 +58,38 @@ ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
return parameter->default_param();
}
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph,
const DeviceAddressPtr &address) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
// if node is a value node, no need sync addr from device to host
if (!AnfAlgo::OutputAddrExist(node, output_index)) {
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return value_node->value();
}
if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
if (input_idx >= input_tensors.size()) {
MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
}
if (graph.inputs()[input_idx] == node) {
return input_tensors[input_idx];
}
}
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr";
}
}
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
MS_EXCEPTION_IF_NULL(address);
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
TypeId type_id = kNumberTypeFloat32;
type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
MS_EXCEPTION_IF_NULL(graph);
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
if (type_id == kTypeUnknown) {
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
}
tensor::TensorPtr tensor;
std::vector<int> temp_shape;
if (graph.IsInternalOutput(node, output_index)) {
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
temp_shape.emplace_back(1);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_device_address(address);
tensor->set_dirty(false);
return tensor;
}
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor = graph->GetInternalOutputTensor(node, output_index);
if (tensor == nullptr) {
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
bool is_internal_output = graph->IsInternalOutput(node, output_index);
if (is_internal_output) {
graph->AddInternalOutputTensor(node, output_index, tensor);
}
}
// if in paynative mode,data only copyed to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
MS_EXCEPTION_IF_NULL(address);
if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
tensor->set_device_address(address);
tensor->set_dirty(false);
......@@ -114,7 +101,35 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
return tensor;
}
BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
// if node is a value node, no need sync addr from device to host
if (!AnfAlgo::OutputAddrExist(node, output_index)) {
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return value_node->value();
}
if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
if (input_idx >= input_tensors.size()) {
MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
}
if (graph->inputs()[input_idx] == node) {
return input_tensors[input_idx];
}
}
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr";
}
}
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
return CreateOutputTensor(node, output_index, graph, address);
}
BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(anf);
MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
......@@ -308,7 +323,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx);
auto ref_real_node = real_kernel.first;
auto ref_real_node_index = real_kernel.second;
if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node, ref_real_node_index)) {
if (ref_real_node->isa<CNode>() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) {
auto kernel_info = ref_real_node->kernel_info();
if (kernel_info == nullptr || !kernel_info->has_build_info()) {
MS_LOG(INFO) << "No kernel info";
......@@ -888,7 +903,7 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors));
outputs->emplace_back(CreateTensorForOutput(item, kernel_graph, input_tensors));
}
}
......@@ -967,6 +982,71 @@ void SessionBasic::Summary(KernelGraph *graph) {
summary_callback_(0, params_list);
}
namespace {
bool CNodePrimIsValueNode(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return false;
}
auto prim = cnode->input(kAnfPrimitiveIndex);
if (prim == nullptr || !prim->isa<ValueNode>()) {
return false;
}
return true;
}
void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
const FuncGraphManagerPtr &front_func_graph_manager,
const std::shared_ptr<KernelGraph> &backend_graph) {
auto node_users = front_func_graph_manager->node_users();
auto users = node_users[front_node];
auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0);
auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0);
auto front_real_kernel = front_real_kernel_pair.first;
std::string kernel_target = GetCNodeTarget(front_real_kernel);
bool internal_output = CNodePrimIsValueNode(front_real_kernel);
bool unique_target = true;
if (internal_output && opt::IsNopNode(front_real_kernel)) {
auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
auto pre_node_target = GetCNodeTarget(pre_node_pair.first);
if (pre_node_target != kernel_target) {
unique_target = false;
}
}
if (internal_output) {
for (auto user : users) {
auto cnode = user.first->cast<CNodePtr>();
if (cnode == nullptr) {
internal_output = false;
break;
}
auto prim = cnode->input(kAnfPrimitiveIndex);
if (prim == nullptr || !prim->isa<ValueNode>()) {
internal_output = false;
break;
}
if (!AnfAlgo::IsRealKernel(user.first)) {
internal_output = false;
break;
}
if (kernel_target != GetCNodeTarget(user.first)) {
unique_target = false;
}
}
}
if (internal_output) {
MS_LOG(INFO) << "Internal output: " << front_node->DebugString() << "To "
<< backend_real_kernel_pair.first->DebugString();
backend_graph->AddInternalOutput(front_node, backend_real_kernel_pair.first, backend_real_kernel_pair.second,
unique_target);
}
}
} // namespace
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> output_args;
......@@ -982,9 +1062,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
if (context_ptr->execution_mode() == kPynativeMode) {
return backend_anf;
}
auto front_real_kernel_pair = AnfAlgo::VisitKernel(out, 0);
auto front_real_kernel = front_real_kernel_pair.first;
auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_anf, 0);
MS_EXCEPTION_IF_NULL(out);
auto out_func_graph = out->func_graph();
MS_EXCEPTION_IF_NULL(out_func_graph);
......@@ -992,51 +1070,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
if (out_func_graph_manager == nullptr) {
return backend_anf;
}
auto node_users = out_func_graph_manager->node_users();
auto users = node_users[out];
bool internal_output = true;
std::string kernel_target = GetCNodeTarget(front_real_kernel);
if (front_real_kernel != nullptr && front_real_kernel->isa<CNode>()) {
auto front_cnode = front_real_kernel->cast<CNodePtr>();
if (front_cnode != nullptr) {
auto prim = front_cnode->input(kAnfPrimitiveIndex);
if (prim == nullptr || !prim->isa<ValueNode>()) {
internal_output = false;
}
} else {
internal_output = false;
}
}
if (internal_output && opt::IsNopNode(front_real_kernel)) {
auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
auto pre_node_target = GetCNodeTarget(pre_node_pair.first);
if (pre_node_target != kernel_target) {
internal_output = false;
}
}
if (internal_output) {
for (auto user : users) {
auto cnode = user.first->cast<CNodePtr>();
if (cnode == nullptr) {
internal_output = false;
break;
}
auto prim = cnode->input(kAnfPrimitiveIndex);
if (prim == nullptr || !prim->isa<ValueNode>()) {
internal_output = false;
break;
}
if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) {
internal_output = false;
break;
}
}
}
if (internal_output) {
MS_LOG(INFO) << "Internal output: " << out->DebugString() << "To "
<< backend_real_kernel_pair.first->DebugString();
graph->AddInternalOutput(out, backend_real_kernel_pair.first);
}
HandleInternalOutput(out, backend_anf, out_func_graph_manager, graph);
return backend_anf;
}
MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
......
......@@ -20,7 +20,7 @@
#include <numeric>
#include <utility>
#include <functional>
#include <unordered_map>
#include <map>
#include <set>
#include "backend/kernel_compiler/kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
......@@ -124,11 +124,10 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id);
}
tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, size_t index,
std::set<DeviceAddressPtr> *bound_addresses,
tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node,
size_t index,
std::vector<tensor::TensorPtr> *need_sync_outputs) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(bound_addresses);
MS_EXCEPTION_IF_NULL(need_sync_outputs);
size_t output_size = AnfAlgo::GetOutputTensorNum(node);
if (index >= output_size) {
......@@ -136,14 +135,21 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, s
}
auto address = AnfAlgo::GetMutableOutputAddr(node, index);
MS_EXCEPTION_IF_NULL(address);
auto shape = AnfAlgo::GetOutputInferShape(node, index);
std::vector<int> temp_shape;
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
TypeId infer_type_id = AnfAlgo::GetOutputInferDataType(node, index);
TypeId device_type_id = AnfAlgo::GetOutputDeviceDataType(node, index);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
MS_EXCEPTION_IF_NULL(tensor);
if (bound_addresses->find(address) != bound_addresses->end()) {
tensor::TensorPtr tensor = kernel_graph->GetInternalOutputTensor(node, index);
if (tensor == nullptr) {
auto shape = AnfAlgo::GetOutputInferShape(node, index);
std::vector<int> temp_shape;
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
bool is_internal_output = kernel_graph->IsInternalOutput(node, index);
if (is_internal_output) {
kernel_graph->AddInternalOutputTensor(node, index, tensor);
}
}
if (bound_addresses_.find(address) != bound_addresses_.end()) {
tensor->set_device_address(address);
need_sync_outputs->emplace_back(tensor);
} else {
......@@ -159,15 +165,14 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, s
address->ptr_ = tensor->data_c();
}
address->ref_count_ = INIT_NODE_REF;
(void)bound_addresses->insert(address);
(void)bound_addresses_.insert(address);
}
tensor->set_dirty(false);
return tensor;
}
BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index,
const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map,
std::set<DeviceAddressPtr> *bound_addresses,
BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph,
const session::KernelWithIndex &kernel_with_index,
std::vector<tensor::TensorPtr> *need_sync_outputs) {
auto &input_node = kernel_with_index.first;
auto index = kernel_with_index.second;
......@@ -179,15 +184,15 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k
VectorRef ret;
for (size_t i = 1; i < node->inputs().size(); i++) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0);
auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs);
auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs);
ret.push_back(out);
}
return ret;
}
return CreatTensorForOutput(node, index, bound_addresses, need_sync_outputs);
return CreatTensorForOutput(kernel_graph, node, index, need_sync_outputs);
} else if (input_node->isa<Parameter>()) {
auto iter = input_map.find(input_node.get());
if (iter != input_map.end()) {
auto iter = input_param_tensor_map_.find(input_node);
if (iter != input_param_tensor_map_.end()) {
return iter->second;
}
} else if (input_node->isa<ValueNode>()) {
......@@ -197,10 +202,8 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k
}
return BaseRef();
}
void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs,
std::vector<tensor::TensorPtr> *need_sync_outputs) {
void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs, std::vector<tensor::TensorPtr> *need_sync_outputs) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
// bind input ptr
......@@ -208,11 +211,11 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
if (input_nodes.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
}
std::unordered_map<AnfNode *, tensor::TensorPtr> input_map;
input_param_tensor_map_.clear();
size_t input_idx = 0;
for (auto &item : input_nodes) {
MS_EXCEPTION_IF_NULL(item);
input_map[item.get()] = inputs[input_idx];
input_param_tensor_map_[item] = inputs[input_idx];
if (item->isa<Parameter>()) {
auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
auto tensor = inputs[input_idx];
......@@ -222,7 +225,6 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
if (tensor_address != nullptr && tensor_address != address) {
(void)tensor->data_sync();
}
if (tensor->data_type() == address->type_id_ || tensor->data_type() == kNumberTypeFloat32 ||
tensor->data_type() == kNumberTypeInt32) {
address->ptr_ = tensor->data_c();
......@@ -243,11 +245,11 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
input_idx++;
}
// new output and bind ptr
std::set<DeviceAddressPtr> bound_addresses;
bound_addresses_.clear();
auto output_nodes = kernel_graph->outputs();
for (const auto &item : output_nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true);
auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs);
auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs);
outputs->push_back(std::move(out));
}
}
......
......@@ -19,7 +19,7 @@
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include <map>
#include <set>
#include "runtime/device/kernel_runtime.h"
#include "backend/session/kernel_graph.h"
......@@ -38,7 +38,7 @@ class CPUKernelRuntime : public KernelRuntime {
bool Init() override { return true; }
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
void AssignKernelAddress(session::KernelGraph *kernel_graph);
void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs, std::vector<tensor::TensorPtr> *need_sync_outputs);
void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
......@@ -49,19 +49,18 @@ class CPUKernelRuntime : public KernelRuntime {
TypeId type_id) override;
private:
tensor::TensorPtr CreatTensorForOutput(const CNodePtr &node, size_t index,
std::set<DeviceAddressPtr> *bound_addresses,
tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,
std::vector<tensor::TensorPtr> *need_sync_outputs);
BaseRef CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index,
const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map,
std::set<DeviceAddressPtr> *bound_addresses,
BaseRef CreatTensorForOutput(session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index,
std::vector<tensor::TensorPtr> *need_sync_outputs);
void AssignValueNodeAddress(session::KernelGraph *kernel_graph);
void AssignInputNodeAddress(const session::KernelGraph *kernel_graph);
void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph);
void AddRuntimeAddress(DeviceAddress *address, std::vector<kernel::AddressPtr> *input_list);
CPUResourceManager resource_manager_;
std::set<DeviceAddressPtr> bound_addresses_;
std::map<AnfNodePtr, tensor::TensorPtr> input_param_tensor_map_;
};
} // namespace cpu
} // namespace device
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册