提交 647f9806 编写于 作者: H hjchen2

Optimize gru kerenl, thanks to smilejames

上级 9851d10d
...@@ -16,9 +16,9 @@ limitations under the License. */ ...@@ -16,9 +16,9 @@ limitations under the License. */
#ifdef ENABLE_EXCEPTION #ifdef ENABLE_EXCEPTION
#include <stdio.h> #include <stdio.h>
#include <stdlib.h>
#include <exception> #include <exception>
#include <string> #include <string>
#endif #endif
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -93,18 +93,18 @@ enum RoundType { ...@@ -93,18 +93,18 @@ enum RoundType {
}; };
enum ActivationType { enum ActivationType {
Linear = 0, IDENTITY = 0,
Relu = 1, RELU = 1,
Relu6 = 2, RELU6 = 2,
PRelu = 3, PRELU = 3,
LeakyRelu = 4, LEAKY_RELU = 4,
Tanh = 5, TANH = 5,
Sigmoid = 6, SIGMOID = 6,
}; };
enum PoolingType { enum PoolingType {
Max = 0, MAX = 0,
Avg = 1, AVG = 1,
}; };
extern const char *G_OP_TYPE_CONV; extern const char *G_OP_TYPE_CONV;
......
...@@ -143,12 +143,10 @@ double PaddleMobile<CPU, Precision::FP32>::GetPredictTime() { ...@@ -143,12 +143,10 @@ double PaddleMobile<CPU, Precision::FP32>::GetPredictTime() {
int t1 = 1; int t1 = 1;
int t2 = 1; int t2 = 1;
for (int i = 0; i < m * k; ++i) { for (int i = 0; i < m * k; ++i) {
unsigned int seed = 100; a[i] = t1 + rand() % t2; // NOLINT
a[i] = t1 + rand_r(&seed) % t2;
} }
for (int i = 0; i < k * n; ++i) { for (int i = 0; i < k * n; ++i) {
unsigned int seed = 200; b[i] = t1 + rand() % t2; // NOLINT
b[i] = t1 + rand_r(&seed) % t2;
} }
paddle_mobile::operators::math::Gemm gemm; paddle_mobile::operators::math::Gemm gemm;
auto time1 = paddle_mobile::time(); auto time1 = paddle_mobile::time();
......
...@@ -131,7 +131,7 @@ bool FusionDequantBNKernel<CPU, float>::Init(FusionDequantBNParam<CPU> *param) { ...@@ -131,7 +131,7 @@ bool FusionDequantBNKernel<CPU, float>::Init(FusionDequantBNParam<CPU> *param) {
template <> template <>
void FusionDequantBNKernel<CPU, float>::Compute( void FusionDequantBNKernel<CPU, float>::Compute(
const FusionDequantBNParam<CPU> &param) { const FusionDequantBNParam<CPU> &param) {
DequantBNCompute<Linear>(&param); DequantBNCompute<IDENTITY>(&param);
} }
#endif // FUSION_DEQUANT_BN_OP #endif // FUSION_DEQUANT_BN_OP
...@@ -146,7 +146,7 @@ bool FusionDequantBNReluKernel<CPU, float>::Init( ...@@ -146,7 +146,7 @@ bool FusionDequantBNReluKernel<CPU, float>::Init(
template <> template <>
void FusionDequantBNReluKernel<CPU, float>::Compute( void FusionDequantBNReluKernel<CPU, float>::Compute(
const FusionDequantBNParam<CPU> &param) { const FusionDequantBNParam<CPU> &param) {
DequantBNCompute<Relu>(&param); DequantBNCompute<RELU>(&param);
} }
#endif // FUSION_DEQUANT_BN_RELU_OP #endif // FUSION_DEQUANT_BN_RELU_OP
...@@ -162,7 +162,7 @@ bool FusionDequantAddBNKernel<CPU, float>::Init( ...@@ -162,7 +162,7 @@ bool FusionDequantAddBNKernel<CPU, float>::Init(
template <> template <>
void FusionDequantAddBNKernel<CPU, float>::Compute( void FusionDequantAddBNKernel<CPU, float>::Compute(
const FusionDequantAddBNParam<CPU> &param) { const FusionDequantAddBNParam<CPU> &param) {
DequantBNCompute<Linear>(&param); DequantBNCompute<IDENTITY>(&param);
} }
#endif // FUSION_DEQUANT_ADD_BN_OP #endif // FUSION_DEQUANT_ADD_BN_OP
...@@ -178,7 +178,7 @@ bool FusionDequantAddBNReluKernel<CPU, float>::Init( ...@@ -178,7 +178,7 @@ bool FusionDequantAddBNReluKernel<CPU, float>::Init(
template <> template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute( void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNParam<CPU> &param) { const FusionDequantAddBNParam<CPU> &param) {
DequantBNCompute<Relu>(&param); DequantBNCompute<RELU>(&param);
} }
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP #endif // FUSION_DEQUANT_ADD_BN_RELU_OP
...@@ -292,13 +292,13 @@ void FusionDequantAddBNQuantKernel<CPU, float>::Compute( ...@@ -292,13 +292,13 @@ void FusionDequantAddBNQuantKernel<CPU, float>::Compute(
const FusionDequantAddBNQuantParam<CPU> &param) { const FusionDequantAddBNQuantParam<CPU> &param) {
switch (param.round_type_) { switch (param.round_type_) {
case ROUND_NEAREST_TO_EVEN: case ROUND_NEAREST_TO_EVEN:
DequantBNQuantCompute<Linear, ROUND_NEAREST_TO_EVEN>(&param); DequantBNQuantCompute<IDENTITY, ROUND_NEAREST_TO_EVEN>(&param);
break; break;
case ROUND_NEAREST_TOWARDS_ZERO: case ROUND_NEAREST_TOWARDS_ZERO:
DequantBNQuantCompute<Linear, ROUND_NEAREST_TOWARDS_ZERO>(&param); DequantBNQuantCompute<IDENTITY, ROUND_NEAREST_TOWARDS_ZERO>(&param);
break; break;
case ROUND_NEAREST_AWAY_ZERO: case ROUND_NEAREST_AWAY_ZERO:
DequantBNQuantCompute<Linear, ROUND_NEAREST_AWAY_ZERO>(&param); DequantBNQuantCompute<IDENTITY, ROUND_NEAREST_AWAY_ZERO>(&param);
break; break;
default: default:
LOG(kLOG_ERROR) << "round type is not supported."; LOG(kLOG_ERROR) << "round type is not supported.";
...@@ -321,13 +321,13 @@ void FusionDequantAddBNReluQuantKernel<CPU, float>::Compute( ...@@ -321,13 +321,13 @@ void FusionDequantAddBNReluQuantKernel<CPU, float>::Compute(
const FusionDequantAddBNQuantParam<CPU> &param) { const FusionDequantAddBNQuantParam<CPU> &param) {
switch (param.round_type_) { switch (param.round_type_) {
case ROUND_NEAREST_TO_EVEN: case ROUND_NEAREST_TO_EVEN:
DequantBNQuantCompute<Relu, ROUND_NEAREST_TO_EVEN>(&param); DequantBNQuantCompute<RELU, ROUND_NEAREST_TO_EVEN>(&param);
break; break;
case ROUND_NEAREST_TOWARDS_ZERO: case ROUND_NEAREST_TOWARDS_ZERO:
DequantBNQuantCompute<Relu, ROUND_NEAREST_TOWARDS_ZERO>(&param); DequantBNQuantCompute<RELU, ROUND_NEAREST_TOWARDS_ZERO>(&param);
break; break;
case ROUND_NEAREST_AWAY_ZERO: case ROUND_NEAREST_AWAY_ZERO:
DequantBNQuantCompute<Relu, ROUND_NEAREST_AWAY_ZERO>(&param); DequantBNQuantCompute<RELU, ROUND_NEAREST_AWAY_ZERO>(&param);
break; break;
default: default:
LOG(kLOG_ERROR) << "round type is not supported."; LOG(kLOG_ERROR) << "round type is not supported.";
......
...@@ -34,14 +34,66 @@ inline float32_t vmaxvq_f32(float32x4_t r) { ...@@ -34,14 +34,66 @@ inline float32_t vmaxvq_f32(float32x4_t r) {
#endif #endif
template <RoundType R> template <RoundType R>
static void Quantize(const Tensor *input, const float scale, Tensor *output) { inline void QuantizeOffline(const Tensor *input, const float scale,
const float max_abs, Tensor *output) {
const float *x = input->data<const float>(); const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>(); int8_t *y = output->mutable_data<int8_t>();
size_t remain = input->numel(); size_t remain = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = remain >> 4; size_t loop = remain >> 4;
remain = remain & 0xF; remain = remain & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __postive_max = vdupq_n_f32(max_abs);
float32x4_t __negtive_max = vdupq_n_f32(-max_abs);
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
const float *local_x = x + (i << 4);
int8_t *local_y = y + (i << 4);
float32x4_t r0 = vld1q_f32(local_x);
float32x4_t r1 = vld1q_f32(local_x + 4);
float32x4_t r2 = vld1q_f32(local_x + 8);
float32x4_t r3 = vld1q_f32(local_x + 12);
r0 = vmaxq_f32(vminq_f32(r0, __postive_max), __negtive_max);
r1 = vmaxq_f32(vminq_f32(r1, __postive_max), __negtive_max);
r2 = vmaxq_f32(vminq_f32(r2, __postive_max), __negtive_max);
r3 = vmaxq_f32(vminq_f32(r3, __postive_max), __negtive_max);
r0 = vmulq_f32(r0, __scale);
r1 = vmulq_f32(r1, __scale);
r2 = vmulq_f32(r2, __scale);
r3 = vmulq_f32(r3, __scale);
int32x4_t q0 = math::vRoundq_f32<R>(r0);
int32x4_t q1 = math::vRoundq_f32<R>(r1);
int32x4_t q2 = math::vRoundq_f32<R>(r2);
int32x4_t q3 = math::vRoundq_f32<R>(r3);
int16x4_t d0 = vmovn_s32(q0);
int16x4_t d1 = vmovn_s32(q1);
int16x4_t d2 = vmovn_s32(q2);
int16x4_t d3 = vmovn_s32(q3);
int16x8_t q5 = vcombine_s16(d0, d1);
int16x8_t q6 = vcombine_s16(d2, d3);
int8x8_t d5 = vmovn_s16(q5);
int8x8_t d6 = vmovn_s16(q6);
vst1_s8(local_y, d5);
vst1_s8(local_y + 8, d6);
}
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < remain; ++i) {
float x_temp = std::max(std::min(x[i], max_abs), -max_abs);
y[i] = math::Round<R>(x_temp * scale);
}
}
template <RoundType R>
inline void QuantizeOnline(const Tensor *input, const float scale,
Tensor *output) {
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t remain = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = remain >> 4;
remain = remain & 0xF;
float32x4_t __scale = vdupq_n_f32(scale); float32x4_t __scale = vdupq_n_f32(scale);
#pragma omp parallel for #pragma omp parallel for
for (size_t i = 0; i < loop; ++i) { for (size_t i = 0; i < loop; ++i) {
...@@ -78,6 +130,17 @@ static void Quantize(const Tensor *input, const float scale, Tensor *output) { ...@@ -78,6 +130,17 @@ static void Quantize(const Tensor *input, const float scale, Tensor *output) {
} }
} }
template <RoundType R>
static void Quantize(const Tensor *input, const float max_abs,
const bool offline, Tensor *output) {
float scale = 127.f / max_abs;
if (offline) {
QuantizeOffline<R>(input, scale, max_abs, output);
} else {
QuantizeOnline<R>(input, scale, output);
}
}
float find_abs_max(const Tensor *input) { float find_abs_max(const Tensor *input) {
float max_abs = 0.f; float max_abs = 0.f;
const float *x = input->data<const float>(); const float *x = input->data<const float>();
...@@ -133,23 +196,22 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) { ...@@ -133,23 +196,22 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
max_abs = find_abs_max(input); max_abs = find_abs_max(input);
} }
max_abs = std::max(max_abs, 1e-6f); max_abs = std::max(max_abs, 1e-6f);
// only support int8 currently
float scale = 127 / max_abs;
param.online_scale_->mutable_data<float>()[0] = max_abs; param.online_scale_->mutable_data<float>()[0] = max_abs;
switch (param.round_type_) { // switch (param.round_type_) {
case ROUND_NEAREST_TO_EVEN: // case ROUND_NEAREST_TO_EVEN:
Quantize<ROUND_NEAREST_TO_EVEN>(input, scale, output); // Quantize<ROUND_NEAREST_TO_EVEN>(input, scale, output);
break; // break;
case ROUND_NEAREST_TOWARDS_ZERO: // case ROUND_NEAREST_TOWARDS_ZERO:
Quantize<ROUND_NEAREST_TOWARDS_ZERO>(input, scale, output); // Quantize<ROUND_NEAREST_TOWARDS_ZERO>(input, scale, output);
break; // break;
case ROUND_NEAREST_AWAY_ZERO: // case ROUND_NEAREST_AWAY_ZERO:
Quantize<ROUND_NEAREST_AWAY_ZERO>(input, scale, output); // Quantize<ROUND_NEAREST_AWAY_ZERO>(input, scale, output);
break; // break;
default: // default:
LOG(kLOG_ERROR) << "round type is not supported."; // LOG(kLOG_ERROR) << "round type is not supported.";
break; // break;
} // }
Quantize<ROUND_NEAREST_AWAY_ZERO>(input, max_abs, param.offline_, output);
} }
} // namespace operators } // namespace operators
......
...@@ -74,7 +74,7 @@ template <> ...@@ -74,7 +74,7 @@ template <>
void ReluKernel<CPU, float>::Compute(const ReluParam<CPU> &param) { void ReluKernel<CPU, float>::Compute(const ReluParam<CPU> &param) {
const Tensor *input = param.InputX(); const Tensor *input = param.InputX();
Tensor *output = param.Out(); Tensor *output = param.Out();
ReluCompute<float, Relu>()(input, output); ReluCompute<float, RELU>()(input, output);
} }
template <> template <>
...@@ -86,7 +86,7 @@ template <> ...@@ -86,7 +86,7 @@ template <>
void Relu6Kernel<CPU, float>::Compute(const ReluParam<CPU> &param) { void Relu6Kernel<CPU, float>::Compute(const ReluParam<CPU> &param) {
const Tensor *input = param.InputX(); const Tensor *input = param.InputX();
Tensor *output = param.Out(); Tensor *output = param.Out();
ReluCompute<float, Relu6>()(input, output); ReluCompute<float, RELU6>()(input, output);
} }
} // namespace operators } // namespace operators
......
...@@ -40,28 +40,28 @@ void PoolCompute(const PoolParam<CPU> &param) { ...@@ -40,28 +40,28 @@ void PoolCompute(const PoolParam<CPU> &param) {
if (ksize[0] == 3 && ksize[0] == ksize[1]) { if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max" && strides[0] == strides[1]) { if (pooling_type == "max" && strides[0] == strides[1]) {
if (strides[0] == 1) { if (strides[0] == 1) {
math::Pooling3x3<Max, 1>()(*input, paddings, output); math::Pooling3x3<MAX, 1>()(*input, paddings, output);
} else if (strides[0] == 2) { } else if (strides[0] == 2) {
math::Pooling3x3<Max, 2>()(*input, paddings, output); math::Pooling3x3<MAX, 2>()(*input, paddings, output);
} else { } else {
math::Pooling<Max>()(*input, ksize, strides, paddings, output); math::Pooling<MAX>()(*input, ksize, strides, paddings, output);
} }
} else if (pooling_type == "avg" && strides[0] == strides[1]) { } else if (pooling_type == "avg" && strides[0] == strides[1]) {
if (strides[0] == 1) { if (strides[0] == 1) {
math::Pooling3x3<Avg, 1>()(*input, paddings, output); math::Pooling3x3<AVG, 1>()(*input, paddings, output);
} else if (strides[0] == 2) { } else if (strides[0] == 2) {
math::Pooling3x3<Avg, 2>()(*input, paddings, output); math::Pooling3x3<AVG, 2>()(*input, paddings, output);
} else { } else {
math::Pooling<Avg>()(*input, ksize, strides, paddings, output); math::Pooling<AVG>()(*input, ksize, strides, paddings, output);
} }
} else { } else {
// Others // Others
} }
} else { } else {
if (pooling_type == "max") { if (pooling_type == "max") {
math::Pooling<Max>()(*input, ksize, strides, paddings, output); math::Pooling<MAX>()(*input, ksize, strides, paddings, output);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
math::Pooling<Avg>()(*input, ksize, strides, paddings, output); math::Pooling<AVG>()(*input, ksize, strides, paddings, output);
} else { } else {
// Others // Others
} }
......
...@@ -16,50 +16,109 @@ limitations under the License. */ ...@@ -16,50 +16,109 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <string>
#include "common/enforce.h"
#include "common/types.h" #include "common/types.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h> #include <arm_neon.h>
#include "operators/math/math_func_neon.h"
#endif #endif
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
inline ActivationType GetActivationType(const std::string &type) {
if (type == "sigmoid") {
return ActivationType::SIGMOID;
} else if (type == "relu") {
return ActivationType::RELU;
} else if (type == "tanh") {
return ActivationType::TANH;
} else if (type == "identity" || type == "") {
return ActivationType::IDENTITY;
}
PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type.");
}
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
template <ActivationType Act = Linear> template <ActivationType Act = IDENTITY>
inline float32x4_t vActiveq_f32(const float32x4_t &x) { inline float32x4_t vActiveq_f32(const float32x4_t &x) {
return x; return x;
} }
template <> template <>
inline float32x4_t vActiveq_f32<Relu>(const float32x4_t &x) { inline float32x4_t vActiveq_f32<RELU>(const float32x4_t &x) {
float32x4_t __zero = vdupq_n_f32(0.f); float32x4_t __zero = vdupq_n_f32(0.f);
return vmaxq_f32(x, __zero); return vmaxq_f32(x, __zero);
} }
template <> template <>
inline float32x4_t vActiveq_f32<Relu6>(const float32x4_t &x) { inline float32x4_t vActiveq_f32<RELU6>(const float32x4_t &x) {
float32x4_t __zero = vdupq_n_f32(0.f); float32x4_t __zero = vdupq_n_f32(0.f);
float32x4_t __six = vdupq_n_f32(6.f); float32x4_t __six = vdupq_n_f32(6.f);
return vminq_f32(vmaxq_f32(x, __zero), __six); return vminq_f32(vmaxq_f32(x, __zero), __six);
} }
template <>
inline float32x4_t vActiveq_f32<SIGMOID>(const float32x4_t &x) {
float32x4_t __one = vdupq_n_f32(1.f);
float32x4_t __x = vnegq_f32(x);
__x = exp_ps(__x);
__x = vaddq_f32(__x, __one);
float32x4_t __out = vrecpeq_f32(__x);
return vmulq_f32(vrecpsq_f32(__x, __out), __out);
}
template <>
inline float32x4_t vActiveq_f32<TANH>(const float32x4_t &x) {
float32x4_t __one = vdupq_n_f32(1.f);
float32x4_t __x = vnegq_f32(x);
__x = vmulq_n_f32(__x, 2.f);
__x = exp_ps(__x);
__x = vaddq_f32(__x, __one);
float32x4_t __out = vrecpeq_f32(__x);
__out = vmulq_f32(vrecpsq_f32(__x, __out), __out);
__out = vmulq_n_f32(__out, 2.f);
return vsubq_f32(__out, __one);
}
#endif #endif
template <ActivationType Act = Linear> template <ActivationType Act = IDENTITY>
inline float Active(const float &x) { inline float Active(const float &x) {
return x; return x;
} }
template <> template <>
inline float Active<Relu>(const float &x) { inline float Active<RELU>(const float &x) {
return std::max(x, 0.f); return std::max(x, 0.f);
} }
template <> template <>
inline float Active<Relu6>(const float &x) { inline float Active<RELU6>(const float &x) {
return std::min(std::max(x, 0.f), 6.f); return std::min(std::max(x, 0.f), 6.f);
} }
template <>
inline float Active<SIGMOID>(const float &x) {
// float tmp = x > SIGMOID_THRESHOLD_MAX ? SIGMOID_THRESHOLD_MAX : x;
// tmp = x > SIGMOID_THRESHOLD_MIN ? x : SIGMOID_THRESHOLD_MIN;
// return 1.f / (1.f + exp(-tmp));
return 1.f / (1.f + exp(-x));
}
template <>
inline float Active<TANH>(const float &x) {
// float tmp = -2.f * x;
// tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
// return (2.f / (1.f + exp(tmp))) - 1.f;
return 2.f / (1.f + exp(-2.f * x)) - 1.f;
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <string>
#include "common/enforce.h"
namespace paddle_mobile {
namespace operators {
namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
enum ActivationType {
kSigmoid,
kReLU,
kTanh,
kIdentity,
};
inline ActivationType GetActivationType(const std::string &type) {
if (type == "sigmoid") {
return ActivationType::kSigmoid;
} else if (type == "relu") {
return ActivationType::kReLU;
} else if (type == "tanh") {
return ActivationType::kTanh;
} else if (type == "identity" || type == "") {
return ActivationType::kIdentity;
}
PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type.");
}
namespace forward {
template <typename T>
T Identity(const T a) {
return a;
}
template <typename T>
T Relu(const T a) {
return a > static_cast<T>(0.0) ? a : static_cast<T>(0.0);
}
template <typename T>
T Sigmoid(const T a) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
}
template <typename T>
T Tanh(const T a) {
T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
} // namespace forward
template <typename T>
struct Active {
typedef T (*Act)(T);
};
static Active<float>::Act kActFloat[] = {
&forward::Sigmoid<float>, &forward::Relu<float>, &forward::Tanh<float>,
&forward::Identity<float>};
namespace forward {
inline float activation(float a, int index) { return kActFloat[index](a); }
} // namespace forward
} // namespace math
} // namespace operators
} // namespace paddle_mobile
...@@ -1260,10 +1260,10 @@ void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -1260,10 +1260,10 @@ void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
"q10", "q11", "q12", "q13"); "q10", "q11", "q12", "q13");
} }
/* void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A,
void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A, int int lda, const float *B, int ldb, float beta, float *C,
lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu) { float int ldc, bool relu) {
*bufferC = static_cast<float *>(memory::Alloc(sizeof(float) * n)); float *bufferC = static_cast<float *>(memory::Alloc(sizeof(float) * n));
const float *a0, *b0, *b1, *b2, *b3; const float *a0, *b0, *b1, *b2, *b3;
float *c0, *C0; float *c0, *C0;
...@@ -1482,6 +1482,7 @@ lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu) { float ...@@ -1482,6 +1482,7 @@ lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu) { float
} }
} }
/*
void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C, int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale, float *new_bias) { int ldc, bool relu, float *new_scale, float *new_bias) {
...@@ -2579,9 +2580,8 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, ...@@ -2579,9 +2580,8 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
} }
} }
/* // C = A * B
// C = A * B void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) {
void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) {
int nc1 = n / 16; int nc1 = n / 16;
int _nc1 = n % 16; int _nc1 = n % 16;
int nc2 = _nc1 / 4; int nc2 = _nc1 / 4;
...@@ -2624,13 +2624,13 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, ...@@ -2624,13 +2624,13 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
: :
: [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3) : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5"); : "memory", "q0", "q1", "q2", "q3", "q4", "q5");
} }
// C = alpha * A * B + beta * C // C = alpha * A * B + beta * C
void Gemm::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {} void Gemm::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {}
// C = A * B + C // C = A * B + C
void Gemm::VecWriteWithAdd(int n, float *c, float *C, int ldc) { void Gemm::VecWriteWithAdd(int n, float *c, float *C, int ldc) {
int nc1 = n / 16; int nc1 = n / 16;
int _nc1 = n % 16; int _nc1 = n % 16;
...@@ -2657,18 +2657,18 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, ...@@ -2657,18 +2657,18 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
: [C] "+r"(C), [c] "+r"(c) : [C] "+r"(C), [c] "+r"(c)
: [nc1] "r"(nc1) : [nc1] "r"(nc1)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q11", "q12", "q13"); "q12", "q13");
if (_nc1 != 0) { if (_nc1 != 0) {
for (int j = 0; j < _nc1; j++) { for (int j = 0; j < _nc1; j++) {
*C++ += *c++; *C++ += *c++;
} }
} }
} }
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) { void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
int nc1 = n / 16; int nc1 = n / 16;
int _nc1 = n % 16; int _nc1 = n % 16;
...@@ -2700,8 +2700,8 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, ...@@ -2700,8 +2700,8 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
: [C] "+r"(C), [c] "+r"(c) : [C] "+r"(C), [c] "+r"(c)
: [nc1] "r"(nc1) : [nc1] "r"(nc1)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q11", "q12", "q13"); "q12", "q13");
if (_nc1 != 0) { if (_nc1 != 0) {
for (int j = 0; j < _nc1; j++) { for (int j = 0; j < _nc1; j++) {
...@@ -2713,8 +2713,9 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, ...@@ -2713,8 +2713,9 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
c++; c++;
} }
} }
} }
/*
// C = A * B, batchnorm(C) // C = A * B, batchnorm(C)
void Gemm::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale, void Gemm::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
float *bias) { float *bias) {
...@@ -3149,13 +3150,17 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, ...@@ -3149,13 +3150,17 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) { bool relu, float *bias) {
if (m == 1 && bias == nullptr) {
return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu);
}
#ifdef _OPENMP #ifdef _OPENMP
int max_threads = omp_get_max_threads(); int max_threads = omp_get_max_threads();
#else #else
int max_threads = 1; int max_threads = 1;
#endif #endif
int L1 = 64 / max_threads * 1024; // int L1 = 64 / max_threads * 1024;
int L1 = 32 / max_threads * 1024;
KC = k; KC = k;
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC)); zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC); memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
......
...@@ -105,12 +105,11 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -105,12 +105,11 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *c, float *C, int ldc, float *p, float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1); std::string mode, float *bias, float *bias1);
/*
// 向量矩阵乘法 (M = 1) // 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
bool relu); bool relu);
/*
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float int lda, const float *B, int ldb, float beta, float
*C, int ldc, bool relu, float *new_scale, float *new_bias); *C, int ldc, bool relu, float *new_scale, float *new_bias);
...@@ -149,7 +148,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -149,7 +148,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias1); float *new_scale, float *new_bias, float *bias1);
/*
// 向量矩阵乘法结果回写 // 向量矩阵乘法结果回写
// C = A * B // C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc); void VecWriteBasic(int n, float *c, float *C, int ldc);
...@@ -159,12 +157,13 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -159,12 +157,13 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void VecWriteWithAdd(int n, float *c, float *C, int ldc); void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc); void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
/*
// C = A * B, batchnorm(C) // C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale, void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias); float *new_bias);
// C = A * B, batchnorm(C), relu(C) // C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale, void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float
float *new_bias); *new_scale, float *new_bias);
*/ */
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
...@@ -392,7 +391,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, ...@@ -392,7 +391,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
packedB_int8 = static_cast<int8_t *>( packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__ #if __aarch64__
// TODO() // TODO(paddle mobile)
#else #else
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif #endif
...@@ -414,7 +413,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, ...@@ -414,7 +413,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
packedA_int8 = static_cast<int8_t *>( packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
#if __aarch64__ #if __aarch64__
// TODO() // TODO(paddle mobile)
#else #else
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif #endif
...@@ -438,7 +437,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, ...@@ -438,7 +437,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int8_t *local_A = packedA_int8 + MC * KC * local_threads; int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__ #if __aarch64__
// TODO() // TODO(paddle mobile)
#else #else
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif #endif
...@@ -468,7 +467,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, ...@@ -468,7 +467,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int8_t *local_B = packedB_int8 + KC * NC * local_threads; int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__ #if __aarch64__
// TODO() // TODO(paddle mobile)
#else #else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif #endif
......
...@@ -11,13 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,13 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef GRU_OP #ifdef GRU_OP
#include "operators/math/gru_compute.h" #include "operators/math/gru_compute.h"
#include "common/types.h" #include "common/types.h"
#include "operators/math/activation_functions.h" #include "operators/math/activation.h"
#include "operators/math/gemm.h" #include "operators/math/gemm.h"
#include "operators/math/gru_cpu_kernel.h" #include "operators/math/gru_cpu_kernel.h"
#include "operators/math/gru_kernel.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -43,8 +44,7 @@ struct GRUUnitFunctor<CPU, T> { ...@@ -43,8 +44,7 @@ struct GRUUnitFunctor<CPU, T> {
#endif #endif
} }
forward_reset_output(forward::gru_resetOutput<T>(), value, frame_size, forward_reset_output(value, frame_size, batch_size, active_gate);
batch_size, active_gate);
if (value.prev_out_value) { if (value.prev_out_value) {
#ifdef _OPENMP #ifdef _OPENMP
...@@ -60,8 +60,7 @@ struct GRUUnitFunctor<CPU, T> { ...@@ -60,8 +60,7 @@ struct GRUUnitFunctor<CPU, T> {
#endif #endif
} }
forward_final_output(forward::gru_finalOutput<T>(), value, frame_size, forward_final_output(value, frame_size, batch_size, active_node);
batch_size, active_node);
} }
}; };
......
...@@ -11,7 +11,7 @@ limitations under the License. */ ...@@ -11,7 +11,7 @@ limitations under the License. */
#ifdef GRU_OP #ifdef GRU_OP
#pragma once #pragma once
#include "operators/math/activation_functions.h" #include "operators/math/activation.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -11,21 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,21 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef GRU_OP #ifdef GRU_OP
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "operators/math/activation_functions.h" #include "operators/math/activation.h"
#include "operators/math/gru_compute.h" #include "operators/math/gru_compute.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
template <class OpResetOutput, typename T> template <typename T, ActivationType Act>
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, void hl_naive_gru_forward_reset_output(T *gate_value, T *reset_output_value,
T *gate_value, T *reset_output_value, T *prev_output_value, int frame_size) {
T *prev_output_value, int frame_size,
ActivationType active_gate) {
T r_value_update_gate; T r_value_update_gate;
T r_value_reset_gate; T r_value_reset_gate;
T r_value_reset_output; T r_value_reset_output;
...@@ -33,27 +34,57 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -33,27 +34,57 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
T *update_gate = gate_value; T *update_gate = gate_value;
T *reset_gate = gate_value + frame_size; T *reset_gate = gate_value + frame_size;
for (int i = 0; i < frame_size; i++) { int remain = frame_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = remain >> 3;
remain = remain & 0x7;
float32x4_t prev0 = vdupq_n_f32(0.f);
float32x4_t prev1 = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i) {
float32x4_t update0 = vld1q_f32(update_gate);
float32x4_t update1 = vld1q_f32(update_gate + 4);
float32x4_t reset0 = vld1q_f32(reset_gate);
float32x4_t reset1 = vld1q_f32(reset_gate + 4);
if (prev_output_value) {
prev0 = vld1q_f32(prev_output_value);
prev1 = vld1q_f32(prev_output_value + 4);
prev_output_value += 8;
}
update0 = vActiveq_f32<Act>(update0);
update1 = vActiveq_f32<Act>(update1);
reset0 = vActiveq_f32<Act>(reset0);
reset1 = vActiveq_f32<Act>(reset1);
float32x4_t output0 = vmulq_f32(prev0, reset0);
float32x4_t output1 = vmulq_f32(prev1, reset1);
vst1q_f32(update_gate, update0);
vst1q_f32(update_gate + 4, update1);
vst1q_f32(reset_gate, reset0);
vst1q_f32(reset_gate + 4, reset1);
vst1q_f32(reset_output_value, output0);
vst1q_f32(reset_output_value + 4, output1);
update_gate += 8;
reset_gate += 8;
reset_output_value += 8;
}
#endif // __ARM_NEON__
for (int i = 0; i < remain; i++) {
r_value_update_gate = update_gate[i]; r_value_update_gate = update_gate[i];
r_value_reset_gate = reset_gate[i]; r_value_reset_gate = reset_gate[i];
if (prev_output_value) { if (prev_output_value) {
r_prev_out = prev_output_value[i]; r_prev_out = prev_output_value[i];
} }
r_value_update_gate = Active<Act>(r_value_update_gate);
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, r_value_reset_gate = Active<Act>(r_value_reset_gate);
&r_value_reset_output, active_gate); r_value_reset_output = r_prev_out * r_value_reset_gate;
update_gate[i] = r_value_update_gate; update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate; reset_gate[i] = r_value_reset_gate;
reset_output_value[i] = r_value_reset_output; reset_output_value[i] = r_value_reset_output;
} }
} }
template <class OpFinalOutput, typename T> template <typename T, ActivationType Act>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, void hl_naive_gru_forward_final_output(T *gate_value, T *prev_output_value,
T *gate_value, T *prev_output_value, T *output_value, int frame_size) {
T *output_value, int frame_size,
ActivationType active_node) {
T r_value_update_gate; T r_value_update_gate;
T r_value_frame_state; T r_value_frame_state;
T r_prev_out = 0; T r_prev_out = 0;
...@@ -61,30 +92,73 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -61,30 +92,73 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *update_gate = gate_value; T *update_gate = gate_value;
T *frame_state = gate_value + frame_size * 2; T *frame_state = gate_value + frame_size * 2;
for (int i = 0; i < frame_size; i++) { int remain = frame_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = remain >> 3;
remain = remain & 0x7;
float32x4_t prev0 = vdupq_n_f32(0.f);
float32x4_t prev1 = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i) {
float32x4_t update0 = vld1q_f32(update_gate);
float32x4_t update1 = vld1q_f32(update_gate + 4);
float32x4_t state0 = vld1q_f32(frame_state);
float32x4_t state1 = vld1q_f32(frame_state + 4);
if (prev_output_value) {
prev0 = vld1q_f32(prev_output_value);
prev1 = vld1q_f32(prev_output_value + 4);
prev_output_value += 8;
}
state0 = vActiveq_f32<Act>(state0);
state1 = vActiveq_f32<Act>(state1);
float32x4_t output0 = vmlsq_f32(prev0, update0, prev0);
float32x4_t output1 = vmlsq_f32(prev1, update1, prev1);
output0 = vmlaq_f32(output0, update0, state0);
output1 = vmlaq_f32(output1, update1, state1);
vst1q_f32(frame_state, state0);
vst1q_f32(frame_state + 4, state1);
vst1q_f32(output_value, output0);
vst1q_f32(output_value + 4, output1);
update_gate += 8;
frame_state += 8;
output_value += 8;
}
#endif // __ARM_NEON__
for (int i = 0; i < remain; i++) {
r_value_update_gate = update_gate[i]; r_value_update_gate = update_gate[i];
r_value_frame_state = frame_state[i]; r_value_frame_state = frame_state[i];
if (prev_output_value) { if (prev_output_value) {
r_prev_out = prev_output_value[i]; r_prev_out = prev_output_value[i];
} }
r_value_frame_state = Active<Act>(r_value_frame_state);
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, r_output = r_prev_out - r_value_update_gate * r_prev_out +
&r_output, active_node); r_value_update_gate * r_value_frame_state;
frame_state[i] = r_value_frame_state; frame_state[i] = r_value_frame_state;
output_value[i] = r_output; output_value[i] = r_output;
} }
} }
template <class OpResetOutput, typename T> #define FORWARD_RESET_OUTPUT(active_type, value, frame_size) \
inline void forward_reset_output(OpResetOutput op_reset_output, hl_naive_gru_forward_reset_output<float, active_type>( \
GRUMetaValue<T> value, int frame_size, value.gate_value, value.reset_output_value, value.prev_out_value, \
int batch_size, ActivationType active_gate) { frame_size);
for (int b = 0; b < batch_size; b++) {
hl_naive_gru_forward_reset_output(
op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate);
template <typename T>
inline void forward_reset_output(GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) {
for (int b = 0; b < batch_size; ++b) {
switch (active_node) {
case RELU:
FORWARD_RESET_OUTPUT(RELU, value, frame_size);
break;
case SIGMOID:
FORWARD_RESET_OUTPUT(SIGMOID, value, frame_size);
break;
case TANH:
FORWARD_RESET_OUTPUT(TANH, value, frame_size);
break;
default:
FORWARD_RESET_OUTPUT(IDENTITY, value, frame_size);
}
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
value.reset_output_value += frame_size; value.reset_output_value += frame_size;
if (value.prev_out_value) { if (value.prev_out_value) {
...@@ -93,15 +167,27 @@ inline void forward_reset_output(OpResetOutput op_reset_output, ...@@ -93,15 +167,27 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
} }
} }
template <class OpFinalOutput, typename T> #define FORWARD_FINAL_OUTPUT(active_type, value, frame_size) \
inline void forward_final_output(OpFinalOutput op_final_output, hl_naive_gru_forward_final_output<float, active_type>( \
GRUMetaValue<T> value, int frame_size, value.gate_value, value.prev_out_value, value.output_value, frame_size)
int batch_size, ActivationType active_node) {
for (int b = 0; b < batch_size; b++) {
hl_naive_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value,
frame_size, active_node);
template <typename T>
inline void forward_final_output(GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) {
for (int b = 0; b < batch_size; ++b) {
switch (active_node) {
case RELU:
FORWARD_FINAL_OUTPUT(RELU, value, frame_size);
break;
case SIGMOID:
FORWARD_FINAL_OUTPUT(SIGMOID, value, frame_size);
break;
case TANH:
FORWARD_FINAL_OUTPUT(TANH, value, frame_size);
break;
default:
FORWARD_FINAL_OUTPUT(IDENTITY, value, frame_size);
}
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
value.output_value += frame_size; value.output_value += frame_size;
if (value.prev_out_value) { if (value.prev_out_value) {
...@@ -113,4 +199,5 @@ inline void forward_final_output(OpFinalOutput op_final_output, ...@@ -113,4 +199,5 @@ inline void forward_final_output(OpFinalOutput op_final_output,
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif #endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_OP
#pragma once
#include <type_traits>
#include "operators/math/activation_functions.h"
namespace paddle_mobile {
namespace operators {
namespace math {
namespace forward {
template <typename T>
class gru_resetOutput {
public:
void operator()(T *value_update_gate, T *value_reset_gate, T *prev_out,
T *value_reset_output, ActivationType act_gate) {
*value_update_gate = activation(*value_update_gate, act_gate);
*value_reset_gate = activation(*value_reset_gate, act_gate);
*value_reset_output = (*prev_out) * (*value_reset_gate);
}
};
template <typename T>
class gru_finalOutput {
public:
void operator()(T *value_update_gate, T *value_frame_state, T *prev_out,
T *value_output, ActivationType act_input) {
*value_frame_state = activation(*value_frame_state, act_input);
*value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
((*value_update_gate) * (*value_frame_state));
}
};
} // namespace forward
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -72,8 +72,8 @@ void Pooling<P>::operator()(const framework::Tensor &input, ...@@ -72,8 +72,8 @@ void Pooling<P>::operator()(const framework::Tensor &input,
} }
} }
template struct Pooling<Max>; template struct Pooling<MAX>;
template struct Pooling<Avg>; template struct Pooling<AVG>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -30,7 +30,7 @@ namespace paddle_mobile { ...@@ -30,7 +30,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
template <PoolingType P = Max> template <PoolingType P = MAX>
struct PoolingVal { struct PoolingVal {
float val; float val;
int count; int count;
...@@ -44,11 +44,11 @@ struct PoolingVal { ...@@ -44,11 +44,11 @@ struct PoolingVal {
}; };
template <> template <>
struct PoolingVal<Avg> { struct PoolingVal<AVG> {
float val; float val;
int count; int count;
PoolingVal() : val(0.f), count(0) {} PoolingVal() : val(0.f), count(0) {}
inline PoolingVal<Avg> &operator+=(const float &x) { inline PoolingVal<AVG> &operator+=(const float &x) {
val += x; val += x;
++count; ++count;
return *this; return *this;
...@@ -57,57 +57,57 @@ struct PoolingVal<Avg> { ...@@ -57,57 +57,57 @@ struct PoolingVal<Avg> {
}; };
#if defined(__ARM_NEON) || defined(__ARM_NEON__) #if defined(__ARM_NEON) || defined(__ARM_NEON__)
template <PoolingType P = Max> template <PoolingType P = MAX>
inline float32x4_t vPoolInitq_f32() { inline float32x4_t vPoolInitq_f32() {
return vdupq_n_f32(-std::numeric_limits<float>::max()); return vdupq_n_f32(-std::numeric_limits<float>::max());
} }
template <> template <>
inline float32x4_t vPoolInitq_f32<Avg>() { inline float32x4_t vPoolInitq_f32<AVG>() {
return vdupq_n_f32(0.f); return vdupq_n_f32(0.f);
} }
template <PoolingType P = Max> template <PoolingType P = MAX>
inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) { inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) {
return vmaxq_f32(x1, x2); return vmaxq_f32(x1, x2);
} }
template <> template <>
inline float32x4_t vPoolPreq_f32<Avg>(const float32x4_t &x1, inline float32x4_t vPoolPreq_f32<AVG>(const float32x4_t &x1,
const float32x4_t &x2) { const float32x4_t &x2) {
return vaddq_f32(x1, x2); return vaddq_f32(x1, x2);
} }
template <PoolingType P = Max> template <PoolingType P = MAX>
inline float32x4_t vPoolPostq_f32(const float32x4_t &x, inline float32x4_t vPoolPostq_f32(const float32x4_t &x,
const float32x4_t &post) { const float32x4_t &post) {
return x; return x;
} }
template <> template <>
inline float32x4_t vPoolPostq_f32<Avg>(const float32x4_t &x, inline float32x4_t vPoolPostq_f32<AVG>(const float32x4_t &x,
const float32x4_t &post) { const float32x4_t &post) {
return vmulq_f32(x, post); return vmulq_f32(x, post);
} }
#endif // __ARM_NEON__ #endif // __ARM_NEON__
template <PoolingType P = Max> template <PoolingType P = MAX>
inline float PoolPre(const float &x1, const float &x2) { inline float PoolPre(const float &x1, const float &x2) {
return std::max(x1, x2); return std::max(x1, x2);
} }
template <> template <>
inline float PoolPre<Avg>(const float &x1, const float &x2) { inline float PoolPre<AVG>(const float &x1, const float &x2) {
return x1 + x2; return x1 + x2;
} }
template <PoolingType P = Max> template <PoolingType P = MAX>
inline float PoolPost(const float &x, const float &post) { inline float PoolPost(const float &x, const float &post) {
return x; return x;
} }
template <> template <>
inline float PoolPost<Avg>(const float &x, const float &post) { inline float PoolPost<AVG>(const float &x, const float &post) {
return x * post; return x * post;
} }
......
...@@ -1016,10 +1016,10 @@ struct Pooling3x3<P, 2> { ...@@ -1016,10 +1016,10 @@ struct Pooling3x3<P, 2> {
} }
}; };
template struct Pooling3x3<Max, 1>; template struct Pooling3x3<MAX, 1>;
template struct Pooling3x3<Avg, 1>; template struct Pooling3x3<AVG, 1>;
template struct Pooling3x3<Max, 2>; template struct Pooling3x3<MAX, 2>;
template struct Pooling3x3<Avg, 2>; template struct Pooling3x3<AVG, 2>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -74,11 +74,11 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { ...@@ -74,11 +74,11 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
output_cmp.mutable_data<float>(output->dims()); output_cmp.mutable_data<float>(output->dims());
if (pooling_type == "avg") { if (pooling_type == "avg") {
math::Pooling<Avg>()(*input, std::vector<int>{kernel_h, kernel_w}, math::Pooling<AVG>()(*input, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, std::vector<int>{stride_h, stride_w},
std::vector<int>{pad_h, pad_w}, &output_cmp); std::vector<int>{pad_h, pad_w}, &output_cmp);
} else { } else {
math::Pooling<Max>()(*input, std::vector<int>{kernel_h, kernel_w}, math::Pooling<MAX>()(*input, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, std::vector<int>{stride_h, stride_w},
std::vector<int>{pad_h, pad_w}, &output_cmp); std::vector<int>{pad_h, pad_w}, &output_cmp);
} }
...@@ -117,57 +117,57 @@ int main(int argc, char *argv[]) { ...@@ -117,57 +117,57 @@ int main(int argc, char *argv[]) {
int in_channels = atoi(argv[1]); int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]); int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]); int in_width = atoi(argv[3]);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=0, stride=1"; << "float, pooling_type=max, kernel=3, pad=0, stride=1";
// paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=1, stride=1"; << "float, pooling_type=max, kernel=3, pad=1, stride=1";
// paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=2, stride=1"; << "float, pooling_type=max, kernel=3, pad=2, stride=1";
// paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=5, stride=1"; << "float, pooling_type=max, kernel=3, pad=5, stride=1";
// paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=0, stride=1"; << "float, pooling_type=avg, kernel=3, pad=0, stride=1";
// paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=1, stride=1"; << "float, pooling_type=avg, kernel=3, pad=1, stride=1";
// paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=2, stride=1"; << "float, pooling_type=avg, kernel=3, pad=2, stride=1";
// paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=5, stride=1"; << "float, pooling_type=avg, kernel=3, pad=5, stride=1";
// paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=0, stride=2"; << "float, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<0, 3, 0, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 3, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=1, stride=2"; << "float, pooling_type=max, kernel=3, pad=1, stride=2";
// paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=2, stride=2"; << "float, pooling_type=max, kernel=3, pad=2, stride=2";
// paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=5, stride=2"; << "float, pooling_type=max, kernel=3, pad=5, stride=2";
// paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=0, stride=2"; << "float, pooling_type=avg, kernel=3, pad=0, stride=2";
// paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=1, stride=2"; << "float, pooling_type=avg, kernel=3, pad=1, stride=2";
// paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=2, stride=2"; << "float, pooling_type=avg, kernel=3, pad=2, stride=2";
// paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=5, stride=2"; << "float, pooling_type=avg, kernel=3, pad=5, stride=2";
// paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width);
// // kernel = 5, pad = 0, stride = 1 // // kernel = 5, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册