From 65906ef1d0782e76b3bc40c09df30a01c423fb7c Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 20 Oct 2017 12:52:35 -0700 Subject: [PATCH] Several Enhancement --- paddle/operators/lstm_op.cc | 16 ++--- paddle/operators/lstm_op.h | 18 ++--- paddle/operators/math/detail/lstm_kernel.h | 83 +++++++++++----------- paddle/operators/math/lstm_compute.cc | 9 +-- paddle/operators/math/lstm_compute.cu | 9 +-- paddle/operators/math/lstm_compute.h | 9 +-- paddle/operators/math/sequence2batch.cc | 2 - paddle/operators/math/sequence2batch.cu | 2 +- paddle/operators/math/sequence2batch.h | 51 ++++++------- 9 files changed, 102 insertions(+), 97 deletions(-) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index f360502e666..222aeeace5c 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -68,7 +68,7 @@ class LSTMOp : public framework::OperatorWithKernel { } else { PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, "The second dimension of Input(Bias) should be " - "4 * %d if diable peepholes connection", + "4 * %d if disable peepholes connection", frame_size); } ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); @@ -86,7 +86,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Input", "(LoDTensor) the first input is a LodTensor, which support " "variable-time length input sequence. The underlying tensor in " - "this LoDTenosr is a matrix with shape (T X 4D), where, T is the " + "this LoDTensor is a matrix with shape (T X 4D), where, T is the " "total time steps in this mini-batch, D is the hidden size."); AddInput("H0", "(Tensor, optional) the initial hidden state is an optional " @@ -112,7 +112,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { " - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); AddOutput("BatchGate", "(LoDTensor) This LoDTensor contains input gate, forget gate " - "and output gate aftern the nonlinear computation. This " + "and output gate after the nonlinear computation. This " "LoDTensor has the same shape with the reorganized input, which " "was also be called batch input. The LoD size is 2. The first " "LoD is the batch offsets and the second LoD contains the " @@ -135,18 +135,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(false); AddAttr( "gateActivation", - "(string, defalut: sigmoid)" + "(string, default: sigmoid)" "The activation for input gate, forget gate and output " - "gate, `sigmoid` by defalut.") + "gate, `sigmoid` by default.") .SetDefault("sigmoid"); AddAttr("cellActivation", - "(string, defalut: tanh)" + "(string, default: tanh)" "The activation for cell output, `tanh` by defalut.") .SetDefault("tanh"); AddAttr("candidateActivation", - "(string, defalut: tanh)" + "(string, default: tanh)" "The activation for candidate hidden state, " - "`tanh` by defalut.") + "`tanh` by default.") .SetDefault("tanh"); AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index b9d4ae3a6fb..5e100367077 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -52,7 +52,7 @@ class LSTMKernel : public framework::OpKernel { to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); auto in_dims = input->dims(); - int frame_size = in_dims[1] / 4; + int frame_size = static_cast(in_dims[1] / 4); framework::DDim dims({in_dims[0], frame_size}); if (bias) { @@ -70,7 +70,7 @@ class LSTMKernel : public framework::OpKernel { math::LstmMetaValue lstm_value; T* bias_data = const_cast(bias->data()); - // the code styple in LstmMetaValue will be updated later. + // the code style in LstmMetaValue will be updated later. lstm_value.checkIg = bias_data + 4 * frame_size; lstm_value.checkFg = lstm_value.checkIg + frame_size; lstm_value.checkOg = lstm_value.checkFg + frame_size; @@ -83,15 +83,15 @@ class LSTMKernel : public framework::OpKernel { framework::LoDTensor batch_cell_pre_act; batch_cell_pre_act.mutable_data(dims, ctx.GetPlace()); - auto batch_lod = batch_gate->lod()[0]; - int num_batch = batch_lod.size() - 1; + auto& batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; auto gate_act = ctx.Attr("gateActivation"); auto cell_act = ctx.Attr("cellActivation"); auto cand_act = ctx.Attr("candidateActivation"); - for (int n = 0; n < num_batch; n++) { - int bstart = batch_lod[n]; - int bend = batch_lod[n + 1]; + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor out_t = batch_out.Slice(bstart, bend); @@ -101,14 +101,14 @@ class LSTMKernel : public framework::OpKernel { int cur_batch_size = bend - bstart; if (n != 0) { - int pre_h_start = batch_lod[n - 1]; + int pre_h_start = static_cast(batch_starts[n - 1]); int pre_h_end = pre_h_start + cur_batch_size; auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end); math::matmul(ctx.device_context(), pre_hidden_t, false, *weight, false, static_cast(1.0), &gate_t, static_cast(1.0)); } - // else if : support the initial hidden and cell + // else if : FIXME support the initial hidden and cell lstm_value.gateValue = gate_t.data(); lstm_value.outputValue = out_t.data(); diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h index b1e59a4ee89..6f3ead2397d 100644 --- a/paddle/operators/math/detail/lstm_kernel.h +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -13,12 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/platform/hostdevice.h" -#ifdef __CUDA_ARCH__ -#define INLINE __device__ inline -#else -#define INLINE inline -#endif +#include namespace paddle { namespace operators { @@ -30,12 +27,12 @@ namespace forward { template class lstm { public: - INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, - T &prevState, T &state, T &stateAtv, T &output, - T &checkI, T &checkF, T &checkO, - typename hppl::ForwardActType::type actInput, - typename hppl::ForwardActType::type actGate, - typename hppl::ForwardActType::type actState) { + HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &prevState, T &state, T &stateAtv, T &output, + T &checkI, T &checkF, T &checkO, + typename hppl::ForwardActType::type actInput, + typename hppl::ForwardActType::type actGate, + typename hppl::ForwardActType::type actState) { valueIn = actInput(valueIn); valueIg = actGate(valueIg + prevState * checkI); valueFg = actGate(valueFg + prevState * checkF); @@ -45,17 +42,19 @@ class lstm { output = valueOg * stateAtv; } #ifndef __NVCC__ -#ifndef __AVX__ +#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default static const bool avx = false; #else - static const bool avx = true; - INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, - __m256 &valueOg, __m256 &prevState, __m256 &state, - __m256 &stateAtv, __m256 &output, __m256 &checkI, - __m256 &checkF, __m256 &checkO, - hppl::Active<__m256>::forward actInput, - hppl::Active<__m256>::forward actGate, - hppl::Active<__m256>::forward actState) { + // Only float support AVX optimization + static const bool avx = std::is_same::value; + + HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &prevState, __m256 &state, + __m256 &stateAtv, __m256 &output, __m256 &checkI, + __m256 &checkF, __m256 &checkO, + hppl::Active<__m256>::forward actInput, + hppl::Active<__m256>::forward actGate, + hppl::Active<__m256>::forward actState) { valueIn = actInput(valueIn); valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI))); valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF))); @@ -76,14 +75,15 @@ namespace backward { template class lstm { public: - INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, - T &gradIn, T &gradIg, T &gradFg, T &gradOg, - T &prevState, T &prevStateGrad, T &state, T &stateGrad, - T &stateAtv, T &outputGrad, T &checkI, T &checkF, - T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad, - typename hppl::BackwardActType::type actInput, - typename hppl::BackwardActType::type actGate, - typename hppl::BackwardActType::type actState) { + HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &gradIn, T &gradIg, T &gradFg, T &gradOg, + T &prevState, T &prevStateGrad, T &state, + T &stateGrad, T &stateAtv, T &outputGrad, + T &checkI, T &checkF, T &checkO, T &checkIGrad, + T &checkFGrad, T &checkOGrad, + typename hppl::BackwardActType::type actInput, + typename hppl::BackwardActType::type actGate, + typename hppl::BackwardActType::type actState) { gradOg = actGate(outputGrad * stateAtv, valueOg); stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; gradIn = actInput(stateGrad * valueIg, valueIn); @@ -95,21 +95,22 @@ class lstm { checkOGrad = gradOg * state; } #ifndef __NVCC__ -#ifndef __AVX__ +#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default static const bool avx = false; #else - static const bool avx = true; - INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, - __m256 &valueOg, __m256 &gradIn, __m256 &gradIg, - __m256 &gradFg, __m256 &gradOg, __m256 &prevState, - __m256 &prevStateGrad, __m256 &state, - __m256 &stateGrad, __m256 &stateAtv, - __m256 &outputGrad, __m256 &checkI, __m256 &checkF, - __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad, - __m256 &checkOGrad, - hppl::Active<__m256>::backward actInput, - hppl::Active<__m256>::backward actGate, - hppl::Active<__m256>::backward actState) { + // Only float support AVX optimization + static const bool avx = std::is_same::value; + HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &gradIn, __m256 &gradIg, + __m256 &gradFg, __m256 &gradOg, __m256 &prevState, + __m256 &prevStateGrad, __m256 &state, + __m256 &stateGrad, __m256 &stateAtv, + __m256 &outputGrad, __m256 &checkI, __m256 &checkF, + __m256 &checkO, __m256 &checkIGrad, + __m256 &checkFGrad, __m256 &checkOGrad, + hppl::Active<__m256>::backward actInput, + hppl::Active<__m256>::backward actGate, + hppl::Active<__m256>::backward actState) { gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg); stateGrad = _mm256_add_ps( actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad); diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc index d1c63bafe11..0febf8e3b70 100644 --- a/paddle/operators/math/lstm_compute.cc +++ b/paddle/operators/math/lstm_compute.cc @@ -24,8 +24,8 @@ template struct LstmUnitFunctor { static void compute(const platform::DeviceContext& context, LstmMetaValue value, int frame_size, int batch_size, - std::string gate_act, std::string cell_act, - std::string cand_act) { + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, ActiveType(cand_act), ActiveType(gate_act), @@ -45,8 +45,9 @@ template struct LstmUnitGradFunctor { static void compute(const platform::DeviceContext& context, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, std::string gate_act, - std::string cell_act, std::string cand_act) { + int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, frame_size, ActiveType(cand_act), diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu index d942f60a269..b2122f2a5c0 100644 --- a/paddle/operators/math/lstm_compute.cu +++ b/paddle/operators/math/lstm_compute.cu @@ -24,8 +24,8 @@ template struct LstmUnitFunctor { static void compute(const platform::DeviceContext& context, LstmMetaValue value, int frame_size, int batch_size, - std::string gate_act, std::string cell_act, - std::string cand_act) { + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { detail::gpu_lstm_forward(context, detail::forward::lstm(), value, frame_size, batch_size, ActiveType(cand_act), ActiveType(gate_act), ActiveType(cell_act)); @@ -36,8 +36,9 @@ template struct LstmUnitGradFunctor { static void compute(const platform::DeviceContext& context, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, std::string gate_act, - std::string cell_act, std::string cand_act) { + int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, frame_size, batch_size, ActiveType(cand_act), ActiveType(gate_act), ActiveType(cell_act)); diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index c58a1ad0d66..28d2c6fd3b0 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -72,8 +72,8 @@ class LstmUnitFunctor { public: static void compute(const platform::DeviceContext &context, LstmMetaValue value, int frame_size, int batch_size, - std::string gate_act, std::string cell_act, - std::string cand_act); + const std::string &gate_act, const std::string &cell_act, + const std::string &cand_act); }; template @@ -81,8 +81,9 @@ class LstmUnitGradFunctor { public: static void compute(const platform::DeviceContext &context, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, std::string gate_act, - std::string cell_act, std::string cand_act); + int frame_size, int batch_size, + const std::string &gate_act, const std::string &cell_act, + const std::string &cand_act); }; } // namespace math diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index 10c6e105b95..00de56f7cd5 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -51,8 +51,6 @@ class CopyMatrixRowsFunctor { template class CopyMatrixRowsFunctor; template class CopyMatrixRowsFunctor; -template class LoDTensor2BatchFunctor; -template class LoDTensor2BatchFunctor; template class Batch2LoDTensorFunctor; template class Batch2LoDTensorFunctor; diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index e478c46db71..4f349946785 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -21,7 +21,7 @@ namespace math { template __global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index, int64_t height, int64_t width, - const bool is_src_index) { + bool is_src_index) { int idx = threadIdx.x; int idy = threadIdx.y; int id = blockIdx.x + idy * GridDimX; diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 89b51168043..690cac05870 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -31,33 +31,33 @@ class CopyMatrixRowsFunctor { // The indexed rows are based on the input index. void operator()(const platform::DeviceContext& context, const framework::LoDTensor& src, const size_t* index, - framework::LoDTensor& dst, const bool is_src_index); + framework::LoDTensor& dst, bool is_src_index); }; template class LoDTensor2BatchFunctor { + // Calculate the length of each sequence and + // sort sequence index by the length. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} + // + struct SeqInfo { + SeqInfo(int start, int length, int seq_idx) + : start(start), length(length), seq_idx(seq_idx) {} + int start; + int length; + int seq_idx; + }; + public: void operator()(const platform::DeviceContext& context, const framework::LoDTensor& lod_tensor, - framework::LoDTensor& batch, const bool is_reverse) const { + framework::LoDTensor& batch, bool is_reverse) const { auto lods = lod_tensor.lod(); PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); auto lod = lods[0]; - // Calculate the length of each sequence and - // sort sequence index by the length. - // example: sequences = {s0, s1, s2} - // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 - // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} - // - struct SeqInfo { - SeqInfo(int start, int length, int seq_idx) - : start(start), length(length), seq_idx(seq_idx) {} - int start; - int length; - int seq_idx; - }; - std::vector seq_info; for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { int length = lod[seq_id + 1] - lod[seq_id]; @@ -75,31 +75,34 @@ class LoDTensor2BatchFunctor { // batchIndex = {b0, b1, b2, b3, b4} // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} + // batch_start_positions[0] = len(b0) + // batch_start_positions[1] = len(b0) + len(b1) + // batch_start_positions[2] = len(b0) + len(b1) + len(b2) + // ... // seq2batch_idx[12] = {4, 0, 9, // 5, 1, 10, // 6, 2, 11, // 7, 3, // 8} - // The batch number represents batch size after rearranging the // input LodTensor. It is also the maximum length of input sequence. paddle::framework::LoD batch_lods; - batch_lods.push_back(std::vector{0}); - batch_lods.push_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); // batch_lods[0] is the start positions for batch LoDTensor - int num_batch = (size_t)seq_info[0].length; - batch_lods[0].resize(num_batch + 1); + int num_batch = seq_info[0].length; + batch_lods[0].resize(static_cast(num_batch + 1)); // batch_lods[1] is the raw index in the input LoDTensor auto dims = lod_tensor.dims(); - batch_lods[1].resize(dims[0]); + batch_lods[1].resize(static_cast(dims[0])); size_t* batch_starts = batch_lods[0].data(); size_t* seq2batch_idx = batch_lods[1].data(); batch_starts[0] = 0; for (size_t n = 0; n < num_batch; n++) { - int batch_id = batch_starts[n]; + auto batch_id = static_cast(batch_starts[n]); for (size_t i = 0; i < seq_info.size(); ++i) { size_t seq_len = seq_info[i].length; int start = seq_info[i].start; @@ -114,7 +117,7 @@ class LoDTensor2BatchFunctor { break; } } - batch_starts[n + 1] = batch_id; + batch_starts[n + 1] = static_cast(batch_id); } batch.set_lod(batch_lods); -- GitLab