提交 e9c4a697 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3868 set tensor allocator

Merge pull request !3868 from 张学同/to_merge
...@@ -39,7 +39,7 @@ int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Ten ...@@ -39,7 +39,7 @@ int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Ten
auto &outputs = kernel->GetOutputs(); auto &outputs = kernel->GetOutputs();
for (auto *output : outputs) { for (auto *output : outputs) {
MS_ASSERT(nullptr != output); MS_ASSERT(nullptr != output);
output->MallocData(allocator); output->MallocData();
} }
kernel::CallBackParam callbackParam; kernel::CallBackParam callbackParam;
callbackParam.name_callback_aram = kernel->Name(); callbackParam.name_callback_aram = kernel->Name();
...@@ -62,7 +62,7 @@ int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Ten ...@@ -62,7 +62,7 @@ int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Ten
} }
for (auto input_kernel : kernel->GetInKernels()) { for (auto input_kernel : kernel->GetInKernels()) {
MS_EXCEPTION_IF_NULL(input_kernel); MS_EXCEPTION_IF_NULL(input_kernel);
ret = input_kernel->DecOutTensorRefCount(allocator); ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) { if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->Name() << " failed"; MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->Name() << " failed";
} }
......
...@@ -112,19 +112,24 @@ class Tensor : public mindspore::tensor::MetaTensor { ...@@ -112,19 +112,24 @@ class Tensor : public mindspore::tensor::MetaTensor {
return 0; return 0;
} }
size *= (format_ == schema::Format_NC4HW4 || format_ == schema::Format_NHWC4) ? ElementsC4Num() size *= (format_ == schema::Format_NC4HW4 || format_ == schema::Format_NHWC4) ? ElementsC4Num()
: MetaTensor::ElementsNum(); : MetaTensor::ElementsNum();
return size; return size;
} }
void set_allocator(mindspore::lite::Allocator *allocator) { allocator_ = allocator; }
int MallocData(mindspore::lite::Allocator *allocator = nullptr) { int MallocData(mindspore::lite::Allocator *allocator = nullptr) {
if (nullptr != this->data_) { if (nullptr != this->data_) {
return 0; return 0;
} }
if (nullptr == allocator) { if (allocator != nullptr) {
allocator_ = allocator;
}
if (allocator_ == nullptr) {
this->data_ = malloc(this->Size()); this->data_ = malloc(this->Size());
} else { } else {
this->data_ = allocator->Malloc(this->Size()); this->data_ = allocator_->Malloc(this->Size());
} }
if (nullptr == this->data_) { if (nullptr == this->data_) {
MS_LOG(ERROR) << "Malloc tensor data failed, size=" << this->Size(); MS_LOG(ERROR) << "Malloc tensor data failed, size=" << this->Size();
...@@ -134,14 +139,14 @@ class Tensor : public mindspore::tensor::MetaTensor { ...@@ -134,14 +139,14 @@ class Tensor : public mindspore::tensor::MetaTensor {
return 0; return 0;
} }
int FreeData(mindspore::lite::Allocator *allocator = nullptr) { int FreeData() {
if (nullptr == this->data_) { if (nullptr == this->data_) {
return 0; return 0;
} }
if (nullptr == allocator) { if (nullptr == allocator_) {
free(this->data_); free(this->data_);
} else { } else {
allocator->Free(this->data_); allocator_->Free(this->data_);
this->data_ = nullptr; this->data_ = nullptr;
} }
...@@ -177,6 +182,7 @@ class Tensor : public mindspore::tensor::MetaTensor { ...@@ -177,6 +182,7 @@ class Tensor : public mindspore::tensor::MetaTensor {
schema::Format format_; schema::Format format_;
size_t refCount = 0; size_t refCount = 0;
std::vector<tensor::QuantArg> quant_params_; std::vector<tensor::QuantArg> quant_params_;
mindspore::lite::Allocator *allocator_ = nullptr;
}; };
class LiteTensor : public mindspore::tensor::MSTensor { class LiteTensor : public mindspore::tensor::MSTensor {
...@@ -221,4 +227,3 @@ using TensorPtr = std::shared_ptr<tensor::Tensor>; ...@@ -221,4 +227,3 @@ using TensorPtr = std::shared_ptr<tensor::Tensor>;
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_SRC_IR_TENSOR_H_ #endif // MINDSPORE_LITE_SRC_IR_TENSOR_H_
...@@ -25,11 +25,11 @@ void LiteKernel::InitOutTensorRefCount() { ...@@ -25,11 +25,11 @@ void LiteKernel::InitOutTensorRefCount() {
} }
} }
int LiteKernel::DecOutTensorRefCount(lite::Allocator *allocator) { int LiteKernel::DecOutTensorRefCount() {
for (auto *tensor : this->outputs_) { for (auto *tensor : this->outputs_) {
tensor->decRefCount(); tensor->decRefCount();
if (0 >= tensor->RefCount()) { if (0 >= tensor->RefCount()) {
auto ret = tensor->FreeData(allocator); auto ret = tensor->FreeData();
if (0 != ret) { if (0 != ret) {
MS_LOG(ERROR) << "Free tensor data failed"; MS_LOG(ERROR) << "Free tensor data failed";
return ret; return ret;
...@@ -141,4 +141,3 @@ void LiteKernelUtil::InitTensorRefCount(std::vector<kernel::LiteKernel *> &kerne ...@@ -141,4 +141,3 @@ void LiteKernelUtil::InitTensorRefCount(std::vector<kernel::LiteKernel *> &kerne
int LiteKernelUtil::SetInput(LiteKernel &kernelMod, std::vector<lite::tensor::Tensor *> inputs) { return -1; } int LiteKernelUtil::SetInput(LiteKernel &kernelMod, std::vector<lite::tensor::Tensor *> inputs) { return -1; }
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include "src/runtime/kernel/arm/opclib/op_base.h" #include "src/runtime/kernel/arm/opclib/op_base.h"
// #include "backend/kernel_compiler/kernel.h"
#include "include/context.h" #include "include/context.h"
#include "src/ir/tensor.h" #include "src/ir/tensor.h"
#include "src/ops/ops.h" #include "src/ops/ops.h"
...@@ -60,7 +59,6 @@ struct CallBackParam { ...@@ -60,7 +59,6 @@ struct CallBackParam {
using KernelCallBack = std::function<bool(std::vector<lite::tensor::Tensor *> inputs, using KernelCallBack = std::function<bool(std::vector<lite::tensor::Tensor *> inputs,
std::vector<lite::tensor::Tensor *> outputs, const CallBackParam &opInfo)>; std::vector<lite::tensor::Tensor *> outputs, const CallBackParam &opInfo)>;
// class LiteKernel : public KernelMod {
class LiteKernel { class LiteKernel {
public: public:
LiteKernel() = default; LiteKernel() = default;
...@@ -73,17 +71,6 @@ class LiteKernel { ...@@ -73,17 +71,6 @@ class LiteKernel {
virtual ~LiteKernel() { delete opParameter; } virtual ~LiteKernel() { delete opParameter; }
// bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
// const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
// return false;
// };
//
// const std::vector<size_t> &GetInputSizeList() const override { return {}; }
//
// const std::vector<size_t> &GetOutputSizeList() const override { return {}; }
//
// const std::vector<size_t> &GetWorkspaceSizeList() const override { return {}; }
virtual int Prepare() { return -1; } virtual int Prepare() { return -1; }
virtual int Init() { return -1; } virtual int Init() { return -1; }
virtual int ReSize() { return -1; } virtual int ReSize() { return -1; }
...@@ -115,7 +102,7 @@ class LiteKernel { ...@@ -115,7 +102,7 @@ class LiteKernel {
void InitOutTensorRefCount(); void InitOutTensorRefCount();
int DecOutTensorRefCount(lite::Allocator *allocator = nullptr); int DecOutTensorRefCount();
const KernelKey Desc() const { return desc; } const KernelKey Desc() const { return desc; }
......
...@@ -134,7 +134,7 @@ int LiteSession::CompileGraph(Model *model) { ...@@ -134,7 +134,7 @@ int LiteSession::CompileGraph(Model *model) {
} }
auto ret = ConvertTensors(model); auto ret = ConvertTensors(model);
if (0 != ret) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvertTensors failed: " << ret; MS_LOG(ERROR) << "ConvertTensors failed: " << ret;
return ret; return ret;
} }
...@@ -142,9 +142,9 @@ int LiteSession::CompileGraph(Model *model) { ...@@ -142,9 +142,9 @@ int LiteSession::CompileGraph(Model *model) {
InitGraphInOutTensor(model); InitGraphInOutTensor(model);
// scheduler kernels // scheduler kernels
Scheduler scheduler(context); Scheduler scheduler(context_);
ret = scheduler.Schedule(model, &tensors, &kernels); ret = scheduler.Schedule(model, &tensors, &kernels);
if (0 != ret) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Schedule kernels failed: " << ret; MS_LOG(ERROR) << "Schedule kernels failed: " << ret;
return ret; return ret;
} }
...@@ -166,15 +166,15 @@ std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() { ...@@ -166,15 +166,15 @@ std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() {
} }
int LiteSession::RunGraph() { int LiteSession::RunGraph() {
MS_EXCEPTION_IF_NULL(this->context); MS_EXCEPTION_IF_NULL(this->context_);
Executor executor; Executor executor;
return executor.Run(this->inputs, this->outputs, this->kernels, this->context->allocator.get()); return executor.Run(this->inputs, this->outputs, this->kernels, this->context_->allocator.get());
} }
int LiteSession::RunGraph(const kernel::KernelCallBack &before, const kernel::KernelCallBack &after) { int LiteSession::RunGraph(const kernel::KernelCallBack &before, const kernel::KernelCallBack &after) {
MS_EXCEPTION_IF_NULL(this->context); MS_EXCEPTION_IF_NULL(this->context_);
Executor executor; Executor executor;
return executor.Run(this->inputs, this->outputs, this->kernels, this->context->allocator.get(), before, after); return executor.Run(this->inputs, this->outputs, this->kernels, this->context_->allocator.get(), before, after);
} }
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputs() { std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputs() {
...@@ -190,30 +190,32 @@ std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputs() { ...@@ -190,30 +190,32 @@ std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputs() {
return ret; return ret;
} }
void LiteSession::Init(Context *context) { int LiteSession::Init(Context *context) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
this->context = new Context; this->context_ = new (std::nothrow) Context(context->threadNum, context->allocator, context->deviceCtx);
this->context->cpuBindMode = context->cpuBindMode; if (this->context_ == nullptr) {
this->context->threadNum = context->threadNum; MS_LOG(ERROR) << "new context failed";
this->context->deviceCtx.type = context->deviceCtx.type; return RET_MEMORY_FAILED;
this->context->allocator = std::make_shared<DefaultAllocator>(); }
this->context_->cpuBindMode = context->cpuBindMode;
ConfigThreadPool(context->cpuBindMode, context->threadNum); ConfigThreadPool(context->cpuBindMode, context->threadNum);
auto ret = KernelRegistry::GetInstance()->Init(); auto ret = KernelRegistry::GetInstance()->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "KernelRegistry Init Failed."; MS_LOG(ERROR) << "KernelRegistry Init Failed.";
return; return ret;
} }
#if SUPPORT_GPU #if SUPPORT_GPU
if (context->deviceCtx.type == DT_GPU) { if (context_->deviceCtx.type == DT_GPU) {
auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
opencl_runtime->Init(); opencl_runtime->Init();
} }
#endif #endif
return RET_OK;
} }
void LiteSession::BindThread(bool ifBind) { void LiteSession::BindThread(bool ifBind) {
if (this->context->cpuBindMode != NO_BIND) { if (this->context_->cpuBindMode != NO_BIND) {
DoAllThreadBind(ifBind, static_cast<int>(this->context->cpuBindMode)); DoAllThreadBind(ifBind, static_cast<int>(this->context_->cpuBindMode));
} }
} }
...@@ -234,17 +236,18 @@ LiteSession::~LiteSession() { ...@@ -234,17 +236,18 @@ LiteSession::~LiteSession() {
} }
} }
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputsByName(std::string name) { std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputsByName(std::string name) { return input_map[name]; }
return input_map[name]; std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputsByName(std::string name) { return output_map[name]; }
}
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputsByName(std::string name) {
return output_map[name];
}
} // namespace lite } // namespace lite
session::LiteSession *session::LiteSession::CreateSession(lite::Context *context) { session::LiteSession *session::LiteSession::CreateSession(lite::Context *context) {
auto session = new lite::LiteSession(); auto session = new lite::LiteSession();
session->Init(context); auto ret = session->Init(context);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init sesssion failed";
delete session;
return nullptr;
}
return session; return session;
} }
} // namespace mindspore } // namespace mindspore
......
...@@ -36,7 +36,7 @@ class LiteSession : public session::LiteSession { ...@@ -36,7 +36,7 @@ class LiteSession : public session::LiteSession {
~LiteSession() override; ~LiteSession() override;
void Init(Context *context); int Init(Context *context);
void BindThread(bool ifBind) override; void BindThread(bool ifBind) override;
...@@ -60,7 +60,7 @@ class LiteSession : public session::LiteSession { ...@@ -60,7 +60,7 @@ class LiteSession : public session::LiteSession {
void InitGraphInOutTensor(const lite::Model *model); void InitGraphInOutTensor(const lite::Model *model);
protected: protected:
Context *context = nullptr; Context *context_ = nullptr;
std::vector<kernel::LiteKernel *> kernels; std::vector<kernel::LiteKernel *> kernels;
std::vector<tensor::Tensor *> tensors; std::vector<tensor::Tensor *> tensors;
// graph input tensors // graph input tensors
......
...@@ -25,10 +25,10 @@ SubGraphOpenCLKernel::~SubGraphOpenCLKernel() { UnInit(); } ...@@ -25,10 +25,10 @@ SubGraphOpenCLKernel::~SubGraphOpenCLKernel() { UnInit(); }
int SubGraphOpenCLKernel::Init() { int SubGraphOpenCLKernel::Init() {
allocator_ = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); allocator_ = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
for (const auto tensor : inputs_) { for (const auto tensor : inputs_) {
tensor->MallocData(allocator_); tensor->set_allocator(allocator_);
} }
for (const auto tensor : outputs_) { for (const auto tensor : outputs_) {
tensor->MallocData(allocator_); tensor->set_allocator(allocator_);
} }
// Map buffer for write, it is not necessary for fine-grained // Map buffer for write, it is not necessary for fine-grained
for (auto &tensor : inputs_) { for (auto &tensor : inputs_) {
...@@ -82,4 +82,3 @@ int SubGraphOpenCLKernel::Run() { ...@@ -82,4 +82,3 @@ int SubGraphOpenCLKernel::Run() {
} }
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -112,6 +112,11 @@ void Scheduler::ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels) { ...@@ -112,6 +112,11 @@ void Scheduler::ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels) {
for (auto temp_kernels : sub_kernels_list) { for (auto temp_kernels : sub_kernels_list) {
kernel::KERNEL_ARCH arch = temp_kernels.front()->Desc().arch; kernel::KERNEL_ARCH arch = temp_kernels.front()->Desc().arch;
if (arch == kernel::KERNEL_ARCH::kCPU) { if (arch == kernel::KERNEL_ARCH::kCPU) {
for (auto kernel : temp_kernels) {
for (auto tensor : kernel->GetOutputs()) {
tensor->set_allocator(context_->allocator.get());
}
}
std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels)); std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels));
} else { } else {
auto subgraph_kernel = CreateSubKernel(temp_kernels, arch); auto subgraph_kernel = CreateSubKernel(temp_kernels, arch);
...@@ -154,9 +159,9 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *> ...@@ -154,9 +159,9 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>
MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != primitive);
auto data_type = inputs.front()->data_type(); auto data_type = inputs.front()->data_type();
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, primitive->Type()}; kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, primitive->Type()};
if (context->deviceCtx.type == DT_GPU) { if (context_->deviceCtx.type == DT_GPU) {
desc.arch = kernel::KERNEL_ARCH::kGPU; desc.arch = kernel::KERNEL_ARCH::kGPU;
auto *kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, desc); auto *kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, desc);
if (nullptr != kernel) { if (nullptr != kernel) {
kernel->set_desc(desc); kernel->set_desc(desc);
return kernel; return kernel;
...@@ -168,14 +173,14 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *> ...@@ -168,14 +173,14 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>
if (data_type == kNumberTypeFloat32) { if (data_type == kNumberTypeFloat32) {
// check if support fp16 // check if support fp16
kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type}; kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type};
kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, key); kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, key);
if (kernel != nullptr) { if (kernel != nullptr) {
kernel->set_desc(desc); kernel->set_desc(desc);
return kernel; return kernel;
} }
kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, desc); kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, desc);
} else { } else {
kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, desc); kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, desc);
} }
if (kernel != nullptr) { if (kernel != nullptr) {
kernel->set_desc(desc); kernel->set_desc(desc);
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
namespace mindspore::lite { namespace mindspore::lite {
class Scheduler { class Scheduler {
public: public:
explicit Scheduler(const Context *ctx) : context(ctx) {} explicit Scheduler(const Context *ctx) : context_(ctx) {}
int Schedule(const lite::Model *model, std::vector<tensor::Tensor *> *tensors, int Schedule(const lite::Model *model, std::vector<tensor::Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels); std::vector<kernel::LiteKernel *> *kernels);
...@@ -48,7 +48,7 @@ class Scheduler { ...@@ -48,7 +48,7 @@ class Scheduler {
protected: protected:
std::vector<std::vector<size_t>> markedKernelGroup; std::vector<std::vector<size_t>> markedKernelGroup;
const Context *context = nullptr; const Context *context_ = nullptr;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册