diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 73ab9b18dcb276f5768d32bad2aecbb48668fb2b..10b60e3de627207210e26839339c8bd7046deb9b 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -82,6 +82,13 @@ class LSTMOp : public framework::OperatorWithKernel { ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Cell"); } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType( + ctx.Input("Input")->type()); + } }; class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { @@ -239,6 +246,13 @@ class LSTMGradOp : public framework::OperatorWithKernel { if (ctx->HasOutput(b_g_name)) ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias")); } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType( + ctx.Input("Input")->type()); + } }; } // namespace operators diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h index 74d51d7bc9b91f4c8088384d77183131f57aafab..d0ed55ea168bc3e701c421c51d662c646e475351 100644 --- a/paddle/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -26,10 +26,7 @@ namespace detail { template void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frameSize, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int frameSize) { T rValueIn; T rValueIg; T rValueFg; @@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, rPrevState = value.prevStateValue[i]; } - hppl::cpu::ForwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), - act(active_state)); + rOut, rCheckI, rCheckF, rCheckO); valueIn[i] = rValueIn; valueIg[i] = rValueIg; @@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, template void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frameSize, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + LstmMetaGrad grad, int frameSize) { T rValueIn; T rValueIg; T rValueFg; @@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, rPrevState = value.prevStateValue[i]; } - hppl::cpu::BackwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, - rCheckOGrad, act(active_node), act(active_gate), act(active_state)); + rCheckOGrad); gradIn[i] = rGradIn; gradIg[i] = rGradIg; @@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue value, int frameSize, avx_lstm_forward_one_sequence(op, value, frameSize, active_node, active_gate, active_state); } else { - naive_lstm_forward_one_sequence(op, value, frameSize, active_node, - active_gate, active_state); + naive_lstm_forward_one_sequence(op, value, frameSize); } } @@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, active_gate, active_state); } else { - naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, - active_gate, active_state); + naive_lstm_backward_one_sequence(op, value, grad, frameSize); } } diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 9573eaefb6a9d678ef70f2e2bffdc6a3011b21ea..c06f164f84a92d31f89901e2656bdb8e69c533b7 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -32,9 +32,7 @@ namespace detail { */ template __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, - int batchSize, activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int batchSize) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, rPrevState = value.prevStateValue[frameIdx]; } - hppl::gpu::ForwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), - act(active_state)); + rOut, rCheckI, rCheckF, rCheckO); value.gateValue[frameIdx] = rValueIn; value.gateValue[frameIdx + frameSize] = rValueIg; @@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, template __global__ void KeLstmBackward(Op op, LstmMetaValue value, LstmMetaGrad grad, int frameSize, - int batchSize, activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int batchSize) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, rPrevState = value.prevStateValue[frameIdx]; } - hppl::gpu::BackwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, - rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, - act(active_node), act(active_gate), act(active_state)); + rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad); grad.gateGrad[frameIdx] = rGradIn; grad.gateGrad[frameIdx + frameSize] = rGradIg; @@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, if (batchSize == 1) { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, frameSize, batchSize); } else { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, frameSize, batchSize); } } @@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, if (batchSize == 1) { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, grad, frameSize, batchSize); } else { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, grad, frameSize, batchSize); } } diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h index 6f3ead2397d5131b4468d0ad288513cedb289594..461039a4d51a2b9b8a55d3101bdf4c511907597e 100644 --- a/paddle/operators/math/detail/lstm_kernel.h +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -24,15 +24,29 @@ namespace detail { namespace forward { +template +DEVICE inline T sigmoid(const T a) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + T tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +template +DEVICE inline T tanh(const T a) { + T tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + template class lstm { public: HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, T &prevState, T &state, T &stateAtv, T &output, - T &checkI, T &checkF, T &checkO, - typename hppl::ForwardActType::type actInput, - typename hppl::ForwardActType::type actGate, - typename hppl::ForwardActType::type actState) { + T &checkI, T &checkF, T &checkO) { +#if 0 + // TODO(qingqing) support to activation speficed by users valueIn = actInput(valueIn); valueIg = actGate(valueIg + prevState * checkI); valueFg = actGate(valueFg + prevState * checkF); @@ -40,6 +54,15 @@ class lstm { valueOg = actGate(valueOg + state * checkO); stateAtv = actState(state); output = valueOg * stateAtv; +#else + valueIn = tanh(valueIn); + valueIg = sigmoid(valueIg + prevState * checkI); + valueFg = sigmoid(valueFg + prevState * checkF); + state = valueIn * valueIg + prevState * valueFg; + valueOg = sigmoid(valueOg + state * checkO); + stateAtv = tanh(state); + output = valueOg * stateAtv; +#endif } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -72,6 +95,16 @@ class lstm { namespace backward { +template +DEVICE inline T sigmoid(const T a, const T b) { + return a * b * (1.0 - b); +} + +template +DEVICE inline T tanh(const T a, const T b) { + return a * (1.0 - b * b); +} + template class lstm { public: @@ -80,10 +113,9 @@ class lstm { T &prevState, T &prevStateGrad, T &state, T &stateGrad, T &stateAtv, T &outputGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad, - T &checkFGrad, T &checkOGrad, - typename hppl::BackwardActType::type actInput, - typename hppl::BackwardActType::type actGate, - typename hppl::BackwardActType::type actState) { + T &checkFGrad, T &checkOGrad) { +#if 0 + // TODO(qingqing) support to activation speficed by users gradOg = actGate(outputGrad * stateAtv, valueOg); stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; gradIn = actInput(stateGrad * valueIg, valueIn); @@ -93,6 +125,17 @@ class lstm { checkIGrad = gradIg * prevState; checkFGrad = gradFg * prevState; checkOGrad = gradOg * state; +#else + gradOg = sigmoid(outputGrad * stateAtv, valueOg); + stateGrad += tanh(outputGrad * valueOg, stateAtv) + gradOg * checkO; + gradIn = tanh(stateGrad * valueIg, valueIn); + gradIg = sigmoid(stateGrad * valueIn, valueIg); + gradFg = sigmoid(stateGrad * prevState, valueFg); + prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; + checkIGrad = gradIg * prevState; + checkFGrad = gradFg * prevState; + checkOGrad = gradOg * state; +#endif } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index 7f428cd617cb57076e4e2e5a9f649fa52ccfbbfb..f308ba82fa6f6b3a973fa636e7d1b31fc5f77fb5 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -110,7 +110,7 @@ def lstm( class TestLstmOp(OpTest): def set_argument(self): - self.lod = [[0, 2, 6]] + self.lod = [[0, 2, 5, 7]] self.D = 16 self.act_gate = 'sigmoid' @@ -164,12 +164,13 @@ class TestLstmOp(OpTest): # TODO(qingqing) remove folowing two lines after the check_grad is refined. self.outputs['BatchGate'] = None self.outputs['BatchCellPreAct'] = None - self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden']) + self.check_grad( + ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=0.02) class TestLstmOpHasNoInitial(TestLstmOp): def set_argument(self): - self.lod = [[0, 2, 6]] + self.lod = [[0, 2, 5, 7]] self.D = 16 self.act_gate = 'sigmoid' @@ -182,7 +183,7 @@ class TestLstmOpHasNoInitial(TestLstmOp): class TestLstmOpRerverse(TestLstmOp): def set_argument(self): - self.lod = [[0, 2, 6]] + self.lod = [[0, 2, 5, 7]] self.D = 16 self.act_gate = 'sigmoid'