未验证 提交 dd3150a4 编写于 作者: Y ysh329 提交者: GitHub

fix conflict and cherry pick 1d0f70ae: add opencl tune api. test=develop (#4020)

上级 61ec5d82
......@@ -215,6 +215,18 @@ ConfigBase::ConfigBase(PowerMode mode, int threads) {
#endif
}
void ConfigBase::set_opencl_tune(bool enable_tune) {
#ifdef LITE_WITH_OPENCL
if (paddle::lite_api::IsOpenCLBackendValid()) {
enable_opencl_tune_ = enable_tune;
paddle::lite::CLRuntime::Global()->set_auto_tune(enable_opencl_tune_);
#ifdef LITE_WITH_OPENCL
LOG(INFO) << "auto_tune:" << paddle::lite::CLRuntime::Global()->auto_tune();
#endif
}
#endif
}
void ConfigBase::set_power_mode(paddle::lite_api::PowerMode mode) {
#ifdef LITE_WITH_ARM
lite::DeviceInfo::Global().SetRunMode(mode, threads_);
......
......@@ -121,6 +121,8 @@ class LITE_API ConfigBase {
std::string model_dir_;
int threads_{1};
PowerMode mode_{LITE_POWER_NO_BIND};
// gpu
bool enable_opencl_tune_{false};
// to save subgraph model for npu/xpu/...
std::string subgraph_model_cache_dir_{""};
......@@ -135,6 +137,9 @@ class LITE_API ConfigBase {
// set Thread
void set_threads(int threads);
int threads() const { return threads_; }
// set GPU opencl tune
void set_opencl_tune(bool enable_tune);
bool opencl_tune() const { return enable_opencl_tune_; }
// set subgraph_model_dir
void set_subgraph_model_cache_dir(std::string subgraph_model_cache_dir) {
subgraph_model_cache_dir_ = subgraph_model_cache_dir;
......
......@@ -70,6 +70,7 @@ class CLContext {
cl::NDRange LocalWorkSizeTuneReverse(cl::NDRange global_work_size,
size_t max_work_size,
int divitor = 2);
bool IsArmMali();
private:
......
......@@ -91,6 +91,10 @@ class CLRuntime {
return is_device_avaliable_for_opencl_;
}
void set_auto_tune(bool enable_tune) { auto_tune_ = enable_tune; }
bool auto_tune() { return auto_tune_; }
bool Init();
cl::Platform& platform();
......@@ -195,6 +199,8 @@ class CLRuntime {
bool is_cl_runtime_initialized_{false};
bool is_platform_device_init_success_{false};
bool auto_tune_{false};
};
} // namespace lite
......
......@@ -92,6 +92,7 @@ void RunModel(std::string model_dir,
if (is_opencl_backend_valid) {
// give opencl nb model dir
config.set_model_from_file(model_dir);
config.set_opencl_tune(false); // default is false
} else {
std::cout << "Unsupport opencl nb model." << std::endl;
exit(1);
......
......@@ -32,16 +32,24 @@ namespace opencl {
void ConvImageCompute::PrepareForRun() {
ReInitWhenNeeded();
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const bool is_mali = context.cl_context()->IsArmMali();
use_tune_ = CLRuntime::Global()->auto_tune();
if (!is_mali) {
use_tune_ = false;
}
#ifdef LITE_WITH_LOG
LOG(INFO) << "use_tune_" << use_tune_;
#endif
auto filter_dims = conv_param_->filter->dims();
filter_tensor_n_ = filter_dims[0];
filter_tensor_c_ = filter_dims[1];
filter_tensor_h_ = filter_dims[2];
filter_tensor_w_ = filter_dims[3];
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const bool is_mali = context.cl_context()->IsArmMali();
auto paddings = *conv_param_->paddings;
pad_up_ = paddings[0];
pad_down_ = paddings[1];
......@@ -65,6 +73,7 @@ void ConvImageCompute::PrepareForRun() {
bool stride_equal = stride_h_ == stride_w_;
bool dilation_equal = dilation_h_ == dilation_w_;
#ifdef LITE_WITH_LOG
VLOG(3) << "Is arm mali / " << (is_mali ? "Yes" : "No");
VLOG(3) << "Is relu fused? / " << (relu_fused_ ? "Yes" : "No");
VLOG(3) << "groups:" << groups_ << " stride_h_:" << stride_h_
......@@ -83,6 +92,8 @@ void ConvImageCompute::PrepareForRun() {
VLOG(3) << "dilation_equal:" << dilation_equal;
VLOG(3) << "padding :" << pad_up_ << " " << pad_down_ << " " << pad_left_
<< " " << pad_right_;
#endif
CHECK(pad_equal && stride_equal && dilation_equal);
CHECK_GE(conv_param_->dilations->size(), 2);
CHECK(dilation_h_ == dilation_w_);
......@@ -91,10 +102,6 @@ void ConvImageCompute::PrepareForRun() {
CHECK_GE(conv_param_->strides.size(), 2);
CHECK(stride_h_ == stride_w_);
if (!is_mali) {
use_tune_ = false;
}
/*********************************************
* Upload filter, bias to opencl device
*********************************************/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册