diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 7ce774a285087aca11e4bfe6f829bff5296c70ed..0c53ed3cdc5e5190c0f1078ddc37da74bf15d94c 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -127,7 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) -op_library(lstm_op DEPS sequence2batch) +op_library(lstm_op DEPS sequence2batch lstm_compute math_function) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 7a72a08c50aa0528c23ecd174f3662cddceb7734..f360502e666fe75b939a94f45df4ddca846f5c43 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -44,7 +44,7 @@ class LSTMOp : public framework::OperatorWithKernel { "should be the same."); } - int frame_size = x_dims[1]; + int frame_size = x_dims[1] / 4; auto w_dims = ctx->GetInputDim("Weight"); PADDLE_ENFORCE_EQ(w_dims.size(), 2, "The rank of Input(Weight) should be 2."); @@ -71,9 +71,9 @@ class LSTMOp : public framework::OperatorWithKernel { "4 * %d if diable peepholes connection", frame_size); } - ctx->SetOutputDim("Hidden", x_dims); - ctx->SetOutputDim("Cell", x_dims); - ctx->SetOutputDim("Batch", x_dims); + ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); + ctx->SetOutputDim("Cell", {x_dims[0], frame_size}); + ctx->SetOutputDim("BatchGate", x_dims); ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Cell"); } diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 6924cba68fbd0e13499b5e814c00f7b2ccfc01cb..affa44c6fb1ac0b194b86c7012fc73e6ace14268 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -52,9 +52,14 @@ class LSTMKernel : public framework::OpKernel { to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); auto in_dims = input->dims(); - int frame_size = in_dims[1]; + int frame_size = in_dims[1] / 4; + framework::DDim dims({in_dims[0], frame_size}); if (bias) { + // framework::Tensor cpu_t; + // cpu_t.mutable_data(in_dims, platform::CPUPlace()); + // cpu_t.CopyFrom(*batch_gate, platform::CPUPlace(), + // ctx.device_context()); Eigen::array extents({{1, 4 * frame_size}}); Eigen::array offsets({{0, 0}}); auto b = EigenMatrix::From(*bias); @@ -76,15 +81,14 @@ class LSTMKernel : public framework::OpKernel { lstm_value.prevStateValue = nullptr; framework::LoDTensor batch_out; - batch_out.mutable_data(in_dims, ctx.GetPlace()); + batch_out.mutable_data(dims, ctx.GetPlace()); framework::LoDTensor batch_cell; - batch_cell.mutable_data(in_dims, ctx.GetPlace()); + batch_cell.mutable_data(dims, ctx.GetPlace()); framework::LoDTensor batch_cell_pre_act; - batch_cell_pre_act.mutable_data(in_dims, ctx.GetPlace()); + batch_cell_pre_act.mutable_data(dims, ctx.GetPlace()); auto batch_lod = batch_gate->lod()[0]; int num_batch = batch_lod.size() - 1; - auto gate_act = ctx.Attr("gateActivation"); auto cell_act = ctx.Attr("cellActivation"); auto cand_act = ctx.Attr("candidateActivation"); @@ -125,9 +129,12 @@ class LSTMKernel : public framework::OpKernel { // restore the output hidden in LoDTensor from the batch hidden to_seq(ctx.device_context(), batch_out, *hidden_out); - batch_out.set_lod(batch_gate->lod()); + batch_cell.set_lod(batch_gate->lod()); // restore the output cell state in LoDTensor from the batch cell to_seq(ctx.device_context(), batch_cell, *cell_out); + + auto t = framework::EigenVector::Flatten(*batch_gate); + t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); } }; diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 794ffc39974eceb48524b0b53162266f7538e13f..2771b5de40311d367b899bbae48a2e1708b6e2a9 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(detail) + if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) @@ -6,7 +8,7 @@ if(WITH_GPU) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) - nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context) + nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) @@ -14,7 +16,7 @@ else() cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) - cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context) + cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) endif() cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/detail/hl_avx_functions.cc b/paddle/operators/math/detail/hl_avx_functions.cc index 70e7d80304859b8532718f2d4841f071c88cd100..415bac5d93ee00244d072b0998c6941b14d4f8d8 100644 --- a/paddle/operators/math/detail/hl_avx_functions.cc +++ b/paddle/operators/math/detail/hl_avx_functions.cc @@ -14,10 +14,12 @@ limitations under the License. */ #include #include "hl_functions.h" +// TODO(qingqing) refine this dependence +#include "paddle/cuda/src/avx_mathfun.h" namespace hppl { -extern __m256 exp(__m256 a); +__m256 exp(__m256 a) { return exp256_ps(a); } __m256 relu(const __m256 a) { __m256 tmp = _mm256_set1_ps(0.0f); diff --git a/paddle/operators/math/detail/hl_cpu_functions.cc b/paddle/operators/math/detail/hl_cpu_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..21ec78f9629af0e4673a56517d76ac6734f57db8 --- /dev/null +++ b/paddle/operators/math/detail/hl_cpu_functions.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "hl_functions.h" + +namespace hppl { +namespace typef { + +float relu(const float a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +float sigmoid(const float a) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + float tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +float tanh(const float a) { + float tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +float linear(const float a) { return a; } + +float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); } + +float sigmoid(const float a, const float b) { + return a * b * (static_cast(1) - b); +} + +float tanh(const float a, const float b) { + return a * (static_cast(1) - b * b); +} + +float linear(const float a, const float b) { return a; } + +} // namespace typef + +namespace typed { +double relu(const double a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +double sigmoid(const double a) { + const double min = SIGMOID_THRESHOLD_MIN; + const double max = SIGMOID_THRESHOLD_MAX; + double tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +double tanh(const double a) { + double tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +double linear(const double a) { return a; } + +double relu(const double a, const double b) { + return a * (b > 0.0 ? 1.0 : 0.0); +} + +double sigmoid(const double a, const double b) { + return a * b * (static_cast(1) - b); +} + +double tanh(const double a, const double b) { + return a * (static_cast(1) - b * b); +} + +double linear(const double a, const double b) { return a; } + +} // namespace typed +} // namespace hppl diff --git a/paddle/operators/math/detail/hl_functions.h b/paddle/operators/math/detail/hl_functions.h index c77c119dfed205135dca728392a1b5c10f7b5118..3e2f0c9ee6d3ae2ed598c4d5f09b85b7d61fdd51 100644 --- a/paddle/operators/math/detail/hl_functions.h +++ b/paddle/operators/math/detail/hl_functions.h @@ -34,83 +34,28 @@ limitations under the License. */ #ifndef __NVCC__ namespace hppl { namespace typef { -/* - * forward activation - */ -float relu(const float a) { - return a > static_cast(0.0) ? a : static_cast(0.0); -} - -float sigmoid(const float a) { - const float min = SIGMOID_THRESHOLD_MIN; - const float max = SIGMOID_THRESHOLD_MAX; - float tmp = (a < min) ? min : ((a > max) ? max : a); - return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); -} - -float tanh(const float a) { - float tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(tmp))) - 1.0; -} - -float linear(const float a) { return a; } - -/* - * backward activation - */ -float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); } +float relu(const float a); +float sigmoid(const float a); +float tanh(const float a); +float linear(const float a); -float sigmoid(const float a, const float b) { - return a * b * (static_cast(1) - b); -} +float relu(const float a, const float b); +float sigmoid(const float a, const float b); +float tanh(const float a, const float b); +float linear(const float a, const float b); -float tanh(const float a, const float b) { - return a * (static_cast(1) - b * b); -} - -float linear(const float a, const float b) { return a; } } // namespace typef namespace typed { -/* - * forward activation - */ -double relu(const double a) { - return a > static_cast(0.0) ? a : static_cast(0.0); -} - -double sigmoid(const double a) { - const double min = SIGMOID_THRESHOLD_MIN; - const double max = SIGMOID_THRESHOLD_MAX; - double tmp = (a < min) ? min : ((a > max) ? max : a); - return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); -} - -double tanh(const double a) { - double tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(tmp))) - 1.0; -} - -double linear(const double a) { return a; } - -/* - * backward activation - */ -double relu(const double a, const double b) { - return a * (b > 0.0 ? 1.0 : 0.0); -} - -double sigmoid(const double a, const double b) { - return a * b * (static_cast(1) - b); -} - -double tanh(const double a, const double b) { - return a * (static_cast(1) - b * b); -} - -double linear(const double a, const double b) { return a; } +double relu(const double a); +double sigmoid(const double a); +double tanh(const double a); +double linear(const double a); + +double relu(const double a, const double b); +double sigmoid(const double a, const double b); +double tanh(const double a, const double b); +double linear(const double a, const double b); } // namespace typed } // namespace hppl diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 01310a49f8c1903f40e9a7b9ac975a79df5fde5f..36f303034854e3d179f5975144b1b365a488810c 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -19,6 +19,8 @@ limitations under the License. */ #include "paddle/platform/cuda_helper.h" #include "paddle/platform/device_context.h" +#include + namespace paddle { namespace operators { namespace math { @@ -29,11 +31,10 @@ namespace detail { * grid(frameBlocks, batchBlocks) */ template -__global__ void KeLstmForward( - Op op, LstmMetaValue value, int frameSize, int batchSize, - typename hppl::ForwardActType::type active_node, - typename hppl::ForwardActType::type active_gate, - typename hppl::ForwardActType::type active_state) { +__global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -69,8 +70,10 @@ __global__ void KeLstmForward( rPrevState = value.prevStateValue[frameIdx]; } + hppl::gpu::ForwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); + rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), + act(active_state)); value.gateValue[frameIdx] = rValueIn; value.gateValue[frameIdx + frameSize] = rValueIg; @@ -87,11 +90,11 @@ __global__ void KeLstmForward( * grid(frameBlocks, batchBlocks) */ template -__global__ void KeLstmBackward( - Op op, LstmMetaValue value, LstmMetaGrad grad, int frameSize, - int batchSize, typename hppl::BackwardActType::type active_node, - typename hppl::BackwardActType::type active_gate, - typename hppl::BackwardActType::type active_state) { +__global__ void KeLstmBackward(Op op, LstmMetaValue value, + LstmMetaGrad grad, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -142,10 +145,11 @@ __global__ void KeLstmBackward( rPrevState = value.prevStateValue[frameIdx]; } + hppl::gpu::BackwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, - active_node, active_gate, active_state); + act(active_node), act(active_gate), act(active_state)); grad.gateGrad[frameIdx] = rGradIn; grad.gateGrad[frameIdx + frameSize] = rGradIg; @@ -196,22 +200,16 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); } - using type = typename hppl::ForwardActType::type; - hppl::gpu::ForwardAct act; - type act_node = act(active_node); - type act_gate = act(active_gate); - type act_state = act(active_state); - auto stream = reinterpret_cast(context).stream(); if (batchSize == 1) { KeLstmForward<<>>( - op, value, frameSize, batchSize, act_node, act_gate, act_state); + op, value, frameSize, batchSize, active_node, active_gate, active_gate); } else { KeLstmForward<<>>( - op, value, frameSize, batchSize, act_node, act_gate, act_state); + op, value, frameSize, batchSize, active_node, active_gate, active_gate); } } @@ -235,22 +233,18 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); } - using type = typename hppl::BackwardActType::type; - hppl::gpu::BackwardAct act; - type act_node = act(active_node); - type act_gate = act(active_gate); - type act_state = act(active_state); - auto stream = reinterpret_cast(context).stream(); if (batchSize == 1) { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, act_node, act_gate, act_state); + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); } else { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, act_node, act_gate, act_state); + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); } } diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc index 293c9da3a0c65c33e484777c21d09e765b0e793b..d1c63bafe112468e26b3dbe3ba92354a6393846c 100644 --- a/paddle/operators/math/lstm_compute.cc +++ b/paddle/operators/math/lstm_compute.cc @@ -72,6 +72,8 @@ struct LstmUnitGradFunctor { }; template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; template class LstmUnitGradFunctor; } // namespace math diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu index aade604b9ec555a91490bc535f07e9dba454a59c..d942f60a269eda283e68975de6f9eec0eb1c2733 100644 --- a/paddle/operators/math/lstm_compute.cu +++ b/paddle/operators/math/lstm_compute.cu @@ -26,18 +26,9 @@ struct LstmUnitFunctor { LstmMetaValue value, int frame_size, int batch_size, std::string gate_act, std::string cell_act, std::string cand_act) { - for (int b = 0; b < batch_size; b++) { - detail::gpu_lstm_forward(context, detail::forward::lstm(), value, - frame_size, batch_size, ActiveType(cand_act), - ActiveType(gate_act), ActiveType(cell_act)); - value.gateValue += frame_size * 4; - value.stateValue += frame_size; - value.stateActiveValue += frame_size; - value.outputValue += frame_size; - if (value.prevStateValue) { - value.prevStateValue += frame_size; - } - } + detail::gpu_lstm_forward(context, detail::forward::lstm(), value, + frame_size, batch_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); } }; @@ -47,32 +38,15 @@ struct LstmUnitGradFunctor { LstmMetaValue value, LstmMetaGrad grad, int frame_size, int batch_size, std::string gate_act, std::string cell_act, std::string cand_act) { - for (int b = 0; b < batch_size; b++) { - detail::gpu_lstm_backward(context, detail::backward::lstm(), value, - grad, frame_size, batch_size, - ActiveType(cand_act), ActiveType(gate_act), - ActiveType(cell_act)); - - value.gateValue += frame_size * 4; - value.stateValue += frame_size; - value.stateActiveValue += frame_size; - value.outputValue += frame_size; - if (value.prevStateValue) { - value.prevStateValue += frame_size; - } - - grad.gateGrad += frame_size * 4; - grad.stateGrad += frame_size; - grad.stateActiveGrad += frame_size; - grad.outputGrad += frame_size; - if (grad.prevStateGrad) { - grad.prevStateGrad += frame_size; - } - } + detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, + frame_size, batch_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); } }; template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; template class LstmUnitGradFunctor; } // namespace math diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index ebf765c02e380d0aa4d57e24cb736c065553920a..bff9dd3ea407176cb29e639dfcea844b0739cb0d 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -53,7 +53,7 @@ struct LstmMetaGrad { T *checkOgGrad; }; -activation_mode_t ActiveType(const std::string &type) { +inline activation_mode_t ActiveType(const std::string &type) { if (type == "sigmoid") { return HL_ACTIVATION_SIGMOID; } else if (type == "relu") { diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 3813d712388524714ef389d33253da8d89346da0..89b511680435812202f7770ff0e587c5b77227e0 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -59,7 +59,7 @@ class LoDTensor2BatchFunctor { }; std::vector seq_info; - for (size_t seq_id = 0; seq_id < lod.size(); ++seq_id) { + for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { int length = lod[seq_id + 1] - lod[seq_id]; seq_info.emplace_back(lod[seq_id], length, seq_id); } @@ -83,10 +83,11 @@ class LoDTensor2BatchFunctor { // The batch number represents batch size after rearranging the // input LodTensor. It is also the maximum length of input sequence. - auto batch_lods = batch.lod(); - if (batch_lods.size() == 0) { - batch_lods.resize(2); - } + + paddle::framework::LoD batch_lods; + batch_lods.push_back(std::vector{0}); + batch_lods.push_back(std::vector{0}); + // batch_lods[0] is the start positions for batch LoDTensor int num_batch = (size_t)seq_info[0].length; batch_lods[0].resize(num_batch + 1); @@ -115,6 +116,7 @@ class LoDTensor2BatchFunctor { } batch_starts[n + 1] = batch_id; } + batch.set_lod(batch_lods); CopyMatrixRowsFunctor to_batch; to_batch(context, lod_tensor, seq2batch_idx, batch, true); @@ -130,12 +132,13 @@ class Batch2LoDTensorFunctor { auto in_lod = batch.lod(); PADDLE_ENFORCE_EQ(in_lod.size(), 2UL, "The LoD size of input `batch` should be 2."); - auto out_lod = lod_tensor.lod(); - PADDLE_ENFORCE_EQ(out_lod[0][0], out_lod[1].size()); - PADDLE_ENFORCE_EQ(out_lod[0][0], lod_tensor.dims()[0]); - PADDLE_ENFORCE_EQ(out_lod[0][0], batch.dims()[0]); + auto out_lod = lod_tensor.lod()[0]; + auto num = out_lod[out_lod.size() - 1]; + PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]); + PADDLE_ENFORCE_EQ(num, in_lod[1].size()); + PADDLE_ENFORCE_EQ(num, batch.dims()[0]); CopyMatrixRowsFunctor to_seq; - size_t* index = out_lod[1].data(); + size_t* index = in_lod[1].data(); to_seq(context, batch, index, lod_tensor, false); } }; diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index f3f4c84b2a2c169572419b27d399cdb763cd4c22..aa6a21b54750fdf8519ad61818b29c94dbf3666d 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -2,17 +2,26 @@ import unittest import numpy as np from op_test import OpTest +SIGMOID_THRESHOLD_MIN = -40.0 +SIGMOID_THRESHOLD_MAX = 13.0 +EXP_MAX_INPUT = 40.0 + def identity(x): return x def sigmoid(x): - return 1. / (1. + np.exp(-x)) + y = np.copy(x) + y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN + y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX + return 1. / (1. + np.exp(-y)) def tanh(x): - return 2. * sigmoid(2. * x) - 1. + y = -2. * x + y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT + return (2. / (1. + np.exp(y))) - 1. def relu(x): @@ -35,7 +44,7 @@ def lstm( g = np.dot(h_pre, w_h) # 1 x 4D g = g + x g = np.reshape(g, (1, g.size)) - c, g_i, g_f, g_o = np.split(g, 4, axis=1) + c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1) if w_c is None: g_i = gate_act(g_i) # 1 x D g_f = gate_act(g_f) # 1 x D @@ -43,7 +52,7 @@ def lstm( w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1) g_i = gate_act(g_i + w_ic * c_pre) # 1 x D g_f = gate_act(g_f + w_fc * c_pre) # 1 x D - c = g_f * c_pre + g_i * cand_act(c) # 1 x D + c = g_f * c_pre + g_i * cand_act(c_tmp) # 1 x D if w_c is None: g_o = gate_act(g_o) # 1 x D @@ -51,12 +60,14 @@ def lstm( _, _, w_oc = np.split(w_c, 3, axis=1) g_o = gate_act(g_o + w_oc * c) # 1 x D h = g_o * cell_act(c) - return h, c + bg = np.concatenate((cand_act(c_tmp), g_i, g_f, g_o), axis=1) + return h, c, bg offset = lod[0] batch_size = len(offset) - 1 hidden = [] cell = [] + gate = [] if w_b is not None: input = input + np.tile(w_b, (offset[-1], 1)) for i in range(batch_size): @@ -64,44 +75,62 @@ def lstm( seq_len = offset[i + 1] - offset[i] x = input[offset[i]:offset[i + 1], :] h_pre = h0[i] # 1 x D - c_pre = h0[i] # 1 x D + c_pre = c0[i] # 1 x D for j in range(seq_len): # compute one step - h_pre, c_pre = _step(x[j], w_h, w_c, h_pre, c_pre, gate_act, - cell_act, cand_act) + h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, gate_act, + cell_act, cand_act) hidden.append(h_pre.flatten()) cell.append(c_pre.flatten()) + gate.append(g_pre.flatten()) hidden = np.array(hidden).astype("float64") cell = np.array(cell).astype("float64") + gate = np.array(gate).astype("float64") + assert gate.shape == input.shape assert hidden.shape == (input.shape[0], input.shape[1] / 4) assert cell.shape == (input.shape[0], input.shape[1] / 4) - return hidden, cell + return hidden, cell, gate class LstmUnitTest(OpTest): def set_data(self): - lod = [[0, 2, 6, 9]] - shape = (9, 64) - - x = np.random.normal(size=(9, 4 * 64)).astype("float64") - h0 = np.random.normal(size=(4, 64)).astype("float64") - c0 = np.random.normal(size=(4, 64)).astype("float64") - w = np.random.normal(size=(64, 4 * 64)).astype("float64") - b = np.random.normal(size=(1, 7 * 64)).astype("float64") - - w_b = b[:, 4 * 64] - w_c = b[:, 4 * 64:] - h, c = lstm(x, lod, h0, c0, w, w_b, w_c, False, sigmoid, tanh, tanh) - - self.inputs = {'Input': x, 'H0': h0, 'C0': c0, 'Weight': w, 'Bias': b} - self.inputs = {'Hidden': h, 'Cell': c} + D = 4 + #lod = [[0, 2, 6, 9]] + lod = [[0, 1]] + shape = (1, D) + + x = np.random.normal(size=(1, 4 * D)).astype("float64") + h0 = np.zeros((4, D)).astype("float64") + c0 = np.zeros((4, D)).astype("float64") + w = np.random.normal(size=(D, 4 * D)).astype("float64") + b = np.random.normal(size=(1, 7 * D)).astype("float64") + + w_b = b[:, 0:4 * D] + w_c = b[:, 4 * D:] + #h, c, g = lstm(x, lod, h0, c0, w, w_b, w_c, False, sigmoid, tanh, tanh) + h, c, g = lstm(x, lod, h0, c0, w, w_b, w_c, False, identity, identity, + identity) + + g_sort = np.zeros_like(x) + #idx = [2,6,0,3,7,1,4,8,5] + #for i, j in enumerate(idx): + # g_sort[i, :] = g[j, :] + + self.inputs = { + 'Input': (x, lod), + 'H0': h0, + 'C0': c0, + 'Weight': w, + 'Bias': b + } + self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort} self.attrs = { 'usePeepholes': True, 'isReverse': False, - 'gateActivation': 'sigmoid', - 'cellActivation': 'tanh', - 'candidateActivation': 'tanh' + 'gateActivation': 'linear', + 'cellActivation': 'linear', + 'candidateActivation': 'linear' } def setUp(self):