未验证 提交 84b08a9b 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] Fix OpenCL global static resources of CXX API and Light API (#3373)

* [LITE][OPENCL] fix OpenCL global static resources. test=develop

* Fix Cxx and light api. test=develop
上级 8492ba5c
...@@ -43,16 +43,7 @@ class LITE_API Predictor { ...@@ -43,16 +43,7 @@ class LITE_API Predictor {
public: public:
// Create an empty predictor. // Create an empty predictor.
Predictor() { scope_ = std::make_shared<Scope>(); } Predictor() { scope_ = std::make_shared<Scope>(); }
~Predictor() {
#ifdef LITE_WITH_OPENCL
CLRuntime::Global()->ReleaseResources();
#endif
scope_.reset();
exec_scope_ = nullptr;
program_.reset();
input_names_.clear();
output_names_.clear();
}
// Create a predictor with the weight variable scope set. // Create a predictor with the weight variable scope set.
explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope) explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
: scope_(root_scope) {} : scope_(root_scope) {}
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "lite/api/light_api.h" #include "lite/api/light_api.h"
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include "paddle_use_kernels.h" // NOLINT #include "paddle_use_kernels.h" // NOLINT
#include "paddle_use_ops.h" // NOLINT #include "paddle_use_ops.h" // NOLINT
...@@ -135,7 +136,15 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { ...@@ -135,7 +136,15 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
// 1. Create op first // 1. Create op first
Program program(prog, scope_, {}); Program program(prog, scope_, {});
// 2. Create Instructs // 2. Create Instructs
#ifdef LITE_WITH_OPENCL
using WaitListType =
std::unordered_map<decltype(static_cast<const void*>(nullptr)),
std::shared_ptr<cl::Event>>;
using OpenCLContext = Context<TargetType::kOpenCL>;
std::unique_ptr<KernelContext> local_ctx(new KernelContext());
local_ctx->As<OpenCLContext>().InitOnce();
#endif
// Create the kernels of the target places, and filter out the specific // Create the kernels of the target places, and filter out the specific
// kernel with the target alias. // kernel with the target alias.
...@@ -151,7 +160,18 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { ...@@ -151,7 +160,18 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
return it->alias() == alias; return it->alias() == alias;
}); });
CHECK(it != kernels.end()); CHECK(it != kernels.end());
#ifdef LITE_WITH_OPENCL
if ((*it)->target() == TARGET(kOpenCL)) {
std::unique_ptr<KernelContext> ctx(new KernelContext());
(*local_ctx).As<OpenCLContext>().CopySharedTo(&ctx->As<OpenCLContext>());
(*it)->SetContext(std::move(ctx));
} else {
(*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target()));
}
#else
(*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target())); (*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target()));
#endif
insts.emplace_back(op, std::move(*it)); insts.emplace_back(op, std::move(*it));
} }
......
...@@ -107,8 +107,6 @@ class LightPredictorImpl : public lite_api::PaddlePredictor { ...@@ -107,8 +107,6 @@ class LightPredictorImpl : public lite_api::PaddlePredictor {
public: public:
LightPredictorImpl() = default; LightPredictorImpl() = default;
~LightPredictorImpl();
std::unique_ptr<lite_api::Tensor> GetInput(int i) override; std::unique_ptr<lite_api::Tensor> GetInput(int i) override;
std::unique_ptr<const lite_api::Tensor> GetOutput(int i) const override; std::unique_ptr<const lite_api::Tensor> GetOutput(int i) const override;
......
...@@ -21,13 +21,6 @@ ...@@ -21,13 +21,6 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
LightPredictorImpl::~LightPredictorImpl() {
raw_predictor_.reset();
#ifdef LITE_WITH_OPENCL
CLRuntime::Global()->ReleaseResources();
#endif
}
void LightPredictorImpl::Init(const lite_api::MobileConfig& config) { void LightPredictorImpl::Init(const lite_api::MobileConfig& config) {
// LightPredictor Only support NaiveBuffer backend in publish lib // LightPredictor Only support NaiveBuffer backend in publish lib
if (config.lite_model_file().empty()) { if (config.lite_model_file().empty()) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -13,7 +10,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +10,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "lite/backends/opencl/cl_context.h" #include "lite/backends/opencl/cl_context.h"
#include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -36,10 +32,8 @@ cl::Program &CLContext::GetProgram(const std::string &file_name, ...@@ -36,10 +32,8 @@ cl::Program &CLContext::GetProgram(const std::string &file_name,
STL::stringstream program_key_ss; STL::stringstream program_key_ss;
program_key_ss << file_name << options; program_key_ss << file_name << options;
std::string program_key = program_key_ss.str(); std::string program_key = program_key_ss.str();
auto it = programs_.find(program_key);
auto &programs = CLRuntime::Global()->programs(); if (it != programs_.end()) {
auto it = programs.find(program_key);
if (it != programs.end()) {
VLOG(3) << " --- program -> " << program_key << " has been built --- "; VLOG(3) << " --- program -> " << program_key << " has been built --- ";
return *(it->second); return *(it->second);
} }
...@@ -50,9 +44,9 @@ cl::Program &CLContext::GetProgram(const std::string &file_name, ...@@ -50,9 +44,9 @@ cl::Program &CLContext::GetProgram(const std::string &file_name,
CLRuntime::Global()->BuildProgram(program.get(), options); CLRuntime::Global()->BuildProgram(program.get(), options);
VLOG(3) << " --- end build program -> " << program_key << " --- "; VLOG(3) << " --- end build program -> " << program_key << " --- ";
programs[program_key] = std::move(program); programs_[program_key] = std::move(program);
return *(programs[program_key]); return *(programs_[program_key]);
} }
void CLContext::AddKernel(const std::string &kernel_name, void CLContext::AddKernel(const std::string &kernel_name,
...@@ -68,30 +62,25 @@ void CLContext::AddKernel(const std::string &kernel_name, ...@@ -68,30 +62,25 @@ void CLContext::AddKernel(const std::string &kernel_name,
new cl::Kernel(program, kernel_name.c_str(), &status)); new cl::Kernel(program, kernel_name.c_str(), &status));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
VLOG(3) << " --- end create kernel --- "; VLOG(3) << " --- end create kernel --- ";
kernels_.emplace_back(std::move(kernel));
auto &kernels = CLRuntime::Global()->kernels();
auto &kernel_offset_map = CLRuntime::Global()->kernel_offset();
kernels.emplace_back(std::move(kernel));
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_name << options << time_stamp; kernel_key << kernel_name << options << time_stamp;
kernel_offset_map[kernel_key.str()] = kernels.size() - 1; kernel_offset_[kernel_key.str()] = kernels_.size() - 1;
} }
cl::Kernel &CLContext::GetKernel(const int index) { cl::Kernel &CLContext::GetKernel(const int index) {
auto &kernels = CLRuntime::Global()->kernels(); VLOG(3) << " --- kernel count: " << kernels_.size() << " --- ";
VLOG(3) << " --- kernel count: " << kernels.size() << " --- "; CHECK(static_cast<size_t>(index) < kernels_.size())
CHECK(static_cast<size_t>(index) < kernels.size())
<< "The index must be less than the size of kernels."; << "The index must be less than the size of kernels.";
CHECK(kernels[index] != nullptr) CHECK(kernels_[index] != nullptr)
<< "The target kernel pointer cannot be null."; << "The target kernel pointer cannot be null.";
return *(kernels[index]); return *(kernels_[index]);
} }
cl::Kernel &CLContext::GetKernel(const std::string &name) { cl::Kernel &CLContext::GetKernel(const std::string &name) {
auto &kernel_offset_map = CLRuntime::Global()->kernel_offset(); auto it = kernel_offset_.find(name);
auto it = kernel_offset_map.find(name); CHECK(it != kernel_offset_.end()) << "Cannot find the kernel function: "
CHECK(it != kernel_offset_map.end()) << "Cannot find the kernel function: " << name;
<< name;
return GetKernel(it->second); return GetKernel(it->second);
} }
......
...@@ -27,6 +27,20 @@ namespace lite { ...@@ -27,6 +27,20 @@ namespace lite {
class CLContext { class CLContext {
public: public:
~CLContext() {
for (size_t kidx = 0; kidx < kernels_.size(); ++kidx) {
clReleaseKernel(kernels_[kidx]->get());
kernels_[kidx].reset();
}
kernels_.clear();
kernel_offset_.clear();
for (auto &p : programs_) {
clReleaseProgram(p.second->get());
}
programs_.clear();
LOG(INFO) << "release cl::Program, cl::Kernel finished.";
}
cl::CommandQueue &GetCommandQueue(); cl::CommandQueue &GetCommandQueue();
cl::Context &GetContext(); cl::Context &GetContext();
...@@ -52,6 +66,10 @@ class CLContext { ...@@ -52,6 +66,10 @@ class CLContext {
int divitor = 2); int divitor = 2);
// cl::NDRange LocalWorkSizeConv1x1(cl::NDRange global_work_size, // cl::NDRange LocalWorkSizeConv1x1(cl::NDRange global_work_size,
// size_t max_work_size); // size_t max_work_size);
private:
std::unordered_map<std::string, std::unique_ptr<cl::Program>> programs_;
std::vector<std::unique_ptr<cl::Kernel>> kernels_;
std::map<std::string, int> kernel_offset_;
}; };
} // namespace lite } // namespace lite
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -14,7 +11,6 @@ limitations under the License. */ ...@@ -14,7 +11,6 @@ limitations under the License. */
#include "lite/backends/opencl/cl_runtime.h" #include "lite/backends/opencl/cl_runtime.h"
#include <string> #include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
...@@ -29,38 +25,16 @@ CLRuntime* CLRuntime::Global() { ...@@ -29,38 +25,16 @@ CLRuntime* CLRuntime::Global() {
} }
CLRuntime::~CLRuntime() { CLRuntime::~CLRuntime() {
LOG(INFO) << "CLRuntime::~CLRuntime()";
// Note: do ReleaseResources() in predictor
command_queue_&& clReleaseCommandQueue(command_queue_->get());
command_queue_.reset();
context_&& clReleaseContext(context_->get());
context_.reset();
device_.reset();
platform_.reset();
initialized_ = false;
}
void CLRuntime::ReleaseResources() {
// if (is_resources_released_) {
// return;
// }
if (command_queue_ != nullptr) { if (command_queue_ != nullptr) {
command_queue_->flush(); command_queue_->flush();
command_queue_->finish(); command_queue_->finish();
} }
for (size_t kidx = 0; kidx < kernels_.size(); ++kidx) { // For controlling the destruction order:
clReleaseKernel(kernels_[kidx]->get()); command_queue_.reset();
kernels_[kidx].reset(); context_.reset();
} device_.reset();
kernels_.clear(); platform_.reset();
kernel_offset_.clear(); LOG(INFO) << "release ~CLRuntime() ";
for (auto& p : programs_) {
clReleaseProgram(p.second->get());
}
programs_.clear();
LOG(INFO) << "release resources finished.";
is_resources_released_ = true;
} }
bool CLRuntime::Init() { bool CLRuntime::Init() {
...@@ -98,14 +72,14 @@ cl::CommandQueue& CLRuntime::command_queue() { ...@@ -98,14 +72,14 @@ cl::CommandQueue& CLRuntime::command_queue() {
return *command_queue_; return *command_queue_;
} }
std::shared_ptr<cl::Program> CLRuntime::CreateProgram( std::unique_ptr<cl::Program> CLRuntime::CreateProgram(
const cl::Context& context, std::string file_name) { const cl::Context& context, std::string file_name) {
auto cl_file = opencl_kernels_files.find(file_name); auto cl_file = opencl_kernels_files.find(file_name);
std::string content(cl_file->second.begin(), cl_file->second.end()); std::string content(cl_file->second.begin(), cl_file->second.end());
cl::Program::Sources sources; cl::Program::Sources sources;
sources.push_back(content); sources.push_back(content);
auto prog = auto prog =
std::shared_ptr<cl::Program>(new cl::Program(context, sources, &status_)); std::unique_ptr<cl::Program>(new cl::Program(context, sources, &status_));
VLOG(4) << "OpenCL kernel file name: " << file_name; VLOG(4) << "OpenCL kernel file name: " << file_name;
VLOG(4) << "Program source size: " << content.size(); VLOG(4) << "Program source size: " << content.size();
CL_CHECK_FATAL(status_); CL_CHECK_FATAL(status_);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -18,7 +15,6 @@ limitations under the License. */ ...@@ -18,7 +15,6 @@ limitations under the License. */
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "lite/backends/opencl/cl_include.h" #include "lite/backends/opencl/cl_include.h"
#include "lite/backends/opencl/cl_utility.h" #include "lite/backends/opencl/cl_utility.h"
...@@ -33,8 +29,6 @@ class CLRuntime { ...@@ -33,8 +29,6 @@ class CLRuntime {
public: public:
static CLRuntime* Global(); static CLRuntime* Global();
void ReleaseResources();
bool Init(); bool Init();
cl::Platform& platform(); cl::Platform& platform();
...@@ -45,7 +39,7 @@ class CLRuntime { ...@@ -45,7 +39,7 @@ class CLRuntime {
cl::CommandQueue& command_queue(); cl::CommandQueue& command_queue();
std::shared_ptr<cl::Program> CreateProgram(const cl::Context& context, std::unique_ptr<cl::Program> CreateProgram(const cl::Context& context,
std::string file_name); std::string file_name);
std::unique_ptr<cl::UserEvent> CreateEvent(const cl::Context& context); std::unique_ptr<cl::UserEvent> CreateEvent(const cl::Context& context);
...@@ -60,12 +54,6 @@ class CLRuntime { ...@@ -60,12 +54,6 @@ class CLRuntime {
std::map<std::string, size_t>& GetDeviceInfo(); std::map<std::string, size_t>& GetDeviceInfo();
std::unordered_map<std::string, std::shared_ptr<cl::Program>>& programs() {
return programs_;
}
std::vector<std::unique_ptr<cl::Kernel>>& kernels() { return kernels_; }
std::map<std::string, int>& kernel_offset() { return kernel_offset_; }
private: private:
CLRuntime() = default; CLRuntime() = default;
...@@ -107,19 +95,11 @@ class CLRuntime { ...@@ -107,19 +95,11 @@ class CLRuntime {
std::shared_ptr<cl::CommandQueue> command_queue_{nullptr}; std::shared_ptr<cl::CommandQueue> command_queue_{nullptr};
std::unordered_map<std::string, std::shared_ptr<cl::Program>> programs_{};
std::vector<std::unique_ptr<cl::Kernel>> kernels_{};
std::map<std::string, int> kernel_offset_{};
cl_int status_{CL_SUCCESS}; cl_int status_{CL_SUCCESS};
bool initialized_{false}; bool initialized_{false};
bool is_init_success_{false}; bool is_init_success_{false};
bool is_resources_released_{false};
}; };
} // namespace lite } // namespace lite
......
...@@ -24,11 +24,31 @@ class RuntimeContextAssignPass : public StmtPass { ...@@ -24,11 +24,31 @@ class RuntimeContextAssignPass : public StmtPass {
RuntimeContextAssignPass() {} RuntimeContextAssignPass() {}
void Apply(const std::unique_ptr<SSAGraph>& graph) override { void Apply(const std::unique_ptr<SSAGraph>& graph) override {
#ifdef LITE_WITH_OPENCL
using OpenCLContext = Context<TargetType::kOpenCL>;
std::unique_ptr<KernelContext> local_ctx(new KernelContext());
local_ctx->As<OpenCLContext>().InitOnce();
#endif
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue; if (!node.IsStmt()) continue;
auto& inst = node.AsStmt(); auto& inst = node.AsStmt();
#ifdef LITE_WITH_OPENCL
if (inst.picked_kernel().target() == TARGET(kOpenCL)) {
std::unique_ptr<KernelContext> ctx(new KernelContext());
(*local_ctx)
.As<OpenCLContext>()
.CopySharedTo(&ctx->As<OpenCLContext>());
inst.picked_kernel().SetContext(std::move(ctx));
} else {
inst.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
inst.picked_kernel().target()));
}
#else
inst.picked_kernel().SetContext( inst.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(inst.picked_kernel().target())); ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
#endif
} }
} }
}; };
......
...@@ -106,6 +106,7 @@ class IoCopykOpenCLToHostCompute ...@@ -106,6 +106,7 @@ class IoCopykOpenCLToHostCompute
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
auto* wait_list = context.cl_wait_list(); auto* wait_list = context.cl_wait_list();
auto it = wait_list->find(x_ptr); auto it = wait_list->find(x_ptr);
if (it != wait_list->end()) { if (it != wait_list->end()) {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
...@@ -113,6 +114,9 @@ class IoCopykOpenCLToHostCompute ...@@ -113,6 +114,9 @@ class IoCopykOpenCLToHostCompute
#endif #endif
auto& event = *(it->second); auto& event = *(it->second);
event.wait(); event.wait();
auto command_queue = CLRuntime::Global()->command_queue();
command_queue.flush();
command_queue.finish();
} else { } else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor."; LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册