未验证 提交 69113320 编写于 作者: H huzhiqiang 提交者: GitHub

[Arm]add elementwise_add_tanh method (#4204)

上级 6f648cfc
...@@ -120,6 +120,63 @@ void elementwise_add_relu<float>(const float* dinx, ...@@ -120,6 +120,63 @@ void elementwise_add_relu<float>(const float* dinx,
} }
} }
} }
template <>
void elementwise_add_tanh<float>(const float* dinx,
const float* diny,
float* dout,
int num) {
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* dinx_ptr = dinx + (i << 4);
const float* diny_ptr = diny + (i << 4);
float* dout_ptr = dout + (i << 4);
// Elementwise_add
float32x4_t dinx0 = vld1q_f32(dinx_ptr);
float32x4_t diny0 = vld1q_f32(diny_ptr);
float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4);
float32x4_t diny1 = vld1q_f32(diny_ptr + 4);
float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8);
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
dinx0 = vaddq_f32(dinx0, diny0);
dinx1 = vaddq_f32(dinx1, diny1);
dinx2 = vaddq_f32(dinx2, diny2);
dinx3 = vaddq_f32(dinx3, diny3);
for (int j = 0; j < 4; j++) {
dinx0[j] = (expf(dinx0[j]) - expf(-dinx0[j])) /
(expf(dinx0[j]) + expf(-dinx0[j]));
dinx1[j] = (expf(dinx1[j]) - expf(-dinx1[j])) /
(expf(dinx1[j]) + expf(-dinx1[j]));
dinx2[j] = (expf(dinx2[j]) - expf(-dinx2[j])) /
(expf(dinx2[j]) + expf(-dinx2[j]));
dinx3[j] = (expf(dinx3[j]) - expf(-dinx3[j])) /
(expf(dinx3[j]) + expf(-dinx3[j]));
}
vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, dinx3);
}
if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4);
const float* diny_ptr = diny + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
float tmp = *dinx_ptr + *diny_ptr;
*dout_ptr = (expf(tmp) - expf(-tmp)) / (expf(tmp) + expf(-tmp));
dout_ptr++;
dinx_ptr++;
diny_ptr++;
}
}
}
template <> template <>
void elementwise_add_broadcast<float>(const float* dinx, void elementwise_add_broadcast<float>(const float* dinx,
......
...@@ -175,6 +175,9 @@ void elementwise_add(const T* dinx, const T* diny, T* dout, int num); ...@@ -175,6 +175,9 @@ void elementwise_add(const T* dinx, const T* diny, T* dout, int num);
template <typename T> template <typename T>
void elementwise_add_relu(const T* dinx, const T* diny, T* dout, int num); void elementwise_add_relu(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_add_tanh(const T* dinx, const T* diny, T* dout, int num);
template <typename T> template <typename T>
void elementwise_add_broadcast( void elementwise_add_broadcast(
const T* dinx, const T* diny, T* dout, int batch, int channels, int num); const T* dinx, const T* diny, T* dout, int batch, int channels, int num);
......
...@@ -122,6 +122,9 @@ void ElementwiseAddActivationCompute::Run() { ...@@ -122,6 +122,9 @@ void ElementwiseAddActivationCompute::Run() {
if (act_type == "relu") { if (act_type == "relu") {
lite::arm::math::elementwise_add_relu( lite::arm::math::elementwise_add_relu(
x_data, y_data, out_data, x_dims.production()); x_data, y_data, out_data, x_dims.production());
} else if (act_type == "tanh") {
lite::arm::math::elementwise_add_tanh(
x_data, y_data, out_data, x_dims.production());
} else { } else {
LOG(FATAL) << "unsupported Activation type: " << act_type; LOG(FATAL) << "unsupported Activation type: " << act_type;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册