提交 bf82e8d0 编写于 作者: C chenzupeng

fix bug in opencl subgraph output tensor

上级 dc961e46
...@@ -61,7 +61,9 @@ class LiteKernel { ...@@ -61,7 +61,9 @@ class LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive) const lite::Primitive *primitive)
: opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), context_(ctx) { : opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), context_(ctx) {
if (opParameter && ctx) {
opParameter->thread_num_ = ctx->thread_num_; opParameter->thread_num_ = ctx->thread_num_;
}
this->in_kernel_.clear(); this->in_kernel_.clear();
this->out_kernel_.clear(); this->out_kernel_.clear();
} }
...@@ -100,7 +102,10 @@ class LiteKernel { ...@@ -100,7 +102,10 @@ class LiteKernel {
schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; } schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; }
std::string type_str() { return schema::EnumNamePrimitiveType((schema::PrimitiveType)this->opParameter->type_); } std::string type_str() {
return this->opParameter ? schema::EnumNamePrimitiveType((schema::PrimitiveType)this->opParameter->type_)
: "ERROR:undefined primitive!";
}
void SetInputs(const std::vector<lite::tensor::Tensor *> &inputs) { this->inputs_ = inputs; } void SetInputs(const std::vector<lite::tensor::Tensor *> &inputs) { this->inputs_ = inputs; }
......
...@@ -56,6 +56,7 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso ...@@ -56,6 +56,7 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso
} else { } else {
output->MallocData(allocator); output->MallocData(allocator);
} }
output->set_allocator(allocator);
} }
session::CallBackParam callbackParam; session::CallBackParam callbackParam;
callbackParam.name_callback_param = kernel->Name(); callbackParam.name_callback_param = kernel->Name();
......
...@@ -157,7 +157,7 @@ kernel::LiteKernel *Scheduler::CreateSubKernel(const std::vector<kernel::LiteKer ...@@ -157,7 +157,7 @@ kernel::LiteKernel *Scheduler::CreateSubKernel(const std::vector<kernel::LiteKer
input_tensors.emplace_back(tensor); input_tensors.emplace_back(tensor);
} }
} }
for (auto tensor : tail_kernel->GetInputs()) { for (auto tensor : tail_kernel->GetOutputs()) {
if (tensor->Data() == nullptr) { if (tensor->Data() == nullptr) {
output_tensors.emplace_back(tensor); output_tensors.emplace_back(tensor);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册