未验证 提交 0a22eb4c 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] Fix act bug for opencl kernel (#3085)

* Fix act bug for opencl kernel. test=develop

* add print ActToStr. test=develop

* remove print code. test=develop
上级 b601d81f
...@@ -45,6 +45,21 @@ std::string Place::DebugString() const { ...@@ -45,6 +45,21 @@ std::string Place::DebugString() const {
return os.str(); return os.str();
} }
const std::string& ActivationTypeToStr(ActivationType act) {
static const std::string act2string[] = {"unk",
"Relu",
"Relu6",
"PRelu",
"LeakyRelu",
"Sigmoid",
"Tanh",
"Swish",
"Exp"};
auto x = static_cast<int>(act);
CHECK_LT(x, static_cast<int>(ActivationType::NUM));
return act2string[x];
}
const std::string& TargetToStr(TargetType target) { const std::string& TargetToStr(TargetType target) {
static const std::string target2string[] = {"unk", static const std::string target2string[] = {"unk",
"host", "host",
......
...@@ -97,7 +97,8 @@ enum class ActivationType : int { ...@@ -97,7 +97,8 @@ enum class ActivationType : int {
kSigmoid = 5, kSigmoid = 5,
kTanh = 6, kTanh = 6,
kSwish = 7, kSwish = 7,
kExp = 8 kExp = 8,
NUM = 9,
}; };
static size_t PrecisionTypeLength(PrecisionType type) { static size_t PrecisionTypeLength(PrecisionType type) {
...@@ -149,6 +150,8 @@ _ForEachPrecisionType(DefinePrecisionTypeTrait); ...@@ -149,6 +150,8 @@ _ForEachPrecisionType(DefinePrecisionTypeTrait);
#define PRECISION(item__) paddle::lite_api::PrecisionType::item__ #define PRECISION(item__) paddle::lite_api::PrecisionType::item__
#define DATALAYOUT(item__) paddle::lite_api::DataLayoutType::item__ #define DATALAYOUT(item__) paddle::lite_api::DataLayoutType::item__
const std::string& ActivationTypeToStr(ActivationType act);
const std::string& TargetToStr(TargetType target); const std::string& TargetToStr(TargetType target);
const std::string& PrecisionToStr(PrecisionType precision); const std::string& PrecisionToStr(PrecisionType precision);
......
...@@ -40,6 +40,8 @@ class ActivationComputeImageDefault ...@@ -40,6 +40,8 @@ class ActivationComputeImageDefault
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
act_param_ = param_.get_mutable<param_t>(); act_param_ = param_.get_mutable<param_t>();
int act_type = static_cast<int>(act_param_->active_type); int act_type = static_cast<int>(act_param_->active_type);
VLOG(1) << "ActivationTypeToStr(act_param_->active_type):"
<< ActivationTypeToStr(act_param_->active_type);
switch (act_type) { switch (act_type) {
case 1: case 1:
kernel_func_name_ = "relu"; kernel_func_name_ = "relu";
...@@ -66,9 +68,10 @@ class ActivationComputeImageDefault ...@@ -66,9 +68,10 @@ class ActivationComputeImageDefault
kernel_func_name_ = "exp_act"; kernel_func_name_ = "exp_act";
break; break;
default: default:
printf("This act type: %d doesn't support \n", act_type); LOG(FATAL) << "This act type:" << act_type << " doesn't support.";
return; return;
} }
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel( context.cl_context()->AddKernel(
kernel_func_name_, "image/activation_kernel.cl", build_options_); kernel_func_name_, "image/activation_kernel.cl", build_options_);
} }
...@@ -87,6 +90,7 @@ class ActivationComputeImageDefault ...@@ -87,6 +90,7 @@ class ActivationComputeImageDefault
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_; kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0; int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_img); cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
......
...@@ -39,6 +39,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -39,6 +39,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
} else { } else {
kernel_func_name_ = "concat_mul"; kernel_func_name_ = "concat_mul";
} }
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel( context.cl_context()->AddKernel(
kernel_func_name_, "image/concat_kernel.cl", build_options_); kernel_func_name_, "image/concat_kernel.cl", build_options_);
......
...@@ -46,7 +46,7 @@ void ElementwiseAddImageCompute::PrepareForRun() { ...@@ -46,7 +46,7 @@ void ElementwiseAddImageCompute::PrepareForRun() {
<< ", x->dims().size():" << x->dims().size() << ", x->dims().size():" << x->dims().size()
<< ", y->dims.size():" << y->dims().size(); << ", y->dims.size():" << y->dims().size();
} }
VLOG(4) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( context.cl_context()->AddKernel(
......
...@@ -63,7 +63,7 @@ class ElementwiseMulImageCompute ...@@ -63,7 +63,7 @@ class ElementwiseMulImageCompute
<< y_dims.size() << y_dims.size()
<< ", x_dims.size():" << ele_param_->X->dims().size(); << ", x_dims.size():" << ele_param_->X->dims().size();
} }
VLOG(4) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
VLOG(4) << "y_dims:" << y_dims; VLOG(4) << "y_dims:" << y_dims;
VLOG(4) << "y_dims.size():" << y_dims.size(); VLOG(4) << "y_dims.size():" << y_dims.size();
......
...@@ -38,6 +38,7 @@ class FusionElementwiseAddActivationImageCompute ...@@ -38,6 +38,7 @@ class FusionElementwiseAddActivationImageCompute
if (act_t != "relu") { if (act_t != "relu") {
LOG(FATAL) << "Unsupported Activation type: " << act_t; LOG(FATAL) << "Unsupported Activation type: " << act_t;
} }
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
} }
}; };
......
...@@ -44,6 +44,7 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -44,6 +44,7 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( context.cl_context()->AddKernel(
kernel_func_name_, "image/grid_sampler_kernel.cl", build_options_); kernel_func_name_, "image/grid_sampler_kernel.cl", build_options_);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
} }
void Run() override { void Run() override {
......
...@@ -42,7 +42,7 @@ class LayoutComputeBufferChwToImageDefault ...@@ -42,7 +42,7 @@ class LayoutComputeBufferChwToImageDefault
if (param.process_type == 1) { if (param.process_type == 1) {
kernel_func_name_ = "buffer_to_image2d_with_pre255"; kernel_func_name_ = "buffer_to_image2d_with_pre255";
} }
VLOG(2) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( context.cl_context()->AddKernel(
kernel_func_name_, "image/layout_kernel.cl", build_options_); kernel_func_name_, "image/layout_kernel.cl", build_options_);
...@@ -153,7 +153,7 @@ class LayoutComputeImageDefaultToBufferChw ...@@ -153,7 +153,7 @@ class LayoutComputeImageDefaultToBufferChw
if (param.process_type == 1) { if (param.process_type == 1) {
kernel_func_name_ = "image2d_to_buffer_with_post255"; kernel_func_name_ = "image2d_to_buffer_with_post255";
} }
VLOG(2) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( context.cl_context()->AddKernel(
kernel_func_name_, "image/layout_kernel.cl", build_options_); kernel_func_name_, "image/layout_kernel.cl", build_options_);
......
...@@ -40,6 +40,7 @@ class NearestInterpComputeImageDefault ...@@ -40,6 +40,7 @@ class NearestInterpComputeImageDefault
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( context.cl_context()->AddKernel(
kernel_func_name_, "image/nearest_interp_kernel.cl", build_options_); kernel_func_name_, "image/nearest_interp_kernel.cl", build_options_);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
} }
void Run() override { void Run() override {
......
...@@ -44,6 +44,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -44,6 +44,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
if (global_pooling) { if (global_pooling) {
kernel_func_name_ += "_global"; kernel_func_name_ += "_global";
} }
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( context.cl_context()->AddKernel(
kernel_func_name_, "image/pool_kernel.cl", build_options_); kernel_func_name_, "image/pool_kernel.cl", build_options_);
......
...@@ -35,6 +35,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL), ...@@ -35,6 +35,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override { void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel( context.cl_context()->AddKernel(
kernel_func_name_, "image/reshape_kernel.cl", build_options_); kernel_func_name_, "image/reshape_kernel.cl", build_options_);
} }
......
...@@ -37,6 +37,7 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -37,6 +37,7 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override { void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel( context.cl_context()->AddKernel(
kernel_func_name_, "image/scale_kernel.cl", build_options_); kernel_func_name_, "image/scale_kernel.cl", build_options_);
} }
......
...@@ -36,26 +36,44 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { ...@@ -36,26 +36,44 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto x_name = opdesc.Input("X").front(); auto x_name = opdesc.Input("X").front();
auto out_name = opdesc.Output("Out").front(); auto out_name = opdesc.Output("Out").front();
param_.X = scope->FindVar(x_name)->GetMutable<lite::Tensor>(); param_.X = scope->FindVar(x_name)->GetMutable<lite::Tensor>();
if (opdesc.Type() == "leaky_relu") {
if (opdesc.Type() == "relu") {
// relu
param_.active_type = lite_api::ActivationType::kRelu;
} else if (opdesc.Type() == "leaky_relu") {
// leaky_relu
param_.Leaky_relu_alpha = opdesc.GetAttr<float>("alpha"); param_.Leaky_relu_alpha = opdesc.GetAttr<float>("alpha");
} param_.active_type = lite_api::ActivationType::kLeakyRelu;
if (opdesc.Type() == "relu_clipped") { } else if (opdesc.Type() == "relu_clipped") {
// relu_clipped
param_.Relu_clipped_coef = opdesc.GetAttr<float>("Relu_clipped_coef"); param_.Relu_clipped_coef = opdesc.GetAttr<float>("Relu_clipped_coef");
} } else if (opdesc.Type() == "prelu") {
if (opdesc.Type() == "prelu") { // prelu
param_.Prelu_mode = opdesc.GetAttr<std::string>("mode"); param_.Prelu_mode = opdesc.GetAttr<std::string>("mode");
auto prelu_alpha_name = opdesc.Input("Alpha").front(); auto prelu_alpha_name = opdesc.Input("Alpha").front();
param_.Prelu_alpha = param_.Prelu_alpha =
scope->FindVar(prelu_alpha_name)->GetMutable<lite::Tensor>(); scope->FindVar(prelu_alpha_name)->GetMutable<lite::Tensor>();
} param_.active_type = lite_api::ActivationType::kPRelu;
if (opdesc.Type() == "swish") { } else if (opdesc.Type() == "swish") {
// swish
param_.Swish_beta = opdesc.GetAttr<float>("beta"); param_.Swish_beta = opdesc.GetAttr<float>("beta");
} param_.active_type = lite_api::ActivationType::kSwish;
} else if (opdesc.Type() == "hard_sigmoid") {
if (opdesc.Type() == "hard_sigmoid") { // hard_sigomid
param_.hard_sigmoid_slope = opdesc.GetAttr<float>("slope"); param_.hard_sigmoid_slope = opdesc.GetAttr<float>("slope");
param_.hard_sigmoid_offset = opdesc.GetAttr<float>("offset"); param_.hard_sigmoid_offset = opdesc.GetAttr<float>("offset");
} else if (opdesc.Type() == "sigmoid") {
// sigmoid
param_.active_type = lite_api::ActivationType::kSigmoid;
} else if (opdesc.Type() == "tanh") {
// tanh
param_.active_type = lite_api::ActivationType::kTanh;
} else if (opdesc.Type() == "exp") {
// exp
param_.active_type = lite_api::ActivationType::kExp;
} }
VLOG(4) << "opdesc.Type():" << opdesc.Type();
param_.Out = scope->FindVar(out_name)->GetMutable<lite::Tensor>(); param_.Out = scope->FindVar(out_name)->GetMutable<lite::Tensor>();
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册