提交 3f1062d7 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #4929 from qingqing01/lstm

Forward implementation for LSTM operator.
...@@ -115,7 +115,8 @@ set(DEPS_OPS ...@@ -115,7 +115,8 @@ set(DEPS_OPS
softmax_with_cross_entropy_op softmax_with_cross_entropy_op
sum_op sum_op
pool_op pool_op
pool_with_index_op) pool_with_index_op
lstm_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
...@@ -126,6 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) ...@@ -126,6 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op) op_library(sum_op DEPS net_op)
op_library(pool_op DEPS pooling) op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling) op_library(pool_with_index_op DEPS pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lstm_op.h"
namespace paddle {
namespace operators {
class LSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
int frame_size = x_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"The rank of Input(Weight) should be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], frame_size,
"The first dimension of Input(Weight) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
"The second dimension of Input(Weight) "
"should be 4 * %d.",
frame_size);
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
if (ctx->Attrs().Get<bool>("usePeepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection",
frame_size);
} else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection",
frame_size);
}
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
ctx->SetOutputDim("Cell", {x_dims[0], frame_size});
ctx->SetOutputDim("BatchGate", x_dims);
ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell");
}
};
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where, T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.");
AddInput("C0",
"(Tensor, optional) the initial cell state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size. `H0` and `C0` can be NULL but only at the same time");
AddInput("Weight",
"(Tensor) the learnable hidden-hidden weights."
" - The shape is (D x 4D), where D is the hidden size. "
" - Weight = {W_ch, W_ih, W_fh, W_oh}");
AddInput("Bias",
"(Tensor) the learnable weights, which contains two parts: "
"input-hidden bias weight and peephole connections weight if "
"setting `usePeepholes` True. "
"1. `usePeepholes = False` "
" - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `usePeepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which "
"was also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence "
"in the raw input.")
.AsIntermediate();
AddOutput("Hidden",
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddAttr<bool>("usePeepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
.SetDefault(true);
AddAttr<bool>("isReverse",
"(bool, defalut: False) "
"whether to compute reversed LSTM.")
.SetDefault(false);
AddAttr<std::string>(
"gateActivation",
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.")
.SetDefault("sigmoid");
AddAttr<std::string>("cellActivation",
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh");
AddAttr<std::string>("candidateActivation",
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by default.")
.SetDefault("tanh");
AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator
The defalut implementation is diagonal/peephole connection [1], the formula is
as follows
i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i)
f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f)
\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o)
c_t = f_t ⊙ c_{t-1} + i_t ⊙ \tilde{c_t}
h_t = o_t ⊙ act_h(c_t)
where the W terms denote weight matrices (e.g. \f$W_{xi}\f$ is the matrix
of weights from the input gate to the input), \f$W_{ic}, W_{fc}, W_{oc}\f$
are diagonal weight matrices for peephole connections. In our implenmention,
We use vectors to reprenset these diagonal weight matrices. The b terms
denote bias vectors (\f$b_i\f$ is the input gate bias vector), \f$\sigma\f$
is the non-line actications, such as logistic sigmoid function, and
\f$i, f, o\f$ and \f$c\f$ are respectively the input gate, forget gate,
output gate and cell activation vectors, all of which are the same size as
the cell output activation vector \f$h\f$.
The ⊙ is the element-wise product of the vectors, \f$act_g\f$ and \f$act_h\f$
are the cell input and cell output activation functions, `tanh` is usually
used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Set `usePeepholes` False to disable peephole connection [2]. The formula
is omitted here.
@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
operations on the input x_{t} were NOT included in this operator.
Users can choose to use fully-connect operator before LSTM operator.
[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory
recurrent neural network architectures for large scale acoustic modeling.
INTERSPEECH, 2014.
[2] S. Hochreiter and J. Schmidhuber. Long Short-Term Memory.
Neural Computation, 9(8):1735-1780, 1997.
)DOC");
}
};
class LSTMGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(Hidden@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
"Input(Cell@GRAD) should not be null");
ctx->SetOutputDim(framework::GradVarName("Weight"),
ctx->GetInputDim("Weight"));
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(lstm, ops::LSTMOp, ops::LSTMOpMaker, lstm_grad, ops::LSTMGradOp);
REGISTER_OP_CPU_KERNEL(lstm, ops::LSTMKernel<paddle::platform::CPUPlace, float>,
ops::LSTMKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(lstm_grad,
ops::LSTMGradKernel<paddle::platform::CPUPlace, float>,
ops::LSTMGradKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/lstm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lstm, ops::LSTMKernel<paddle::platform::GPUPlace, float>,
ops::LSTMKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(lstm_grad,
ops::LSTMGradKernel<paddle::platform::GPUPlace, float>,
ops::LSTMGradKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
using framework::LoDTensor;
using framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::LoDTensor>("Input");
auto* weight = ctx.Input<framework::Tensor>("Weight");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* batch_gate = ctx.Output<framework::LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace());
auto* hidden_out = ctx.Output<framework::LoDTensor>("Hidden");
hidden_out->mutable_data<T>(ctx.GetPlace());
auto* cell_out = ctx.Output<framework::LoDTensor>("Cell");
cell_out->mutable_data<T>(ctx.GetPlace());
// Now the function ShareLoD in InferShape is not implemented.
// So copy LoD here.
ctx.ShareLoD("Input", "Hidden");
ctx.ShareLoD("Input", "Cell");
bool is_reverse = ctx.Attr<bool>("isReverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
to_batch(ctx.device_context(), *input, *batch_gate, is_reverse);
auto in_dims = input->dims();
int frame_size = static_cast<int>(in_dims[1] / 4);
framework::DDim dims({in_dims[0], frame_size});
if (bias) {
Eigen::array<int, 2> extents({{1, 4 * frame_size}});
Eigen::array<int, 2> offsets({{0, 0}});
auto b = EigenMatrix<T>::From(*bias);
auto gate = EigenMatrix<T>::From(*batch_gate);
gate.device(ctx.GetEigenDevice<Place>()) =
gate +
b.slice(offsets, extents)
.reshape(Eigen::array<int, 2>({{1, frame_size * 4}}))
.broadcast(
Eigen::array<int, 2>({{static_cast<int>(in_dims[0]), 1}}));
}
math::LstmMetaValue<T> lstm_value;
T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
lstm_value.prevStateValue = nullptr;
framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act;
batch_out.mutable_data<T>(dims, ctx.GetPlace());
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto gate_act = ctx.Attr<std::string>("gateActivation");
auto cell_act = ctx.Attr<std::string>("cellActivation");
auto cand_act = ctx.Attr<std::string>("candidateActivation");
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor out_t = batch_out.Slice(bstart, bend);
Tensor cell_t = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend);
int cur_batch_size = bend - bstart;
if (n != 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false,
*weight, false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
}
// else if : FIXME support the initial hidden and cell
lstm_value.gateValue = gate_t.data<T>();
lstm_value.outputValue = out_t.data<T>();
lstm_value.stateValue = cell_t.data<T>();
lstm_value.stateActiveValue = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<Place, T>::compute(ctx.device_context(), lstm_value,
frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
lstm_value.prevStateValue = lstm_value.stateValue;
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_out.set_lod(batch_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden
to_seq(ctx.device_context(), batch_out, *hidden_out);
batch_cell.set_lod(batch_gate->lod());
// restore the output cell state in LoDTensor from the batch cell
to_seq(ctx.device_context(), batch_cell, *cell_out);
}
};
template <typename Place, typename T>
class LSTMGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
};
} // namespace operators
} // namespace paddle
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::LoDTensor;
using framework::Tensor; using framework::Tensor;
template <typename T> template <typename T>
......
add_subdirectory(detail)
if(WITH_GPU) if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor)
...@@ -7,6 +9,8 @@ if(WITH_GPU) ...@@ -7,6 +9,8 @@ if(WITH_GPU)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
else() else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
...@@ -14,6 +18,8 @@ else() ...@@ -14,6 +18,8 @@ else()
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
endif() endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
...@@ -22,8 +22,6 @@ namespace { ...@@ -22,8 +22,6 @@ namespace {
template <typename T> template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
const int N, const int D) { const int N, const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) { i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D); PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
......
if(WITH_AVX)
cc_library(activation_functions SRCS hl_cpu_functions.cc hl_avx_functions.cc)
else()
cc_library(activation_functions SRCS hl_cpu_functions.cc)
endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_ACTIVATION_FUNCTIONS_H_
#define HL_ACTIVATION_FUNCTIONS_H_
#include "hl_functions.h"
#include "paddle/operators/math/lstm_compute.h"
/**
* Active functions: sigmoid, relu, tanh and linear.
*/
#define FLOAT_ACTIVE_FUNCTION \
{ \
hppl::typef::sigmoid, hppl::typef::relu, hppl::typef::tanh, \
hppl::typef::linear \
}
#define DOUBLE_ACTIVE_FUNCTION \
{ \
hppl::typed::sigmoid, hppl::typed::relu, hppl::typed::tanh, \
hppl::typed::linear \
}
#define AVX_ACTIVE_FUNCTION \
{ hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear }
namespace hppl {
using activation_mode_t = paddle::operators::math::activation_mode_t;
/**
* Hppl supports sigmoid, relu, tanh, linear active functions
* for neural networks' forward and backward activation.
*/
template <class T>
class Active {
public:
typedef T (*forward)(T);
typedef T (*backward)(T, T);
};
template <typename T>
struct ForwardActType;
template <>
struct ForwardActType<float> {
using type = Active<float>::forward;
};
template <>
struct ForwardActType<double> {
using type = Active<double>::forward;
};
template <typename T>
struct BackwardActType;
template <>
struct BackwardActType<float> {
using type = Active<float>::backward;
};
template <>
struct BackwardActType<double> {
using type = Active<double>::backward;
};
#ifdef __NVCC__
namespace gpu {
static __device__ Active<float>::forward forward[] = FLOAT_ACTIVE_FUNCTION;
static __device__ Active<float>::backward backward[] = FLOAT_ACTIVE_FUNCTION;
static __device__ Active<double>::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION;
static __device__ Active<double>::backward backward_d[] =
DOUBLE_ACTIVE_FUNCTION;
template <typename T>
struct ForwardAct {
__device__ typename ForwardActType<T>::type operator()(
activation_mode_t type);
};
template <>
struct ForwardAct<float> {
__device__ ForwardActType<float>::type operator()(activation_mode_t type) {
return forward[type];
}
};
template <>
struct ForwardAct<double> {
__device__ ForwardActType<double>::type operator()(activation_mode_t type) {
return forward_d[type];
}
};
template <typename T>
struct BackwardAct {
__device__ typename BackwardActType<T>::type operator()(
activation_mode_t type);
};
template <>
struct BackwardAct<float> {
__device__ BackwardActType<float>::type operator()(activation_mode_t type) {
return backward[type];
}
};
template <>
struct BackwardAct<double> {
__device__ BackwardActType<double>::type operator()(activation_mode_t type) {
return backward_d[type];
}
};
} // namespace gpu
#else
namespace cpu {
static Active<float>::forward forward[] = FLOAT_ACTIVE_FUNCTION;
static Active<float>::backward backward[] = FLOAT_ACTIVE_FUNCTION;
static Active<double>::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION;
static Active<double>::backward backward_d[] = DOUBLE_ACTIVE_FUNCTION;
template <typename T>
struct ForwardAct {
typename ForwardActType<T>::type operator()(activation_mode_t type);
};
template <>
struct ForwardAct<float> {
ForwardActType<float>::type operator()(activation_mode_t type) {
return forward[type];
}
};
template <>
struct ForwardAct<double> {
ForwardActType<double>::type operator()(activation_mode_t type) {
return forward_d[type];
}
};
template <typename T>
struct BackwardAct {
typename BackwardActType<T>::type operator()(activation_mode_t type);
};
template <>
struct BackwardAct<float> {
BackwardActType<float>::type operator()(activation_mode_t type) {
return backward[type];
}
};
template <>
struct BackwardAct<double> {
BackwardActType<double>::type operator()(activation_mode_t type) {
return backward_d[type];
}
};
} // namespace cpu
#ifdef __AVX__
namespace avx {
static Active<__m256>::forward forward[] = AVX_ACTIVE_FUNCTION;
static Active<__m256>::backward backward[] = AVX_ACTIVE_FUNCTION;
} // namespace avx
#endif
#endif
} // namespace hppl
#endif // HL_ACTIVATION_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <immintrin.h>
#include "hl_functions.h"
// TODO(qingqing) refine this dependence
#include "paddle/cuda/src/avx_mathfun.h"
namespace hppl {
__m256 exp(__m256 a) { return exp256_ps(a); }
__m256 relu(const __m256 a) {
__m256 tmp = _mm256_set1_ps(0.0f);
return _mm256_max_ps(a, tmp);
}
__m256 sigmoid(const __m256 a) {
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);
__m256 tmp = _mm256_max_ps(a, min);
tmp = _mm256_min_ps(tmp, max);
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp);
tmp = exp(tmp);
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
return tmp;
}
__m256 tanh(const __m256 a) {
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
tmp = _mm256_min_ps(tmp, max);
tmp = exp(tmp);
return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f),
_mm256_add_ps(_mm256_set1_ps(1.0f), tmp)),
_mm256_set1_ps(1.0f));
}
__m256 linear(const __m256 a) { return a; }
__m256 relu(const __m256 a, const __m256 b) {
return _mm256_mul_ps(
a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS),
_mm256_set1_ps(1.0f)));
}
__m256 sigmoid(const __m256 a, const __m256 b) {
return _mm256_mul_ps(_mm256_mul_ps(a, b),
_mm256_sub_ps(_mm256_set1_ps(1.0f), b));
}
__m256 tanh(const __m256 a, const __m256 b) {
return _mm256_mul_ps(
a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b)));
}
__m256 linear(const __m256 a, const __m256 b) { return a; }
} // namespace hppl
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_AVX_FUNCTIONS_H_
#define HL_AVX_FUNCTIONS_H_
#include <immintrin.h>
namespace hppl {
__m256 relu(const __m256 a);
__m256 sigmoid(const __m256 a);
__m256 tanh(const __m256 a);
__m256 linear(const __m256 a);
__m256 relu(const __m256 a, const __m256 b);
__m256 sigmoid(const __m256 a, const __m256 b);
__m256 tanh(const __m256 a, const __m256 b);
__m256 linear(const __m256 a, const __m256 b);
} // namespace hppl
#endif // HL_AVX_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include "hl_functions.h"
namespace hppl {
namespace typef {
float relu(const float a) {
return a > static_cast<float>(0.0) ? a : static_cast<float>(0.0);
}
float sigmoid(const float a) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
float tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<float>(1.0) / (static_cast<float>(1.0) + exp(-tmp));
}
float tanh(const float a) {
float tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
float linear(const float a) { return a; }
float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); }
float sigmoid(const float a, const float b) {
return a * b * (static_cast<float>(1) - b);
}
float tanh(const float a, const float b) {
return a * (static_cast<float>(1) - b * b);
}
float linear(const float a, const float b) { return a; }
} // namespace typef
namespace typed {
double relu(const double a) {
return a > static_cast<double>(0.0) ? a : static_cast<double>(0.0);
}
double sigmoid(const double a) {
const double min = SIGMOID_THRESHOLD_MIN;
const double max = SIGMOID_THRESHOLD_MAX;
double tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<double>(1.0) / (static_cast<double>(1.0) + exp(-tmp));
}
double tanh(const double a) {
double tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
double linear(const double a) { return a; }
double relu(const double a, const double b) {
return a * (b > 0.0 ? 1.0 : 0.0);
}
double sigmoid(const double a, const double b) {
return a * b * (static_cast<double>(1) - b);
}
double tanh(const double a, const double b) {
return a * (static_cast<double>(1) - b * b);
}
double linear(const double a, const double b) { return a; }
} // namespace typed
} // namespace hppl
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_FUNCTIONS_H_
#define HL_FUNCTIONS_H_
/**
* sigmoid threshold maximum
*/
#define SIGMOID_THRESHOLD_MIN -40.0
/**
* sigmoid threshold minimum
*/
#define SIGMOID_THRESHOLD_MAX 13.0
/**
* The maximum input value for exp, used to avoid overflow problem.
* currently only used for tanh function.
*/
#define EXP_MAX_INPUT 40.0
#ifndef __NVCC__
namespace hppl {
namespace typef {
float relu(const float a);
float sigmoid(const float a);
float tanh(const float a);
float linear(const float a);
float relu(const float a, const float b);
float sigmoid(const float a, const float b);
float tanh(const float a, const float b);
float linear(const float a, const float b);
} // namespace typef
namespace typed {
double relu(const double a);
double sigmoid(const double a);
double tanh(const double a);
double linear(const double a);
double relu(const double a, const double b);
double sigmoid(const double a, const double b);
double tanh(const double a, const double b);
double linear(const double a, const double b);
} // namespace typed
} // namespace hppl
#ifdef __AVX__
#include "hl_avx_functions.h"
#endif
#else
#include "hl_gpu_functions.h"
#endif
#endif // HL_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_GPU_FUNCTIONS_CUH_
#define HL_GPU_FUNCTIONS_CUH_
#include "hl_base.h"
namespace hppl {
namespace typef {
__device__ static float relu(const float a) { return a > 0.0f ? a : 0.0f; }
__device__ static float sigmoid(const float a) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
float tmp = (a < min) ? min : ((a > max) ? max : a);
return __fdividef(1.0f, 1.0f + __expf(-tmp));
}
__device__ static float tanh(const float a) {
float tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return __fdividef(2.0f, (1.0f + __expf(-2.0f * tmp))) - 1.0f;
}
__device__ static float linear(const float a) { return a; }
__device__ static float relu(const float a, const float b) {
return a * (b > 0.0f ? 1.0f : 0.0f);
}
__device__ static float sigmoid(const float a, const float b) {
return a * b * (1.0f - b);
}
__device__ static float tanh(const float a, const float b) {
return a * (1.0f - b * b);
}
__device__ static float linear(const float a, const float b) { return a; }
} // namespace typef
namespace typed {
__device__ static double relu(const double a) { return a > 0.0 ? a : 0.0; }
__device__ static double sigmoid(const double a) {
const double min = SIGMOID_THRESHOLD_MIN;
const double max = SIGMOID_THRESHOLD_MAX;
double tmp = (a < min) ? min : ((a > max) ? max : a);
return 1.0 / (1.0 + exp(-tmp));
}
__device__ static double tanh(const double a) {
double tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0;
}
__device__ static double linear(const double a) { return a; }
__device__ static double relu(const double a, const double b) {
return a * (b > 0.0 ? 1.0 : 0.0);
}
__device__ static double sigmoid(const double a, const double b) {
return a * b * (1 - b);
}
__device__ static double tanh(const double a, const double b) {
return a * (1.0 - b * b);
}
__device__ static double linear(const double a, const double b) { return a; }
} // namespace typef
} // namespace hppl
#endif // HL_GPU_FUNCTIONS_CUH_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/lstm_compute.h"
namespace paddle {
namespace operators {
namespace math {
namespace detail {
#ifndef __NVCC__
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rCheckI;
T rCheckF;
T rCheckO;
T rState;
T rPrevState = 0;
T rStateAtv;
T rOut;
T *valueIn = value.gateValue;
T *valueIg = value.gateValue + frameSize;
T *valueFg = value.gateValue + frameSize * 2;
T *valueOg = value.gateValue + frameSize * 3;
for (int i = 0; i < frameSize; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = value.checkIg[i];
rCheckF = value.checkFg[i];
rCheckO = value.checkOg[i];
if (value.prevStateValue) {
rPrevState = value.prevStateValue[i];
}
hppl::cpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
act(active_state));
valueIn[i] = rValueIn;
valueIg[i] = rValueIg;
valueFg[i] = rValueFg;
valueOg[i] = rValueOg;
value.stateValue[i] = rState;
value.stateActiveValue[i] = rStateAtv;
value.outputValue[i] = rOut;
}
}
template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rGradIn;
T rGradIg;
T rGradFg;
T rGradOg;
T rPrevState = 0;
T rPrevStateGrad;
T rState;
T rStateGrad;
T rStateAtv;
T rOutputGrad;
T rCheckI;
T rCheckF;
T rCheckO;
T rCheckIGrad;
T rCheckFGrad;
T rCheckOGrad;
T *valueIn = value.gateValue;
T *valueIg = value.gateValue + frameSize;
T *valueFg = value.gateValue + frameSize * 2;
T *valueOg = value.gateValue + frameSize * 3;
T *gradIn = grad.gateGrad;
T *gradIg = grad.gateGrad + frameSize;
T *gradFg = grad.gateGrad + frameSize * 2;
T *gradOg = grad.gateGrad + frameSize * 3;
for (int i = 0; i < frameSize; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = value.checkIg[i];
rCheckF = value.checkFg[i];
rCheckO = value.checkOg[i];
rState = value.stateValue[i];
rStateAtv = value.stateActiveValue[i];
rOutputGrad = grad.outputGrad[i];
rStateGrad = grad.stateGrad[i];
if (value.prevStateValue) {
rPrevState = value.prevStateValue[i];
}
hppl::cpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, act(active_node), act(active_gate), act(active_state));
gradIn[i] = rGradIn;
gradIg[i] = rGradIg;
gradFg[i] = rGradFg;
gradOg[i] = rGradOg;
grad.stateGrad[i] = rStateGrad;
if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad;
if (value.prevStateValue) {
if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad;
if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad;
}
if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad;
}
}
template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
#ifdef __AVX__
__m256 rValueIn;
__m256 rValueIg;
__m256 rValueFg;
__m256 rValueOg;
__m256 rCheckI;
__m256 rCheckF;
__m256 rCheckO;
__m256 rState;
__m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rStateAtv;
__m256 rOut;
__m256 *valueIn = (__m256 *)value.gateValue;
__m256 *valueIg = (__m256 *)(value.gateValue + frameSize);
__m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2);
__m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3);
for (int i = 0; i < frameSize / 8; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, hppl::avx::forward[active_node],
hppl::avx::forward[active_gate], hppl::avx::forward[active_state]);
valueIn[i] = rValueIn;
valueIg[i] = rValueIg;
valueFg[i] = rValueFg;
valueOg[i] = rValueOg;
((__m256 *)value.stateValue)[i] = rState;
((__m256 *)value.stateActiveValue)[i] = rStateAtv;
((__m256 *)value.outputValue)[i] = rOut;
}
#endif
}
template <class T, class Op>
void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
#ifdef __AVX__
__m256 rValueIn;
__m256 rValueIg;
__m256 rValueFg;
__m256 rValueOg;
__m256 rGradIn;
__m256 rGradIg;
__m256 rGradFg;
__m256 rGradOg;
__m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rPrevStateGrad;
__m256 rStateGrad;
__m256 rState;
__m256 rStateAtv;
__m256 rOutputGrad;
__m256 rCheckI;
__m256 rCheckF;
__m256 rCheckO;
__m256 rCheckIGrad;
__m256 rCheckFGrad;
__m256 rCheckOGrad;
__m256 *valueIn = (__m256 *)value.gateValue;
__m256 *valueIg = (__m256 *)(value.gateValue + frameSize);
__m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2);
__m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3);
__m256 *gradIn = (__m256 *)grad.gateGrad;
__m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize);
__m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2);
__m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3);
for (int i = 0; i < frameSize / 8; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
rState = ((__m256 *)value.stateValue)[i];
rStateAtv = ((__m256 *)value.stateActiveValue)[i];
rOutputGrad = ((__m256 *)grad.outputGrad)[i];
rStateGrad = ((__m256 *)grad.stateGrad)[i];
if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, hppl::avx::backward[active_node],
hppl::avx::backward[active_gate], hppl::avx::backward[active_state]);
gradIn[i] = rGradIn;
gradIg[i] = rGradIg;
gradFg[i] = rGradFg;
gradOg[i] = rGradOg;
((__m256 *)grad.stateGrad)[i] = rStateGrad;
if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad;
if (value.prevStateValue) {
if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad;
if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad;
}
if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad;
}
#endif
}
template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state);
} else {
naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state);
}
}
template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frameSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
} else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
}
}
#endif
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"
#include <glog/logging.h>
namespace paddle {
namespace operators {
namespace math {
namespace detail {
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class T, class Op, bool isBatch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
value.gateValue += batchIdx * frameSize * 4;
value.outputValue += batchIdx * frameSize;
value.stateValue += batchIdx * frameSize;
value.stateActiveValue += batchIdx * frameSize;
}
T rState;
T rPrevState = 0;
T rStateAtv;
T rOut;
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rCheckI = value.checkIg[frameIdx];
T rCheckF = value.checkFg[frameIdx];
T rCheckO = value.checkOg[frameIdx];
rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize];
rValueFg = value.gateValue[frameIdx + frameSize * 2];
rValueOg = value.gateValue[frameIdx + frameSize * 3];
if (value.prevStateValue) {
if (isBatch) value.prevStateValue += batchIdx * frameSize;
rPrevState = value.prevStateValue[frameIdx];
}
hppl::gpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
act(active_state));
value.gateValue[frameIdx] = rValueIn;
value.gateValue[frameIdx + frameSize] = rValueIg;
value.gateValue[frameIdx + frameSize * 2] = rValueFg;
value.gateValue[frameIdx + frameSize * 3] = rValueOg;
value.stateValue[frameIdx] = rState;
value.stateActiveValue[frameIdx] = rStateAtv;
value.outputValue[frameIdx] = rOut;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class T, class Op, bool isBatch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
value.gateValue += batchIdx * frameSize * 4;
value.stateValue += batchIdx * frameSize;
value.stateActiveValue += batchIdx * frameSize;
grad.gateGrad += batchIdx * frameSize * 4;
grad.stateGrad += batchIdx * frameSize;
grad.outputGrad += batchIdx * frameSize;
}
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rGradIn;
T rGradIg;
T rGradFg;
T rGradOg;
T rPrevState = 0;
T rPrevStateGrad;
T rState;
T rStateGrad;
T rStateAtv;
T rOutputGrad;
T rCheckI = value.checkIg[frameIdx];
T rCheckF = value.checkFg[frameIdx];
T rCheckO = value.checkOg[frameIdx];
T rCheckIGrad;
T rCheckFGrad;
T rCheckOGrad;
rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize];
rValueFg = value.gateValue[frameIdx + frameSize * 2];
rValueOg = value.gateValue[frameIdx + frameSize * 3];
rState = value.stateValue[frameIdx];
rStateAtv = value.stateActiveValue[frameIdx];
rOutputGrad = grad.outputGrad[frameIdx];
rStateGrad = grad.stateGrad[frameIdx];
if (value.prevStateValue) {
if (isBatch) value.prevStateValue += batchIdx * frameSize;
rPrevState = value.prevStateValue[frameIdx];
}
hppl::gpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
act(active_node), act(active_gate), act(active_state));
grad.gateGrad[frameIdx] = rGradIn;
grad.gateGrad[frameIdx + frameSize] = rGradIg;
grad.gateGrad[frameIdx + frameSize * 2] = rGradFg;
grad.gateGrad[frameIdx + frameSize * 3] = rGradOg;
grad.stateGrad[frameIdx] = rStateGrad;
if (grad.prevStateGrad) {
if (isBatch) grad.prevStateGrad += batchIdx * frameSize;
grad.prevStateGrad[frameIdx] = rPrevStateGrad;
}
if (isBatch) {
if (value.prevStateValue) {
if (grad.checkIgGrad)
paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx,
rCheckIGrad);
if (grad.checkFgGrad)
paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx,
rCheckFGrad);
}
if (grad.checkOgGrad)
paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad);
} else {
if (value.prevStateValue) {
if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad;
if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad;
}
if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad;
}
}
template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
/* framePerBlock = 32 batchPerBlock = 32 */
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
}
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
if (batchSize == 1) {
KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
} else {
KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
}
}
template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
/* framePerBlock = 32 batchPerBlock = 32 */
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
}
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
if (batchSize == 1) {
KeLstmBackward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
} else {
KeLstmBackward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
}
}
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/platform/hostdevice.h"
#include <type_traits>
namespace paddle {
namespace operators {
namespace math {
namespace detail {
namespace forward {
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO,
typename hppl::ForwardActType<T>::type actInput,
typename hppl::ForwardActType<T>::type actGate,
typename hppl::ForwardActType<T>::type actState) {
valueIn = actInput(valueIn);
valueIg = actGate(valueIg + prevState * checkI);
valueFg = actGate(valueFg + prevState * checkF);
state = valueIn * valueIg + prevState * valueFg;
valueOg = actGate(valueOg + state * checkO);
stateAtv = actState(state);
output = valueOg * stateAtv;
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false;
#else
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &prevState, __m256 &state,
__m256 &stateAtv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO,
hppl::Active<__m256>::forward actInput,
hppl::Active<__m256>::forward actGate,
hppl::Active<__m256>::forward actState) {
valueIn = actInput(valueIn);
valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)));
valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)));
state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg),
_mm256_mul_ps(prevState, valueFg));
valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)));
stateAtv = actState(state);
output = _mm256_mul_ps(valueOg, stateAtv);
}
#endif
#endif
};
} // namespace forward
namespace backward {
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
T &prevState, T &prevStateGrad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad,
T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad,
typename hppl::BackwardActType<T>::type actInput,
typename hppl::BackwardActType<T>::type actGate,
typename hppl::BackwardActType<T>::type actState) {
gradOg = actGate(outputGrad * stateAtv, valueOg);
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = actInput(stateGrad * valueIg, valueIn);
gradIg = actGate(stateGrad * valueIn, valueIg);
gradFg = actGate(stateGrad * prevState, valueFg);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false;
#else
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
__m256 &prevStateGrad, __m256 &state,
__m256 &stateGrad, __m256 &stateAtv,
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
__m256 &checkO, __m256 &checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad,
hppl::Active<__m256>::backward actInput,
hppl::Active<__m256>::backward actGate,
hppl::Active<__m256>::backward actState) {
gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg);
stateGrad = _mm256_add_ps(
actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad);
stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad);
gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn);
gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg);
gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg);
prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI),
_mm256_mul_ps(gradFg, checkF));
prevStateGrad =
_mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad);
checkIGrad = _mm256_mul_ps(gradIg, prevState);
checkFGrad = _mm256_mul_ps(gradFg, prevState);
checkOGrad = _mm256_mul_ps(gradOg, state);
}
#endif
#endif
};
} // namespace backward
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/detail/lstm_cpu_kernel.h"
#include "paddle/operators/math/detail/lstm_kernel.h"
namespace paddle {
namespace operators {
namespace math {
template <class T>
struct LstmUnitFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
ActiveType(cand_act), ActiveType(gate_act),
ActiveType(cell_act));
value.gateValue += frame_size * 4;
value.stateValue += frame_size;
value.stateActiveValue += frame_size;
value.outputValue += frame_size;
if (value.prevStateValue) {
value.prevStateValue += frame_size;
}
}
}
};
template <class T>
struct LstmUnitGradFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frame_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
value.gateValue += frame_size * 4;
value.stateValue += frame_size;
value.stateActiveValue += frame_size;
value.outputValue += frame_size;
if (value.prevStateValue) {
value.prevStateValue += frame_size;
}
grad.gateGrad += frame_size * 4;
grad.stateGrad += frame_size;
grad.stateActiveGrad += frame_size;
grad.outputGrad += frame_size;
if (grad.prevStateGrad) {
grad.prevStateGrad += frame_size;
}
}
}
};
template class LstmUnitFunctor<platform::CPUPlace, float>;
template class LstmUnitFunctor<platform::CPUPlace, double>;
template class LstmUnitGradFunctor<platform::CPUPlace, float>;
template class LstmUnitGradFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/lstm_gpu_kernel.h"
#include "paddle/operators/math/detail/lstm_kernel.h"
#include "paddle/operators/math/lstm_compute.h"
namespace paddle {
namespace operators {
namespace math {
template <class T>
struct LstmUnitFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
}
};
template <class T>
struct LstmUnitGradFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
}
};
template class LstmUnitFunctor<platform::GPUPlace, float>;
template class LstmUnitFunctor<platform::GPUPlace, double>;
template class LstmUnitGradFunctor<platform::GPUPlace, float>;
template class LstmUnitGradFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
typedef enum {
HL_ACTIVATION_SIGMOID = 0,
HL_ACTIVATION_RELU = 1,
HL_ACTIVATION_TANH = 2,
HL_ACTIVATION_LINEAR = 3,
HL_ACTIVATION_END
} activation_mode_t;
template <class T>
struct LstmMetaValue {
T *gateValue;
T *prevStateValue;
T *stateValue;
T *stateActiveValue;
T *outputValue;
T *checkIg;
T *checkFg;
T *checkOg;
};
template <class T>
struct LstmMetaGrad {
T *gateGrad;
T *prevStateGrad;
T *stateGrad;
T *stateActiveGrad;
T *outputGrad;
T *checkIgGrad;
T *checkFgGrad;
T *checkOgGrad;
};
inline activation_mode_t ActiveType(const std::string &type) {
if (type == "sigmoid") {
return HL_ACTIVATION_SIGMOID;
} else if (type == "relu") {
return HL_ACTIVATION_RELU;
} else if (type == "tanh") {
return HL_ACTIVATION_TANH;
} else if (type == "linear" || type == "identity" || type == "") {
return HL_ACTIVATION_LINEAR;
} else {
PADDLE_THROW("Do not support activation type.");
}
}
template <typename Place, typename T>
class LstmUnitFunctor {
public:
static void compute(const platform::DeviceContext &context,
LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act,
const std::string &cand_act);
};
template <typename Place, typename T>
class LstmUnitGradFunctor {
public:
static void compute(const platform::DeviceContext &context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act,
const std::string &cand_act);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index) {
auto src_dims = src.dims();
auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2UL,
"The src must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL,
"The dst must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1],
"The width of src and dst must be same.");
auto height = dst_dims[0];
auto width = dst_dims[1];
auto* src_data = src.data<T>();
auto* dst_data = dst.data<T>();
for (int i = 0; i < height; ++i) {
if (is_src_index) {
memcpy(dst_data + i * width, src_data + index[i] * width,
width * sizeof(T));
} else {
memcpy(dst_data + index[i] * width, src_data + i * width,
width * sizeof(T));
}
}
}
};
template class CopyMatrixRowsFunctor<platform::CPUPlace, float>;
template class CopyMatrixRowsFunctor<platform::CPUPlace, double>;
template class LoDTensor2BatchFunctor<platform::CPUPlace, float>;
template class LoDTensor2BatchFunctor<platform::CPUPlace, double>;
template class Batch2LoDTensorFunctor<platform::CPUPlace, float>;
template class Batch2LoDTensorFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index,
int64_t height, int64_t width,
bool is_src_index) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int id = blockIdx.x + idy * GridDimX;
while (id < height) {
int src_idx = is_src_index ? index[id] : id;
int dst_idx = is_src_index ? id : index[id];
const T* src_data = src + src_idx * width;
T* dst_data = dst + dst_idx * width;
for (int i = idx; i < width; i += BlockDimX) {
dst_data[i] = src_data[i];
}
id += BlockDimY * GridDimX;
}
}
template <typename T>
class CopyMatrixRowsFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index) {
auto src_dims = src.dims();
auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2,
"The src must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(dst_dims.size(), 2,
"The dst must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1],
"The width of src and dst must be same.");
auto height = dst_dims[0];
auto width = dst_dims[1];
auto* src_data = src.data<T>();
auto* dst_data = dst.data<T>();
dim3 threads(128, 8);
dim3 grid(8, 1);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>(
src_data, dst_data, index, height, width, is_src_index);
}
};
template class CopyMatrixRowsFunctor<platform::GPUPlace, float>;
template class CopyMatrixRowsFunctor<platform::GPUPlace, double>;
template class LoDTensor2BatchFunctor<platform::GPUPlace, float>;
template class LoDTensor2BatchFunctor<platform::GPUPlace, double>;
template class Batch2LoDTensorFunctor<platform::GPUPlace, float>;
template class Batch2LoDTensorFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
class CopyMatrixRowsFunctor {
public:
// If is_src_index is true,
// copy the indexed rows of input src to the output dst.
// If is_src_index is false,
// copy the input src to the indexed rows of output dst.
// The indexed rows are based on the input index.
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index);
};
template <typename Place, typename T>
class LoDTensor2BatchFunctor {
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
//
struct SeqInfo {
SeqInfo(int start, int length, int seq_idx)
: start(start), length(length), seq_idx(seq_idx) {}
int start;
int length;
int seq_idx;
};
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& lod_tensor,
framework::LoDTensor& batch, bool is_reverse) const {
auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
auto lod = lods[0];
std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
seq_info.emplace_back(lod[seq_id], length, seq_id);
}
std::sort(seq_info.begin(), seq_info.end(),
[](SeqInfo a, SeqInfo b) { return a.length > b.length; });
// calculate the start position of each batch
// (numBatch equal the maxLength of sequences)
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// num_batch = 5,
// batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = len(b0)
// batch_start_positions[1] = len(b0) + len(b1)
// batch_start_positions[2] = len(b0) + len(b1) + len(b2)
// ...
// seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10,
// 6, 2, 11,
// 7, 3,
// 8}
// The batch number represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence.
paddle::framework::LoD batch_lods;
batch_lods.emplace_back(std::vector<size_t>{0});
batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor
int num_batch = seq_info[0].length;
batch_lods[0].resize(static_cast<size_t>(num_batch + 1));
// batch_lods[1] is the raw index in the input LoDTensor
auto dims = lod_tensor.dims();
batch_lods[1].resize(static_cast<size_t>(dims[0]));
size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (size_t n = 0; n < num_batch; n++) {
auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length;
int start = seq_info[i].start;
if (n < seq_len) {
seq2batch_idx[batch_id] =
is_reverse ? start + seq_len - 1 - n : start + n;
batch_id++;
} else {
break;
}
}
batch_starts[n + 1] = static_cast<size_t>(batch_id);
}
batch.set_lod(batch_lods);
CopyMatrixRowsFunctor<Place, T> to_batch;
to_batch(context, lod_tensor, seq2batch_idx, batch, true);
}
};
template <typename Place, typename T>
class Batch2LoDTensorFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& batch,
framework::LoDTensor& lod_tensor) const {
auto in_lod = batch.lod();
PADDLE_ENFORCE_EQ(in_lod.size(), 2UL,
"The LoD size of input `batch` should be 2.");
auto out_lod = lod_tensor.lod()[0];
auto num = out_lod[out_lod.size() - 1];
PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]);
PADDLE_ENFORCE_EQ(num, in_lod[1].size());
PADDLE_ENFORCE_EQ(num, batch.dims()[0]);
CopyMatrixRowsFunctor<Place, T> to_seq;
size_t* index = in_lod[1].data();
to_seq(context, batch, index, lod_tensor, false);
}
};
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -242,7 +242,7 @@ class OpTest(unittest.TestCase): ...@@ -242,7 +242,7 @@ class OpTest(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
actual, expect, atol=atol), actual, expect, atol=atol),
"output name: " + out_name + " has diff.") "Output (" + out_name + ") has diff at " + str(place))
else: else:
actual = np.array(self.scope.find_var(out_name).get_tensor()) actual = np.array(self.scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name] expect = self.outputs[out_name]
...@@ -250,7 +250,7 @@ class OpTest(unittest.TestCase): ...@@ -250,7 +250,7 @@ class OpTest(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
actual, expect, atol=atol), actual, expect, atol=atol),
"output name: " + out_name + " has diff.") "Output (" + out_name + ") has diff at " + str(place))
def check_output(self, atol=1e-5): def check_output(self, atol=1e-5):
places = [core.CPUPlace()] places = [core.CPUPlace()]
......
import unittest
import numpy as np
from op_test import OpTest
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def identity(x):
return x
def sigmoid(x):
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
def relu(x):
return np.maximum(x, 0)
ACTVATION = {
'identity': identity,
'sigmoid': sigmoid,
'tanh': tanh,
'relu': relu
}
def lstm(
input, # T x 4D
lod, # 1 x N
h0=None, # N x D
c0=None, # N x D
w_h=None, # D x 4D
w_b=None, # 1 x 4D
w_c=None, # 1 x 3D
is_reverse=False,
act_gate=None,
act_cell=None,
act_cand=None):
def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand):
g = np.dot(h_pre, w_h) # 1 x 4D
g = g + x
g = np.reshape(g, (1, g.size))
c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1)
if w_c is None:
g_i = act_gate(g_i) # 1 x D
g_f = act_gate(g_f) # 1 x D
else:
w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
g_i = act_gate(g_i + w_ic * c_pre) # 1 x D
g_f = act_gate(g_f + w_fc * c_pre) # 1 x D
c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D
if w_c is None:
g_o = act_gate(g_o) # 1 x D
else:
_, _, w_oc = np.split(w_c, 3, axis=1)
g_o = act_gate(g_o + w_oc * c) # 1 x D
h = g_o * act_cell(c)
bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1)
return h, c, bg
def _reverse(x, lod):
y = np.zeros_like(x)
for i in range(len(lod) - 1):
b, e = lod[i], lod[i + 1]
y[b:e, :] = np.flip(x[b:e, :], 0)
return y
offset = lod[0]
batch_size = len(offset) - 1
hidden = []
cell = []
gate = []
input = _reverse(input, offset) if is_reverse else input
if w_b is not None:
input = input + np.tile(w_b, (offset[-1], 1))
for i in range(batch_size):
# compute one sequence
seq_len = offset[i + 1] - offset[i]
x = input[offset[i]:offset[i + 1], :]
h_pre = h0[i] # 1 x D
c_pre = c0[i] # 1 x D
for j in range(seq_len):
# compute one step
h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
act_cell, act_cand)
hidden.append(h_pre.flatten())
cell.append(c_pre.flatten())
gate.append(g_pre.flatten())
hidden = np.array(hidden).astype("float64")
cell = np.array(cell).astype("float64")
gate = np.array(gate).astype("float64")
hidden = _reverse(hidden, offset) if is_reverse else hidden
cell = _reverse(cell, offset) if is_reverse else cell
assert gate.shape == input.shape
assert hidden.shape == (input.shape[0], input.shape[1] / 4)
assert cell.shape == (input.shape[0], input.shape[1] / 4)
return hidden, cell, gate
class TestLstmOp(OpTest):
def set_data(self):
self.lod = [[0, 2, 6, 9]]
self.D = 64
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
self.act_gate = "sigmoid"
self.act_cell = "tanh"
self.act_cand = "tanh"
self.is_reverse = False
def setUp(self):
self.set_data()
self.op_type = "lstm"
T = self.lod[0][-1]
N = len(self.lod[0]) - 1
x = np.random.normal(size=(T, 4 * self.D)).astype("float64")
h0 = np.zeros((N, self.D)).astype("float64")
c0 = np.zeros((N, self.D)).astype("float64")
w = np.random.normal(size=(self.D, 4 * self.D)).astype("float64")
b = np.random.normal(size=(1, 7 * self.D)).astype("float64")
w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:]
h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand])
g_sort = np.zeros_like(x)
for i, j in enumerate(self.sort_idx):
g_sort[i, :] = g[j, :]
self.inputs = {
'Input': (x, self.lod),
'H0': h0,
'C0': c0,
'Weight': w,
'Bias': b
}
self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort}
self.attrs = {
'usePeepholes': True,
'isReverse': self.is_reverse,
'gateActivation': 'sigmoid',
'cellActivation': 'tanh',
'candidateActivation': 'tanh'
}
def test_check_output(self):
self.check_output()
class TestLstmOpRerverse(TestLstmOp):
def set_data(self):
self.lod = [[0, 2, 6, 9]]
self.D = 64
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
self.act_gate = "sigmoid"
self.act_cell = "tanh"
self.act_cand = "tanh"
self.is_reverse = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册