diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index a73a63ddcb67f8790f73aff3fff8368f4005b7e1..3f92300ae84d6b4bd4ef53aa6d58fcd4897a5cab 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -120,6 +120,63 @@ void elementwise_add_relu(const float* dinx, } } } +template <> +void elementwise_add_tanh(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + // Elementwise_add + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vaddq_f32(dinx0, diny0); + dinx1 = vaddq_f32(dinx1, diny1); + dinx2 = vaddq_f32(dinx2, diny2); + dinx3 = vaddq_f32(dinx3, diny3); + + for (int j = 0; j < 4; j++) { + dinx0[j] = (expf(dinx0[j]) - expf(-dinx0[j])) / + (expf(dinx0[j]) + expf(-dinx0[j])); + dinx1[j] = (expf(dinx1[j]) - expf(-dinx1[j])) / + (expf(dinx1[j]) + expf(-dinx1[j])); + dinx2[j] = (expf(dinx2[j]) - expf(-dinx2[j])) / + (expf(dinx2[j]) + expf(-dinx2[j])); + dinx3[j] = (expf(dinx3[j]) - expf(-dinx3[j])) / + (expf(dinx3[j]) + expf(-dinx3[j])); + } + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + float tmp = *dinx_ptr + *diny_ptr; + *dout_ptr = (expf(tmp) - expf(-tmp)) / (expf(tmp) + expf(-tmp)); + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} template <> void elementwise_add_broadcast(const float* dinx, diff --git a/lite/backends/arm/math/elementwise.h b/lite/backends/arm/math/elementwise.h index 0b400fcce26c7d307777cc6e25d8d25e0d6234bc..7325f7ea40b86e2b72bb4c9d9a819298b49daa71 100644 --- a/lite/backends/arm/math/elementwise.h +++ b/lite/backends/arm/math/elementwise.h @@ -175,6 +175,9 @@ void elementwise_add(const T* dinx, const T* diny, T* dout, int num); template void elementwise_add_relu(const T* dinx, const T* diny, T* dout, int num); +template +void elementwise_add_tanh(const T* dinx, const T* diny, T* dout, int num); + template void elementwise_add_broadcast( const T* dinx, const T* diny, T* dout, int batch, int channels, int num); diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index b3dc2e5d7835f62bded2eedb3e4a53f0429242ce..2214bf8ce854c1ab960067989be7caa57dcdb2e1 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -122,6 +122,9 @@ void ElementwiseAddActivationCompute::Run() { if (act_type == "relu") { lite::arm::math::elementwise_add_relu( x_data, y_data, out_data, x_dims.production()); + } else if (act_type == "tanh") { + lite::arm::math::elementwise_add_tanh( + x_data, y_data, out_data, x_dims.production()); } else { LOG(FATAL) << "unsupported Activation type: " << act_type; }