提交 d760b6a5 编写于 作者: Q qingqing01

Refine the activation type getting in the LSTM operator to speed.

上级 d3c42f7d
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/operators/math/lstm_compute.h" #include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h" #include "paddle/operators/math/sequence2batch.h"
#include "paddle/operators/math/detail/activation_functions.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -102,9 +103,12 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -102,9 +103,12 @@ class LSTMKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto gate_act = ctx.Attr<std::string>("gate_activation"); auto gate_act = math::detail::GetActivationType(
auto cell_act = ctx.Attr<std::string>("cell_activation"); ctx.Attr<std::string>("gate_activation"));
auto cand_act = ctx.Attr<std::string>("candidate_activation"); auto cell_act = math::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
...@@ -264,9 +268,12 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -264,9 +268,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace()); batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod()); batch_gate_g.set_lod(batch_gate->lod());
auto gate_act = ctx.Attr<std::string>("gate_activation"); auto gate_act = math::detail::GetActivationType(
auto cell_act = ctx.Attr<std::string>("cell_activation"); ctx.Attr<std::string>("gate_activation"));
auto cand_act = ctx.Attr<std::string>("candidate_activation"); auto cell_act = math::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <math.h> #include <math.h>
#include "paddle/platform/hostdevice.h" #include "paddle/platform/hostdevice.h"
#include "paddle/platform/enforce.h"
#ifdef __AVX__ #ifdef __AVX__
#include <immintrin.h> #include <immintrin.h>
...@@ -29,6 +30,27 @@ namespace detail { ...@@ -29,6 +30,27 @@ namespace detail {
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0 #define EXP_MAX_INPUT 40.0
enum ActivationType {
kSigmoid,
kReLU,
kTanh,
kIdentity,
};
inline ActivationType GetActivationType (const std::string &type) {
if (type == "sigmoid") {
return ActivationType::kSigmoid;
} else if (type == "relu") {
return ActivationType::kReLU;
} else if (type == "tanh") {
return ActivationType::kTanh;
} else if (type == "identity") {
return ActivationType::kIdentity;
}
PADDLE_THROW("Not support type %s.", type);
}
namespace forward { namespace forward {
template <typename T> template <typename T>
......
...@@ -27,9 +27,9 @@ namespace detail { ...@@ -27,9 +27,9 @@ namespace detail {
template <class T, class Op> template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, int frame_size,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
T r_value_in; T r_value_in;
T r_value_ig; T r_value_ig;
T r_value_fg; T r_value_fg;
...@@ -77,9 +77,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -77,9 +77,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
T r_value_in; T r_value_in;
T r_value_ig; T r_value_ig;
T r_value_fg; T r_value_fg;
...@@ -150,9 +150,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -150,9 +150,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, int frame_size,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_in; __m256 r_value_in;
__m256 r_value_ig; __m256 r_value_ig;
...@@ -204,9 +204,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -204,9 +204,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_in; __m256 r_value_in;
__m256 r_value_ig; __m256 r_value_ig;
...@@ -281,9 +281,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -281,9 +281,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size, void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node, avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
active_gate, active_state); active_gate, active_state);
...@@ -295,9 +295,9 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size, ...@@ -295,9 +295,9 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
template <class T, class Op> template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad, void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, activation_mode_t active_node, int frame_size, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, active_node, avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, active_node,
active_gate, active_state); active_gate, active_state);
......
...@@ -31,9 +31,9 @@ namespace detail { ...@@ -31,9 +31,9 @@ namespace detail {
*/ */
template <class T, class Op, bool is_batch> template <class T, class Op, bool is_batch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
int batch_size, activation_mode_t active_node, int batch_size, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
...@@ -91,9 +91,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, ...@@ -91,9 +91,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
template <class T, class Op, bool is_batch> template <class T, class Op, bool is_batch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
int batch_size, activation_mode_t active_node, int batch_size, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
...@@ -185,9 +185,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -185,9 +185,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op, void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batch_size == 1) { if (batch_size == 1) {
...@@ -220,9 +220,9 @@ template <class T, class Op> ...@@ -220,9 +220,9 @@ template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op, void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batch_size == 1) { if (batch_size == 1) {
......
...@@ -30,9 +30,9 @@ class lstm { ...@@ -30,9 +30,9 @@ class lstm {
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &prev_state, T &state, T &state_atv, T &output, T &prev_state, T &state, T &state_atv, T &output,
T &checkI, T &checkF, T &checkO, T &checkI, T &checkF, T &checkO,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
value_in = activation(value_in, active_node); value_in = activation(value_in, active_node);
value_ig = activation(value_ig + prev_state * checkI, active_gate); value_ig = activation(value_ig + prev_state * checkI, active_gate);
value_fg = activation(value_fg + prev_state * checkF, active_gate); value_fg = activation(value_fg + prev_state * checkF, active_gate);
...@@ -53,9 +53,9 @@ class lstm { ...@@ -53,9 +53,9 @@ class lstm {
__m256 &prev_state, __m256 &state, __m256 &prev_state, __m256 &state,
__m256 &state_atv, __m256 &output, __m256 &checkI, __m256 &state_atv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO, __m256 &checkF, __m256 &checkO,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
value_in = activation(value_in, active_node); value_in = activation(value_in, active_node);
value_ig = value_ig =
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)), activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)),
...@@ -87,9 +87,9 @@ class lstm { ...@@ -87,9 +87,9 @@ class lstm {
T &state_grad, T &state_atv, T &output_grad, T &state_grad, T &state_atv, T &output_grad,
T &checkI, T &checkF, T &checkO, T &checkIGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad, T &checkFGrad, T &checkOGrad,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
grad_og = activation(output_grad * state_atv, value_og, active_gate); grad_og = activation(output_grad * state_atv, value_og, active_gate);
state_grad += activation(output_grad * value_og, state_atv, active_state) + state_grad += activation(output_grad * value_og, state_atv, active_state) +
grad_og * checkO; grad_og * checkO;
...@@ -114,8 +114,8 @@ class lstm { ...@@ -114,8 +114,8 @@ class lstm {
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state, __m256 &prev_state, __m256 &prev_state_grad, __m256 &state,
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad, __m256 &state_grad, __m256 &state_atv, __m256 &output_grad,
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_node, __m256 &checkFGrad, __m256 &checkOGrad, ActivationType active_node,
activation_mode_t active_gate, activation_mode_t active_state) { ActivationType active_gate, ActivationType active_state) {
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og, grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og,
active_gate); active_gate);
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og), state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og),
......
...@@ -24,12 +24,12 @@ template <class T> ...@@ -24,12 +24,12 @@ template <class T>
struct LstmUnitFunctor<platform::CPUDeviceContext, T> { struct LstmUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context, static void compute(const platform::CPUDeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act, const detail::ActivationType& gate_act,
const std::string& cand_act) { const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size, detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
ActiveType(cand_act), ActiveType(gate_act), cand_act, gate_act, cell_act);
ActiveType(cell_act));
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
value.state_active_value += frame_size; value.state_active_value += frame_size;
...@@ -46,12 +46,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -46,12 +46,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context, static void compute(const platform::CPUDeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act, const detail::ActivationType& gate_act,
const std::string& cand_act) { const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad, detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frame_size, ActiveType(cand_act), frame_size, cand_act, gate_act, cell_act);
ActiveType(gate_act), ActiveType(cell_act));
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
......
...@@ -24,11 +24,12 @@ template <class T> ...@@ -24,11 +24,12 @@ template <class T>
struct LstmUnitFunctor<platform::CUDADeviceContext, T> { struct LstmUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context, static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act, const detail::ActivationType& gate_act,
const std::string& cand_act) { const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value, detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, ActiveType(cand_act), frame_size, batch_size, cand_act,
ActiveType(gate_act), ActiveType(cell_act)); gate_act, cell_act);
} }
}; };
...@@ -37,11 +38,12 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -37,11 +38,12 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context, static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act, const detail::ActivationType& gate_act,
const std::string& cand_act) { const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad, detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, ActiveType(cand_act), frame_size, batch_size, cand_act,
ActiveType(gate_act), ActiveType(cell_act)); gate_act, cell_act);
} }
}; };
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/operators/math/detail/activation_functions.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,6 +30,7 @@ typedef enum { ...@@ -29,6 +30,7 @@ typedef enum {
HL_ACTIVATION_END HL_ACTIVATION_END
} activation_mode_t; } activation_mode_t;
template <class T> template <class T>
struct LstmMetaValue { struct LstmMetaValue {
T *gate_value; T *gate_value;
...@@ -72,8 +74,9 @@ class LstmUnitFunctor { ...@@ -72,8 +74,9 @@ class LstmUnitFunctor {
public: public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value, static void compute(const DeviceContext &context, LstmMetaValue<T> value,
int frame_size, int batch_size, int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act, const detail::ActivationType &gate_act,
const std::string &cand_act); const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -81,8 +84,9 @@ class LstmUnitGradFunctor { ...@@ -81,8 +84,9 @@ class LstmUnitGradFunctor {
public: public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value, static void compute(const DeviceContext &context, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, int batch_size, LstmMetaGrad<T> grad, int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act, const detail::ActivationType &gate_act,
const std::string &cand_act); const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act);
}; };
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册