提交 aedb6d73 编写于 作者: H hangq

fix bug that graph output tensor being freed

上级 c3ab044f
...@@ -58,7 +58,10 @@ int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Ten ...@@ -58,7 +58,10 @@ int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Ten
} }
} }
for (auto input_kernel : kernel->GetInKernels()) { for (auto input_kernel : kernel->GetInKernels()) {
MS_EXCEPTION_IF_NULL(input_kernel); MS_ASSERT(input_kernel != nullptr);
if (input_kernel->is_model_output()) {
continue;
}
ret = input_kernel->DecOutTensorRefCount(); ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) { if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->Name() << " failed"; MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->Name() << " failed";
......
...@@ -60,8 +60,7 @@ class LiteKernel { ...@@ -60,8 +60,7 @@ class LiteKernel {
explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive) const lite::Primitive *primitive)
: opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), : opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), context_(ctx) {
context_(ctx) {
opParameter->thread_num_ = ctx->thread_num_; opParameter->thread_num_ = ctx->thread_num_;
this->in_kernel_.clear(); this->in_kernel_.clear();
this->out_kernel_.clear(); this->out_kernel_.clear();
...@@ -95,6 +94,10 @@ class LiteKernel { ...@@ -95,6 +94,10 @@ class LiteKernel {
virtual bool is_eval() { return train_mode == false; } virtual bool is_eval() { return train_mode == false; }
void set_name(const std::string &name) { this->name = name; } void set_name(const std::string &name) { this->name = name; }
void set_is_model_output(bool is_model_output) { this->is_model_output_ = is_model_output; }
bool is_model_output() { return this->is_model_output_; }
schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; } schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; }
std::string type_str() { return schema::EnumNamePrimitiveType((schema::PrimitiveType)this->opParameter->type_); } std::string type_str() { return schema::EnumNamePrimitiveType((schema::PrimitiveType)this->opParameter->type_); }
...@@ -123,9 +126,7 @@ class LiteKernel { ...@@ -123,9 +126,7 @@ class LiteKernel {
void set_desc(const KernelKey kernel_key) { desc = kernel_key; } void set_desc(const KernelKey kernel_key) { desc = kernel_key; }
void SetNeedReInit() { void SetNeedReInit() { need_reinit = true; }
need_reinit = true;
}
protected: protected:
bool InferShapeDone() { bool InferShapeDone() {
...@@ -138,8 +139,8 @@ class LiteKernel { ...@@ -138,8 +139,8 @@ class LiteKernel {
KernelKey desc; KernelKey desc;
std::string name; std::string name;
OpParameter *opParameter = nullptr; OpParameter *opParameter = nullptr;
const lite::Primitive *primitive_; const lite::Primitive *primitive_ = nullptr;
const lite::Context *context_; const lite::Context *context_ = nullptr;
// tensor will free in ~lite_session() // tensor will free in ~lite_session()
std::vector<lite::tensor::Tensor *> inputs_; std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_; std::vector<lite::tensor::Tensor *> outputs_;
...@@ -147,6 +148,7 @@ class LiteKernel { ...@@ -147,6 +148,7 @@ class LiteKernel {
std::vector<LiteKernel *> out_kernel_; std::vector<LiteKernel *> out_kernel_;
bool train_mode = false; bool train_mode = false;
bool need_reinit = false; bool need_reinit = false;
bool is_model_output_ = false;
}; };
class SubGraphKernel : public LiteKernel { class SubGraphKernel : public LiteKernel {
......
...@@ -79,7 +79,33 @@ int LiteSession::ConvertTensors(const lite::Model *model) { ...@@ -79,7 +79,33 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
return RET_OK; return RET_OK;
} }
void LiteSession::InitGraphInOutTensor(const lite::Model *model) { void LiteSession::InitGraphInputTensors(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->inputs.empty());
MS_ASSERT(meta_graph != nullptr);
for (size_t i = 0; i < meta_graph->inputIndex()->size(); i++) {
auto in_tensor_idx = size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i));
MS_ASSERT(in_tensor_idx < this->tensors.size());
auto *in_tensor = this->tensors.at(in_tensor_idx);
MS_ASSERT(in_tensor != nullptr);
this->inputs.emplace_back(in_tensor);
}
}
void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->outputs.empty());
MS_ASSERT(meta_graph != nullptr);
for (size_t i = 0; i < meta_graph->outputIndex()->size(); i++) {
auto out_tensor_idx = size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i));
MS_ASSERT(out_tensor_idx < this->tensors.size());
auto *out_tensor = this->tensors.at(out_tensor_idx);
MS_ASSERT(out_tensor != nullptr);
this->outputs.emplace_back(out_tensor);
}
}
void LiteSession::InitGraphInputMap(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph(); auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->input_map.empty()); MS_ASSERT(this->input_map.empty());
MS_ASSERT(meta_graph != nullptr); MS_ASSERT(meta_graph != nullptr);
...@@ -108,7 +134,12 @@ void LiteSession::InitGraphInOutTensor(const lite::Model *model) { ...@@ -108,7 +134,12 @@ void LiteSession::InitGraphInOutTensor(const lite::Model *model) {
this->input_map[in_node->name()->str()].emplace_back(ms_tensor); this->input_map[in_node->name()->str()].emplace_back(ms_tensor);
} }
} }
}
void LiteSession::InitGraphOutputMap(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->output_map.empty());
MS_ASSERT(meta_graph != nullptr);
auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph); auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph);
for (auto out_node_index : graph_output_node_indexes) { for (auto out_node_index : graph_output_node_indexes) {
auto *out_node = meta_graph->nodes()->GetAs<schema::CNode>(out_node_index); auto *out_node = meta_graph->nodes()->GetAs<schema::CNode>(out_node_index);
...@@ -136,6 +167,13 @@ void LiteSession::InitGraphInOutTensor(const lite::Model *model) { ...@@ -136,6 +167,13 @@ void LiteSession::InitGraphInOutTensor(const lite::Model *model) {
} }
} }
void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
InitGraphInputTensors(model);
InitGraphOutputTensors(model);
InitGraphInputMap(model);
InitGraphOutputMap(model);
}
int LiteSession::CompileGraph(Model *model) { int LiteSession::CompileGraph(Model *model) {
// model.MetaGraph ==> kernels // model.MetaGraph ==> kernels
if (model == nullptr) { if (model == nullptr) {
...@@ -149,7 +187,7 @@ int LiteSession::CompileGraph(Model *model) { ...@@ -149,7 +187,7 @@ int LiteSession::CompileGraph(Model *model) {
return ret; return ret;
} }
InitGraphInOutTensor(model); InitGraphInOutTensors(model);
// scheduler kernels // scheduler kernels
Scheduler scheduler(context_); Scheduler scheduler(context_);
...@@ -228,15 +266,7 @@ LiteSession::~LiteSession() { ...@@ -228,15 +266,7 @@ LiteSession::~LiteSession() {
} }
delete tensor; delete tensor;
} }
// inputs outputs input_map output_map are freed in tensors // tensor::Tensor * in input_map output_map are freed in tensors
for (auto *input : inputs) {
((tensor::LiteTensor *)input)->SetTensorImpl(nullptr);
delete input;
}
for (auto *output : outputs) {
((tensor::LiteTensor *)output)->SetTensorImpl(nullptr);
delete output;
}
for (auto iter : this->input_map) { for (auto iter : this->input_map) {
for (auto *ms_tensor : iter.second) { for (auto *ms_tensor : iter.second) {
((tensor::LiteTensor *)ms_tensor)->SetTensorImpl(nullptr); ((tensor::LiteTensor *)ms_tensor)->SetTensorImpl(nullptr);
......
...@@ -56,7 +56,15 @@ class LiteSession : public session::LiteSession { ...@@ -56,7 +56,15 @@ class LiteSession : public session::LiteSession {
protected: protected:
int ConvertTensors(const lite::Model *model); int ConvertTensors(const lite::Model *model);
void InitGraphInOutTensor(const lite::Model *model); void InitGraphInOutTensors(const lite::Model *model);
void InitGraphInputTensors(const lite::Model *model);
void InitGraphOutputTensors(const lite::Model *model);
void InitGraphInputMap(const lite::Model *model);
void InitGraphOutputMap(const lite::Model *model);
protected: protected:
Context *context_ = nullptr; Context *context_ = nullptr;
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include <algorithm> #include <algorithm>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/kernel_factory.h" #include "src/kernel_factory.h"
#include "src/common/graph_util.h"
#include "src/common/utils.h"
#if SUPPORT_GPU #if SUPPORT_GPU
#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#endif #endif
...@@ -51,6 +53,7 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tenso ...@@ -51,6 +53,7 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tenso
auto meta_graph = model->GetMetaGraph(); auto meta_graph = model->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph); MS_EXCEPTION_IF_NULL(meta_graph);
uint32_t kernelCount = meta_graph->nodes()->size(); uint32_t kernelCount = meta_graph->nodes()->size();
auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph);
for (uint32_t i = 0; i < kernelCount; i++) { for (uint32_t i = 0; i < kernelCount; i++) {
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i); auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i);
std::vector<tensor::Tensor *> inputs; std::vector<tensor::Tensor *> inputs;
...@@ -93,6 +96,7 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tenso ...@@ -93,6 +96,7 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tenso
return RET_ERROR; return RET_ERROR;
} }
kernel->set_name(cNode->name()->str()); kernel->set_name(cNode->name()->str());
kernel->set_is_model_output(IsContain(graph_output_node_indexes, size_t(i)));
kernels->emplace_back(kernel); kernels->emplace_back(kernel);
} }
return RET_OK; return RET_OK;
...@@ -158,10 +162,10 @@ kernel::LiteKernel *Scheduler::CreateSubKernel(const std::vector<kernel::LiteKer ...@@ -158,10 +162,10 @@ kernel::LiteKernel *Scheduler::CreateSubKernel(const std::vector<kernel::LiteKer
output_tensors.emplace_back(tensor); output_tensors.emplace_back(tensor);
} }
} }
// std::vector<tensor::Tensor *> input_tensors = kernel::LiteKernelUtil::SubgraphInputTensors(kernels); // std::vector<tensor::Tensor *> input_tensors = kernel::LiteKernelUtil::SubgraphInputTensors(kernels);
// std::vector<tensor::Tensor *> output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels); // std::vector<tensor::Tensor *> output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels);
// std::vector<kernel::LiteKernel *> input_kernels = kernel::LiteKernelUtil::SubgraphInputKernels(kernels); // std::vector<kernel::LiteKernel *> input_kernels = kernel::LiteKernelUtil::SubgraphInputKernels(kernels);
// std::vector<kernel::LiteKernel *> output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels); // std::vector<kernel::LiteKernel *> output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels);
sub_kernel = sub_kernel =
new kernel::SubGraphOpenCLKernel(input_tensors, output_tensors, input_kernels, output_kernels, kernels); new kernel::SubGraphOpenCLKernel(input_tensors, output_tensors, input_kernels, output_kernels, kernels);
sub_kernel->Init(); sub_kernel->Init();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册