提交 d760b6a5 编写于 作者: Q qingqing01

Refine the activation type getting in the LSTM operator to speed.

上级 d3c42f7d
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
#include "paddle/operators/math/detail/activation_functions.h"
namespace paddle {
namespace operators {
......@@ -102,9 +103,12 @@ class LSTMKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto gate_act = ctx.Attr<std::string>("gate_activation");
auto cell_act = ctx.Attr<std::string>("cell_activation");
auto cand_act = ctx.Attr<std::string>("candidate_activation");
auto gate_act = math::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation"));
auto cell_act = math::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
......@@ -264,9 +268,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod());
auto gate_act = ctx.Attr<std::string>("gate_activation");
auto cell_act = ctx.Attr<std::string>("cell_activation");
auto cand_act = ctx.Attr<std::string>("candidate_activation");
auto gate_act = math::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation"));
auto cell_act = math::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <math.h>
#include "paddle/platform/hostdevice.h"
#include "paddle/platform/enforce.h"
#ifdef __AVX__
#include <immintrin.h>
......@@ -29,6 +30,27 @@ namespace detail {
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
enum ActivationType {
kSigmoid,
kReLU,
kTanh,
kIdentity,
};
inline ActivationType GetActivationType (const std::string &type) {
if (type == "sigmoid") {
return ActivationType::kSigmoid;
} else if (type == "relu") {
return ActivationType::kReLU;
} else if (type == "tanh") {
return ActivationType::kTanh;
} else if (type == "identity") {
return ActivationType::kIdentity;
}
PADDLE_THROW("Not support type %s.", type);
}
namespace forward {
template <typename T>
......
......@@ -27,9 +27,9 @@ namespace detail {
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
T r_value_in;
T r_value_ig;
T r_value_fg;
......@@ -77,9 +77,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
T r_value_in;
T r_value_ig;
T r_value_fg;
......@@ -150,9 +150,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
#ifdef __AVX__
__m256 r_value_in;
__m256 r_value_ig;
......@@ -204,9 +204,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
#ifdef __AVX__
__m256 r_value_in;
__m256 r_value_ig;
......@@ -281,9 +281,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
active_gate, active_state);
......@@ -295,9 +295,9 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
int frame_size, ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, active_node,
active_gate, active_state);
......
......@@ -31,9 +31,9 @@ namespace detail {
*/
template <class T, class Op, bool is_batch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
int batch_size, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
int batch_size, ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
......@@ -91,9 +91,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
template <class T, class Op, bool is_batch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size,
int batch_size, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
int batch_size, ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
......@@ -185,9 +185,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
dim3 threads;
dim3 grid;
if (batch_size == 1) {
......@@ -220,9 +220,9 @@ template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
dim3 threads;
dim3 grid;
if (batch_size == 1) {
......
......@@ -30,9 +30,9 @@ class lstm {
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &prev_state, T &state, T &state_atv, T &output,
T &checkI, T &checkF, T &checkO,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
value_in = activation(value_in, active_node);
value_ig = activation(value_ig + prev_state * checkI, active_gate);
value_fg = activation(value_fg + prev_state * checkF, active_gate);
......@@ -53,9 +53,9 @@ class lstm {
__m256 &prev_state, __m256 &state,
__m256 &state_atv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
value_in = activation(value_in, active_node);
value_ig =
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)),
......@@ -87,9 +87,9 @@ class lstm {
T &state_grad, T &state_atv, T &output_grad,
T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
grad_og = activation(output_grad * state_atv, value_og, active_gate);
state_grad += activation(output_grad * value_og, state_atv, active_state) +
grad_og * checkO;
......@@ -114,8 +114,8 @@ class lstm {
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state,
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad,
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_state) {
__m256 &checkFGrad, __m256 &checkOGrad, ActivationType active_node,
ActivationType active_gate, ActivationType active_state) {
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og,
active_gate);
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og),
......
......@@ -24,12 +24,12 @@ template <class T>
struct LstmUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
ActiveType(cand_act), ActiveType(gate_act),
ActiveType(cell_act));
cand_act, gate_act, cell_act);
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
......@@ -46,12 +46,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frame_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
frame_size, cand_act, gate_act, cell_act);
value.gate_value += frame_size * 4;
value.state_value += frame_size;
......
......@@ -24,11 +24,12 @@ template <class T>
struct LstmUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
frame_size, batch_size, cand_act,
gate_act, cell_act);
}
};
......@@ -37,11 +38,12 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
frame_size, batch_size, cand_act,
gate_act, cell_act);
}
};
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/operators/math/detail/activation_functions.h"
namespace paddle {
namespace operators {
......@@ -29,6 +30,7 @@ typedef enum {
HL_ACTIVATION_END
} activation_mode_t;
template <class T>
struct LstmMetaValue {
T *gate_value;
......@@ -72,8 +74,9 @@ class LstmUnitFunctor {
public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value,
int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act,
const std::string &cand_act);
const detail::ActivationType &gate_act,
const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act);
};
template <typename DeviceContext, typename T>
......@@ -81,8 +84,9 @@ class LstmUnitGradFunctor {
public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act,
const std::string &cand_act);
const detail::ActivationType &gate_act,
const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act);
};
} // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册