提交 16cb216e 编写于 作者: C chenjiaoAngel

add elu act

上级 8eeaa0ac
...@@ -55,7 +55,8 @@ const std::string& ActivationTypeToStr(ActivationType act) { ...@@ -55,7 +55,8 @@ const std::string& ActivationTypeToStr(ActivationType act) {
"Tanh", "Tanh",
"Swish", "Swish",
"Exp", "Exp",
"ThresholdedRelu"}; "ThresholdedRelu",
"Elu"};
auto x = static_cast<int>(act); auto x = static_cast<int>(act);
CHECK_LT(x, static_cast<int>(ActivationType::NUM)); CHECK_LT(x, static_cast<int>(ActivationType::NUM));
return act2string[x]; return act2string[x];
......
...@@ -107,7 +107,8 @@ enum class ActivationType : int { ...@@ -107,7 +107,8 @@ enum class ActivationType : int {
kHardSwish = 10, kHardSwish = 10,
kReciprocal = 11, kReciprocal = 11,
kThresholdedRelu = 12, kThresholdedRelu = 12,
NUM = 13, kElu = 13,
NUM = 14,
}; };
static size_t PrecisionTypeLength(PrecisionType type) { static size_t PrecisionTypeLength(PrecisionType type) {
......
...@@ -763,6 +763,86 @@ void act_thresholded_relu<float>( ...@@ -763,6 +763,86 @@ void act_thresholded_relu<float>(
} }
} }
// elu: out = max(0,x) + min(0, alpha *(exp(x) - 1)
template <>
void act_elu<float>(
const float* din, float* dout, int size, float alpha, int threads) {
int nums_per_thread = size / threads;
int thread_remain = size % threads;
int neon_loop_cnt_dim16 = nums_per_thread >> 4;
int neon_loop_remain_dim16 = nums_per_thread & 15;
float32x4_t alpha = vdupq_n_f32(alpha);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vone = vdupq_n_f32(1.f);
int cnt = neon_loop_remain_dim16 >> 2;
int remain = neon_loop_remain_dim16 & 3;
#pragma omp parallel for
for (int i = 0; i < threads; i++) {
const float* ptr_in_thread = din + i * nums_per_thread;
float* ptr_out_thread = dout + i * nums_per_thread;
for (int k = 0; k < neon_loop_cnt_dim16; ++k) {
float32x4_t va = vld1q_f32(ptr_in_thread); // x
float32x4_t vb = vld1q_f32(ptr_in_thread + 4);
float32x4_t vc = vld1q_f32(ptr_in_thread + 8);
float32x4_t vd = vld1q_f32(ptr_in_thread + 12);
float32x4_t va_exp = exp_ps(va);
float32x4_t va_max = vmaxq_f32(va, vzero);
float32x4_t vb_exp = exp_ps(vb);
float32x4_t vb_max = vmaxq_f32(vb, vzero);
float32x4_t vc_exp = exp_ps(vc);
float32x4_t vc_max = vmaxq_f32(vc, vzero);
float32x4_t vd_exp = exp_ps(vd);
float32x4_t vd_max = vmaxq_f32(vd, vzero);
float32x4_t va_sub = vsubq_f32(va_exp, vone);
float32x4_t vb_sub = vsubq_f32(vb_exp, vone);
float32x4_t vc_sub = vsubq_f32(vc_exp, vone);
float32x4_t vd_sub = vsubq_f32(vd_exp, vone);
float32x4_t va_min = vminq_f32(va_sub, vzero);
float32x4_t vb_min = vminq_f32(vb_sub, vzero);
float32x4_t vc_min = vminq_f32(vc_sub, vzero);
float32x4_t vd_min = vminq_f32(vd_sub, vzero);
float32x4_t va_rst = vaddq_f32(va_max, va_min);
float32x4_t vb_rst = vaddq_f32(vb_max, vb_min);
float32x4_t vc_rst = vaddq_f32(vc_max, vc_min);
float32x4_t vd_rst = vaddq_f32(vd_max, vd_min);
vst1q_f32(ptr_out_thread, va_rst);
vst1q_f32(ptr_out_thread + 4, vb_rst);
vst1q_f32(ptr_out_thread + 8, vc_rst);
vst1q_f32(ptr_out_thread + 12, vd_rst);
ptr_out_thread += 16;
ptr_in_thread += 16;
}
for (int j = 0; j < cnt; j++) {
float32x4_t va = vld1q_f32(ptr_in_thread); // x
float32x4_t va_exp = exp_ps(va);
float32x4_t va_max = vmaxq_f32(va, vzero);
float32x4_t va_sub = vsubq_f32(va_exp, vone);
float32x4_t va_min = vminq_f32(va_sub, vzero);
float32x4_t va_rst = vaddq_f32(va_max, va_min);
vst1q_f32(ptr_out_thread, va_rst);
ptr_out_thread += 4;
ptr_in_thread += 4;
}
for (int j = 0; j < remain; j++) {
float beta = alpha * (expf(ptr_in_thread[0]) - 1);
float max = ptr_in[0] >= 0.f ? ptr_in_thread[0] : 0.f;
float min = beta <= 0.f ? beta : 0.f;
ptr_out_thread[0] = min + max;
ptr_in_thread++;
ptr_out_thread++;
}
}
float* ptr_out = dout + threads * nums_per_thread;
const float* ptr_in = din + threads * nums_per_thread;
for (int j = 0; j < thread_remain; j++) {
float beta = alpha * (expf(ptr_in[0]) - 1);
float max = ptr_in[0] >= 0.f ? ptr_in[0] : 0.f;
float min = beta <= 0.f ? beta : 0.f;
ptr_out[0] = max + min;
ptr_in++;
ptr_out++;
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -90,6 +90,10 @@ template <typename T> ...@@ -90,6 +90,10 @@ template <typename T>
void act_thresholded_relu( void act_thresholded_relu(
const T* din, T* dout, int size, float threshold, int threads); const T* din, T* dout, int size, float threshold, int threads);
template <typename T>
void act_elu(const T* din, T* dout, int size, float alpha, int threads);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -228,6 +228,17 @@ void ThresholdedReluCompute::Run() { ...@@ -228,6 +228,17 @@ void ThresholdedReluCompute::Run() {
x_data, output_data, x_dims.production(), threshold, ctx.threads()); x_data, output_data, x_dims.production(), threshold, ctx.threads());
} }
void EluCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
float alpha = param.Elu_alpha;
lite::arm::math::act_elu<float>(
x_data, output_data, x_dims.production(), alpha, ctx.threads());
}
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
...@@ -356,3 +367,12 @@ REGISTER_LITE_KERNEL(thresholded_relu, ...@@ -356,3 +367,12 @@ REGISTER_LITE_KERNEL(thresholded_relu,
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(elu,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::EluCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
\ No newline at end of file
...@@ -185,6 +185,16 @@ class ThresholdedReluCompute ...@@ -185,6 +185,16 @@ class ThresholdedReluCompute
virtual ~ThresholdedReluCompute() = default; virtual ~ThresholdedReluCompute() = default;
}; };
class EluCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~EluCompute() = default;
};
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -85,6 +85,9 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { ...@@ -85,6 +85,9 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
} else if (opdesc.Type() == "thresholded_relu") { } else if (opdesc.Type() == "thresholded_relu") {
param_.active_type = lite_api::ActivationType::kThresholdedRelu; param_.active_type = lite_api::ActivationType::kThresholdedRelu;
param_.relu_threshold = opdesc.GetAttr<float>("threshold"); param_.relu_threshold = opdesc.GetAttr<float>("threshold");
} else if (opdesc.Type() == "elu") {
param_.active_type = lite_api::ActivationType::kElu;
param_.param_.Elu_alpha = opdesc.GetAttr<float>("alpha");
} }
VLOG(4) << "opdesc.Type():" << opdesc.Type(); VLOG(4) << "opdesc.Type():" << opdesc.Type();
...@@ -105,3 +108,4 @@ REGISTER_LITE_OP(leaky_relu, paddle::lite::operators::ActivationOp); ...@@ -105,3 +108,4 @@ REGISTER_LITE_OP(leaky_relu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(prelu, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(prelu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(thresholded_relu, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(thresholded_relu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(elu, paddle::lite::operators::ActivationOp);
...@@ -83,6 +83,9 @@ class ActivationOp : public OpLite { ...@@ -83,6 +83,9 @@ class ActivationOp : public OpLite {
case lite_api::ActivationType::kThresholdedRelu: case lite_api::ActivationType::kThresholdedRelu:
ch->macs = param_.X->numel(); ch->macs = param_.X->numel();
break; break;
case lite_api::ActivationType::kElu:
ch->macs = param_.X->numel();
break;
default: default:
LOG(FATAL) << "This Type of Activation:" LOG(FATAL) << "This Type of Activation:"
<< static_cast<int>(param_.active_type) << static_cast<int>(param_.active_type)
......
...@@ -359,6 +359,8 @@ struct ActivationParam : ParamBase { ...@@ -359,6 +359,8 @@ struct ActivationParam : ParamBase {
float hard_swish_offset{3.0}; float hard_swish_offset{3.0};
// thresholded_relu // thresholded_relu
float relu_threshold{1.0f}; float relu_threshold{1.0f};
// elu
float Elu_alpha{1.0f};
}; };
struct ActivationGradParam : ParamBase { struct ActivationGradParam : ParamBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册