未验证 提交 0a22eb4c 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] Fix act bug for opencl kernel (#3085)

* Fix act bug for opencl kernel. test=develop

* add print ActToStr. test=develop

* remove print code. test=develop
上级 b601d81f
......@@ -45,6 +45,21 @@ std::string Place::DebugString() const {
return os.str();
}
const std::string& ActivationTypeToStr(ActivationType act) {
static const std::string act2string[] = {"unk",
"Relu",
"Relu6",
"PRelu",
"LeakyRelu",
"Sigmoid",
"Tanh",
"Swish",
"Exp"};
auto x = static_cast<int>(act);
CHECK_LT(x, static_cast<int>(ActivationType::NUM));
return act2string[x];
}
const std::string& TargetToStr(TargetType target) {
static const std::string target2string[] = {"unk",
"host",
......
......@@ -97,7 +97,8 @@ enum class ActivationType : int {
kSigmoid = 5,
kTanh = 6,
kSwish = 7,
kExp = 8
kExp = 8,
NUM = 9,
};
static size_t PrecisionTypeLength(PrecisionType type) {
......@@ -149,6 +150,8 @@ _ForEachPrecisionType(DefinePrecisionTypeTrait);
#define PRECISION(item__) paddle::lite_api::PrecisionType::item__
#define DATALAYOUT(item__) paddle::lite_api::DataLayoutType::item__
const std::string& ActivationTypeToStr(ActivationType act);
const std::string& TargetToStr(TargetType target);
const std::string& PrecisionToStr(PrecisionType precision);
......
......@@ -40,6 +40,8 @@ class ActivationComputeImageDefault
auto& context = ctx_->As<OpenCLContext>();
act_param_ = param_.get_mutable<param_t>();
int act_type = static_cast<int>(act_param_->active_type);
VLOG(1) << "ActivationTypeToStr(act_param_->active_type):"
<< ActivationTypeToStr(act_param_->active_type);
switch (act_type) {
case 1:
kernel_func_name_ = "relu";
......@@ -66,9 +68,10 @@ class ActivationComputeImageDefault
kernel_func_name_ = "exp_act";
break;
default:
printf("This act type: %d doesn't support \n", act_type);
LOG(FATAL) << "This act type:" << act_type << " doesn't support.";
return;
}
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(
kernel_func_name_, "image/activation_kernel.cl", build_options_);
}
......@@ -87,6 +90,7 @@ class ActivationComputeImageDefault
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
......
......@@ -39,6 +39,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
} else {
kernel_func_name_ = "concat_mul";
}
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(
kernel_func_name_, "image/concat_kernel.cl", build_options_);
......
......@@ -46,7 +46,7 @@ void ElementwiseAddImageCompute::PrepareForRun() {
<< ", x->dims().size():" << x->dims().size()
<< ", y->dims.size():" << y->dims().size();
}
VLOG(4) << "kernel_func_name_:" << kernel_func_name_;
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
......
......@@ -63,7 +63,7 @@ class ElementwiseMulImageCompute
<< y_dims.size()
<< ", x_dims.size():" << ele_param_->X->dims().size();
}
VLOG(4) << "kernel_func_name_:" << kernel_func_name_;
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
VLOG(4) << "y_dims:" << y_dims;
VLOG(4) << "y_dims.size():" << y_dims.size();
......
......@@ -38,6 +38,7 @@ class FusionElementwiseAddActivationImageCompute
if (act_t != "relu") {
LOG(FATAL) << "Unsupported Activation type: " << act_t;
}
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
}
};
......
......@@ -44,6 +44,7 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/grid_sampler_kernel.cl", build_options_);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
}
void Run() override {
......
......@@ -42,7 +42,7 @@ class LayoutComputeBufferChwToImageDefault
if (param.process_type == 1) {
kernel_func_name_ = "buffer_to_image2d_with_pre255";
}
VLOG(2) << "kernel_func_name_:" << kernel_func_name_;
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/layout_kernel.cl", build_options_);
......@@ -153,7 +153,7 @@ class LayoutComputeImageDefaultToBufferChw
if (param.process_type == 1) {
kernel_func_name_ = "image2d_to_buffer_with_post255";
}
VLOG(2) << "kernel_func_name_:" << kernel_func_name_;
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/layout_kernel.cl", build_options_);
......
......@@ -40,6 +40,7 @@ class NearestInterpComputeImageDefault
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/nearest_interp_kernel.cl", build_options_);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
}
void Run() override {
......
......@@ -44,6 +44,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
if (global_pooling) {
kernel_func_name_ += "_global";
}
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/pool_kernel.cl", build_options_);
......
......@@ -35,6 +35,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(
kernel_func_name_, "image/reshape_kernel.cl", build_options_);
}
......
......@@ -37,6 +37,7 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(
kernel_func_name_, "image/scale_kernel.cl", build_options_);
}
......
......@@ -36,26 +36,44 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto x_name = opdesc.Input("X").front();
auto out_name = opdesc.Output("Out").front();
param_.X = scope->FindVar(x_name)->GetMutable<lite::Tensor>();
if (opdesc.Type() == "leaky_relu") {
if (opdesc.Type() == "relu") {
// relu
param_.active_type = lite_api::ActivationType::kRelu;
} else if (opdesc.Type() == "leaky_relu") {
// leaky_relu
param_.Leaky_relu_alpha = opdesc.GetAttr<float>("alpha");
}
if (opdesc.Type() == "relu_clipped") {
param_.active_type = lite_api::ActivationType::kLeakyRelu;
} else if (opdesc.Type() == "relu_clipped") {
// relu_clipped
param_.Relu_clipped_coef = opdesc.GetAttr<float>("Relu_clipped_coef");
}
if (opdesc.Type() == "prelu") {
} else if (opdesc.Type() == "prelu") {
// prelu
param_.Prelu_mode = opdesc.GetAttr<std::string>("mode");
auto prelu_alpha_name = opdesc.Input("Alpha").front();
param_.Prelu_alpha =
scope->FindVar(prelu_alpha_name)->GetMutable<lite::Tensor>();
}
if (opdesc.Type() == "swish") {
param_.active_type = lite_api::ActivationType::kPRelu;
} else if (opdesc.Type() == "swish") {
// swish
param_.Swish_beta = opdesc.GetAttr<float>("beta");
}
if (opdesc.Type() == "hard_sigmoid") {
param_.active_type = lite_api::ActivationType::kSwish;
} else if (opdesc.Type() == "hard_sigmoid") {
// hard_sigomid
param_.hard_sigmoid_slope = opdesc.GetAttr<float>("slope");
param_.hard_sigmoid_offset = opdesc.GetAttr<float>("offset");
} else if (opdesc.Type() == "sigmoid") {
// sigmoid
param_.active_type = lite_api::ActivationType::kSigmoid;
} else if (opdesc.Type() == "tanh") {
// tanh
param_.active_type = lite_api::ActivationType::kTanh;
} else if (opdesc.Type() == "exp") {
// exp
param_.active_type = lite_api::ActivationType::kExp;
}
VLOG(4) << "opdesc.Type():" << opdesc.Type();
param_.Out = scope->FindVar(out_name)->GetMutable<lite::Tensor>();
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册