提交 fac96456 编写于 作者: X xzl

add prelu neon impl

上级 aabe1db1
......@@ -28,6 +28,7 @@ limitations under the License. */
#include "hl_top_k.h"
#include "paddle/utils/Logging.h"
#include "NEONFunctions.h"
#include "paddle/function/GemmFunctor.h"
#include "paddle/utils/ThreadLocal.h"
......@@ -4157,16 +4158,36 @@ void CpuMatrix::print(std::ostream& os) const {
void CpuMatrix::paramReluForward(Matrix& data, Matrix& W) {
real* input = data.getData();
real* w = W.getData();
real* output = data_;
size_t numElements = data.getWidth();
size_t numSamples = data.getHeight();
size_t paraSize = W.getHeight() * W.getWidth();
CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init
size_t partial_sum = numElements / paraSize;
if (paraSize == numElements) {
for (size_t n = 0; n < numSamples * numElements; ++n) {
output[n] = input[n] > 0 ? input[n] : input[n] * w[n % numElements];
}
return;
}
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
for (size_t n = 0; n < numSamples; ++n) {
for (size_t i = 0; i < paraSize; i++) {
neon::prelu(
input + i * partial_sum, w[i], output + i * partial_sum, partial_sum);
}
input = input + numElements;
output = output + numElements;
}
#else
for (size_t n = 0, k = 0; n < numSamples; ++n) {
for (size_t i = 0; i < numElements; ++i, ++k) {
data_[k] = input[k] > 0 ? input[k] : input[k] * w[i / partial_sum];
output[k] = input[k] > 0 ? input[k] : input[k] * w[i / partial_sum];
}
}
#endif
}
void CpuMatrix::paramReluBackwardW(Matrix& oGrad, Matrix& data) {
......
......@@ -49,6 +49,46 @@ void relu(const float* a, float* b, int len) {
}
}
// b[i] = a[i] > 0.0f ? a[i] : a[i] * w
void prelu(const float* a, float w, float* b, int len) {
int offset = len % 16;
float32x4_t ma0, ma1, ma2, ma3;
float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t vw = vdupq_n_f32(w);
for (int k = 0; k < len / 16; k++, a += 16, b += 16) {
ma0 = vld1q_f32(a);
ma1 = vld1q_f32(a + 4);
ma2 = vld1q_f32(a + 8);
ma3 = vld1q_f32(a + 12);
uint32x4_t flag0 = vcgtq_f32(ma0, zero);
uint32x4_t flag1 = vcgtq_f32(ma1, zero);
uint32x4_t flag2 = vcgtq_f32(ma2, zero);
uint32x4_t flag3 = vcgtq_f32(ma3, zero);
float32x4_t mul0 = vmulq_f32(ma0, vw);
float32x4_t mul1 = vmulq_f32(ma1, vw);
float32x4_t mul2 = vmulq_f32(ma2, vw);
float32x4_t mul3 = vmulq_f32(ma3, vw);
ma0 = vbslq_f32(flag0, ma0, mul0);
ma1 = vbslq_f32(flag1, ma1, mul1);
ma2 = vbslq_f32(flag2, ma2, mul2);
ma3 = vbslq_f32(flag3, ma3, mul3);
vst1q_f32(b, ma0);
vst1q_f32(b + 4, ma1);
vst1q_f32(b + 8, ma2);
vst1q_f32(b + 12, ma3);
}
for (int i = 0; i < offset; i++) {
b[i] = a[i] > 0.0f ? a[i] : a[i] * w;
}
}
} // namespace neon
} // namespace paddle
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace neon {
void relu(const float* a, float* b, int len);
void prelu(const float* a, float w, float* b, int len);
} // namespace neon
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册