diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index f24cf96cce30bf4b21e954ab022b2b99aeff37e6..0dde33813fdb9e942406e0b9949f26a36fb1dbf1 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -71,7 +71,7 @@ paddle.fluid.initializer.NumpyArrayInitializer.__init__ ArgSpec(args=['self', 'v paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None)) paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')) paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None)) -paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None)) +paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name', 'h_0', 'c_0', 'cell_clip', 'proj_clip'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None, None, None, None, None)) paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False)) paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False)) paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,)) diff --git a/paddle/fluid/operators/lstm_op.h b/paddle/fluid/operators/lstm_op.h index 3f110024b285d41ccfe305e35c8efca5ed5ee0fe..ca998826dd0118ab4b1ecc23bed8ef882f1bcc92 100644 --- a/paddle/fluid/operators/lstm_op.h +++ b/paddle/fluid/operators/lstm_op.h @@ -151,9 +151,10 @@ class LSTMKernel : public framework::OpKernel { lstm_value.output_value = out_t.data(); lstm_value.state_value = cell_t.data(); lstm_value.state_active_value = cell_pre_act_t.data(); + T cell_clip = 0.0; math::LstmUnitFunctor::compute( - device_ctx, lstm_value, frame_size, cur_batch_size, gate_act, - cell_act, cand_act); + device_ctx, lstm_value, frame_size, cur_batch_size, cell_clip, + gate_act, cell_act, cand_act); lstm_value.prev_state_value = lstm_value.state_value; } @@ -316,9 +317,10 @@ class LSTMGradKernel : public framework::OpKernel { lstm_value.output_value = nullptr; lstm_grad.state_active_grad = nullptr; int cur_batch_size = bend - bstart; + T cell_clip = 0.0; math::LstmUnitGradFunctor::compute( device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size, - gate_act, cell_act, cand_act); + cell_clip, gate_act, cell_act, cand_act); if (n > 0) { int pre_h_start = static_cast(batch_starts[n - 1]); diff --git a/paddle/fluid/operators/lstmp_op.cc b/paddle/fluid/operators/lstmp_op.cc index 7a62bc9f828e4d3485628747cdf52c60c5354144..2728aa8a4ee21a9e1fe3deddcdba4c35a6aba7bc 100644 --- a/paddle/fluid/operators/lstmp_op.cc +++ b/paddle/fluid/operators/lstmp_op.cc @@ -73,12 +73,6 @@ class LSTMPOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("C0"), "Input(C0) of LSTMP operator should not be null after " "Input(H0) provided."); - auto h_dims = ctx->GetInputDim("H0"); - auto c_dims = ctx->GetInputDim("C0"); - PADDLE_ENFORCE(h_dims == c_dims, - "The dimension of Input(H0) and Input(C0) " - "should be the same."); - ctx->SetOutputDim("OrderedP0", {h_dims[0], proj_dims[1]}); } auto b_dims = ctx->GetInputDim("Bias"); @@ -180,11 +174,6 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { "This LoDTensor is obtained in the forward and used in the " "backward.") .AsIntermediate(); - AddOutput("OrderedP0", - "(Tensor) the projection of the initial hidden state " - "H0. This is a tensor with shape (N x P), where N is the " - "batch size and P is the hidden size.") - .AsIntermediate(); AddAttr("use_peepholes", "(bool, defalut: True) " "whether to enable diagonal/peephole connections.") @@ -193,6 +182,16 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, defalut: False) " "whether to compute reversed LSTMP.") .SetDefault(false); + AddAttr("cell_clip", + "(float, defalut: 0.0) " + "Clip for Tensor for cell state tensor when clip value is " + "greater than 0.0") + .SetDefault(0.0); + AddAttr("proj_clip", + "(float, defalut: 0.0) " + "Clip for Tensor for projection tensor when clip value is " + "greater than 0.0") + .SetDefault(0.0); AddAttr( "gate_activation", "(string, default: sigmoid)" diff --git a/paddle/fluid/operators/lstmp_op.h b/paddle/fluid/operators/lstmp_op.h index 1f11e57dcb721012c7b8e50d7e138355685053da..c7d6e4205f8862526904e4fa767a2f4c4a2d8481 100644 --- a/paddle/fluid/operators/lstmp_op.h +++ b/paddle/fluid/operators/lstmp_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" @@ -21,17 +22,50 @@ limitations under the License. */ #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; +using platform::Transform; template using EigenMatrix = framework::EigenMatrix; +template +class _ClipFunctor { + public: + explicit _ClipFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T& x) const { + if (x < min_) + return min_; + else if (x > max_) + return max_; + else + return x; + } + + private: + T min_; + T max_; +}; + +template +class _ClipGradFunctor { + public: + explicit _ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T& x, const T& y) const { + return (y > min_ && y < max_) ? x : 0; + } + + private: + T min_; + T max_; +}; + template inline void ReorderInitState(const DeviceContext& ctx, const framework::Tensor& src, @@ -67,9 +101,11 @@ class LSTMPKernel : public framework::OpKernel { auto* bias = ctx.Input("Bias"); auto* hidden_t0 = ctx.Input("H0"); - auto* ordered_proj0 = ctx.Output("OrderedP0"); auto* cell_t0 = ctx.Input("C0"); + auto proj_clip = static_cast(ctx.Attr("proj_clip")); + auto cell_clip = static_cast(ctx.Attr("cell_clip")); + auto* batch_gate = ctx.Output("BatchGate"); batch_gate->mutable_data(ctx.GetPlace()); auto* proj_out = ctx.Output("Projection"); @@ -110,6 +146,7 @@ class LSTMPKernel : public framework::OpKernel { } lstmp_value.prev_state_value = nullptr; Tensor ordered_c0; + Tensor ordered_h0; framework::Vector order(batch_gate->lod()[2]); @@ -169,18 +206,9 @@ class LSTMPKernel : public framework::OpKernel { // Since the batch computing for LSTMP reorders the input sequence // according to their length. The initialized hidden state also needs // to reorder. - - Tensor ordered_h0; - ordered_proj0->mutable_data(ctx.GetPlace()); ReorderInitState(device_ctx, *hidden_t0, order, &ordered_h0, true); - blas.MatMul(ordered_h0, false, *proj_weight, false, static_cast(1.0), - ordered_proj0, static_cast(0.0)); - if (proj_act != math::detail::ActivationType::kIdentity) { - auto proj0_dev = EigenMatrix::From(*ordered_proj0); - ActCompute(cell_act, place, proj0_dev, proj0_dev); - } - blas.MatMul(*ordered_proj0, false, *weight, false, static_cast(1.0), + blas.MatMul(ordered_h0, false, *weight, false, static_cast(1.0), &gate_t, static_cast(1.0)); } @@ -189,8 +217,8 @@ class LSTMPKernel : public framework::OpKernel { lstmp_value.state_value = cell_t.data(); lstmp_value.state_active_value = cell_pre_act_t.data(); math::LstmUnitFunctor::compute( - device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act, - cell_act, cand_act); + device_ctx, lstmp_value, frame_size, cur_batch_size, cell_clip, + gate_act, cell_act, cand_act); lstmp_value.prev_state_value = lstmp_value.state_value; blas.MatMul(hidden_t, false, *proj_weight, false, static_cast(1.0), &proj_t, static_cast(0.0)); @@ -198,6 +226,14 @@ class LSTMPKernel : public framework::OpKernel { auto proj_t_dev = EigenMatrix::From(proj_t); ActCompute(cell_act, place, proj_t_dev, proj_t_dev); } + if (proj_clip && proj_clip > 0.0) { + T* x_data = proj_t.data(); + int64_t numel = proj_t.numel(); + Transform trans; + trans(ctx.template device_context(), x_data, + x_data + numel, x_data, + _ClipFunctor(-1.0 * proj_clip, proj_clip)); + } } math::Batch2LoDTensorFunctor to_seq; @@ -239,6 +275,9 @@ class LSTMPGradKernel : public framework::OpKernel { auto* proj_out = ctx.Input("Projection"); auto* cell_out = ctx.Input("Cell"); + auto proj_clip = static_cast(ctx.Attr("proj_clip")); + auto cell_clip = static_cast(ctx.Attr("cell_clip")); + auto* batch_gate = ctx.Input("BatchGate"); auto* batch_cell_pre_act = ctx.Input("BatchCellPreAct"); auto* batch_hidden = ctx.Input("BatchHidden"); @@ -253,7 +292,6 @@ class LSTMPGradKernel : public framework::OpKernel { auto* bias_g = ctx.Output(framework::GradVarName("Bias")); auto* h0 = ctx.Input("H0"); - auto* ordered_proj0 = ctx.Input("OrderedP0"); auto* c0 = ctx.Input("C0"); auto* h0_g = ctx.Output(framework::GradVarName("H0")); @@ -363,6 +401,17 @@ class LSTMPGradKernel : public framework::OpKernel { Tensor cur_proj = batch_proj.Slice(bstart, bend); Tensor proj_g = batch_proj_g.Slice(bstart, bend); + + if (proj_clip && proj_clip > 0.0) { + T* dx_data = proj_g.data(); + T* x_data = cur_proj.data(); + int64_t numel = proj_g.numel(); + Transform trans; + trans(ctx.template device_context(), dx_data, + dx_data + numel, x_data, dx_data, + _ClipGradFunctor(-1.0 * proj_clip, proj_clip)); + } + if (proj_act != math::detail::ActivationType::kIdentity) { auto cur_proj_dev = EigenMatrix::From(cur_proj); auto proj_g_dev = EigenMatrix::From(proj_g); @@ -412,7 +461,7 @@ class LSTMPGradKernel : public framework::OpKernel { math::LstmUnitGradFunctor::compute( device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size, - gate_act, cell_act, cand_act); + cell_clip, gate_act, cell_act, cand_act); if (n > 0) { int pre_h_start = static_cast(batch_starts[n - 1]); @@ -431,31 +480,14 @@ class LSTMPGradKernel : public framework::OpKernel { ReorderInitState(device_ctx, *h0, order, &ordered_h0, true); if (weight_g) { - blas.MatMul(*ordered_proj0, true, gate_g, false, - static_cast(1.0), weight_g, static_cast(1.0)); + blas.MatMul(ordered_h0, true, gate_g, false, static_cast(1.0), + weight_g, static_cast(1.0)); } } if (h0 && (h0_g || proj_weight_g)) { ordered_h0_g.mutable_data(h0_g->dims(), ctx.GetPlace()); - Tensor proj0_g; - proj0_g.Resize({in_dims[0], proj_weight->dims()[1]}); - proj0_g.mutable_data(ctx.GetPlace()); blas.MatMul(gate_g, false, *weight, true, static_cast(1.0), - &proj0_g, static_cast(0.0)); - if (proj_act != math::detail::ActivationType::kIdentity) { - auto proj0_dev = EigenMatrix::From(*ordered_proj0); - auto proj0_g_dev = EigenMatrix::From(proj0_g); - ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev, - proj0_g_dev); - } - if (h0_g) { - blas.MatMul(proj0_g, false, *proj_weight, true, static_cast(1.0), - &ordered_h0_g, static_cast(0.0)); - } - if (proj_weight_g) { - blas.MatMul(ordered_h0, true, proj0_g, false, static_cast(1.0), - proj_weight_g, static_cast(1.0)); - } + &ordered_h0_g, static_cast(0.0)); } } } diff --git a/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h b/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h index 2e3779ff0845294e71f27801049c010e0a585e6b..ad79c58063a8a12c703979fe32a8e671a5ade857 100644 --- a/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h @@ -32,7 +32,8 @@ namespace detail { template void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frame_size, ActivationType active_node, + int frame_size, T cell_clip, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { T r_value_in; @@ -67,7 +68,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, - active_node, active_gate, active_state); + &cell_clip, active_node, active_gate, active_state); value_in[i] = r_value_in; value_ig[i] = r_value_ig; @@ -82,7 +83,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, template void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, - ActivationType active_node, + T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { T r_value_in; @@ -135,7 +136,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, - active_node, active_gate, active_state); + &cell_clip, active_node, active_gate, active_state); grad_in[i] = r_grad_in; grad_ig[i] = r_grad_ig; @@ -154,7 +155,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, template void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frame_size, ActivationType active_node, + int frame_size, T cell_clip, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { #ifdef __AVX__ @@ -194,7 +196,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, - active_node, active_gate, active_state); + &cell_clip, active_node, active_gate, active_state); value_in[i] = r_value_in; value_ig[i] = r_value_ig; @@ -210,7 +212,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, template void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, - ActivationType active_node, + T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { #ifdef __AVX__ @@ -268,7 +270,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, - active_node, active_gate, active_state); + &cell_clip, active_node, active_gate, active_state); grad_in[i] = r_grad_in; grad_ig[i] = r_grad_ig; @@ -292,27 +294,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, template void cpu_lstm_forward(Op op, LstmMetaValue value, int frame_size, - ActivationType active_node, ActivationType active_gate, - ActivationType active_state) { + T cell_clip, ActivationType active_node, + ActivationType active_gate, ActivationType active_state) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { - avx_lstm_forward_one_sequence(op, value, frame_size, active_node, - active_gate, active_state); + avx_lstm_forward_one_sequence(op, value, frame_size, cell_clip, + active_node, active_gate, active_state); } else { - naive_lstm_forward_one_sequence(op, value, frame_size, active_node, - active_gate, active_state); + naive_lstm_forward_one_sequence(op, value, frame_size, cell_clip, + active_node, active_gate, active_state); } } template void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, ActivationType active_node, + int frame_size, T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { - avx_lstm_backward_one_sequence(op, value, grad, frame_size, active_node, - active_gate, active_state); + avx_lstm_backward_one_sequence(op, value, grad, frame_size, cell_clip, + active_node, active_gate, active_state); } else { - naive_lstm_backward_one_sequence(op, value, grad, frame_size, + naive_lstm_backward_one_sequence(op, value, grad, frame_size, cell_clip, active_node, active_gate, active_state); } } diff --git a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h index 2aecb69237fdf344ebc0bfe72d9c7c147f06358d..e0ca9e7f5b2f4a8bb837768d645b5103aa3e6760 100644 --- a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h @@ -31,7 +31,8 @@ namespace detail { */ template __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, - int batch_size, ActivationType active_node, + int batch_size, T cell_clip, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -72,7 +73,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, - active_node, active_gate, active_state); + &cell_clip, active_node, active_gate, active_state); value.gate_value[frame_idx] = r_value_in; value.gate_value[frame_idx + frame_size] = r_value_ig; @@ -91,7 +92,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, template __global__ void KeLstmBackward(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, - int batch_size, ActivationType active_node, + int batch_size, T cell_clip, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -148,8 +150,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF, - &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, active_node, - active_gate, active_state); + &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, &cell_clip, + active_node, active_gate, active_state); grad.gate_grad[frame_idx] = r_grad_in; grad.gate_grad[frame_idx + frame_size] = r_grad_ig; @@ -185,8 +187,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, template void gpu_lstm_forward(const platform::DeviceContext& context, Op op, LstmMetaValue value, int frame_size, int batch_size, - ActivationType active_node, ActivationType active_gate, - ActivationType active_state) { + T cell_clip, ActivationType active_node, + ActivationType active_gate, ActivationType active_state) { dim3 threads; dim3 grid; if (batch_size == 1) { @@ -205,12 +207,12 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, if (batch_size == 1) { KeLstmForward<<>>( - op, value, frame_size, batch_size, active_node, active_gate, + op, value, frame_size, batch_size, cell_clip, active_node, active_gate, active_state); } else { KeLstmForward<<>>( - op, value, frame_size, batch_size, active_node, active_gate, + op, value, frame_size, batch_size, cell_clip, active_node, active_gate, active_state); } } @@ -218,7 +220,7 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, template void gpu_lstm_backward(const platform::DeviceContext& context, Op op, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, + int frame_size, int batch_size, T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { dim3 threads; @@ -239,13 +241,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, if (batch_size == 1) { KeLstmBackward<<>>( - op, value, grad, frame_size, batch_size, active_node, active_gate, - active_state); + op, value, grad, frame_size, batch_size, cell_clip, active_node, + active_gate, active_state); } else { KeLstmBackward<<>>( - op, value, grad, frame_size, batch_size, active_node, active_gate, - active_state); + op, value, grad, frame_size, batch_size, cell_clip, active_node, + active_gate, active_state); } } diff --git a/paddle/fluid/operators/math/detail/lstm_kernel.h b/paddle/fluid/operators/math/detail/lstm_kernel.h index cbe73d62938d7c4c03a2c8731665260624417fd7..8149686c97a030b91e0c4de708b9abf07f83203d 100644 --- a/paddle/fluid/operators/math/detail/lstm_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_kernel.h @@ -29,7 +29,7 @@ class lstm { public: HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og, T *prev_state, T *state, T *state_atv, T *output, - T *checkI, T *checkF, T *checkO, + T *checkI, T *checkF, T *checkO, T *cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { @@ -37,6 +37,15 @@ class lstm { *value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate); *value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate); *state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg); + + if (*cell_clip > 0.0) { + if (*state < -1.0 * (*cell_clip)) { + *state = -1.0 * (*cell_clip); + } + if (*state > *cell_clip) { + *state = *cell_clip; + } + } *value_og = activation(*value_og + (*state) * (*checkO), active_gate); *state_atv = activation(*state, active_state); *output = (*value_og) * (*state_atv); @@ -52,7 +61,7 @@ class lstm { __m256 *value_fg, __m256 *value_og, __m256 *prev_state, __m256 *state, __m256 *state_atv, __m256 *output, __m256 *checkI, - __m256 *checkF, __m256 *checkO, + __m256 *checkF, __m256 *checkO, T *cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { @@ -65,6 +74,13 @@ class lstm { active_gate); *state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig), _mm256_mul_ps(*prev_state, *value_fg)); + + if (*cell_clip > 0.0f) { + __m256 min = _mm256_set1_ps(0.0f - *cell_clip); + __m256 max = _mm256_set1_ps(*cell_clip); + *state = _mm256_min_ps(max, *state); + *state = _mm256_max_ps(min, *state); + } *value_og = activation( _mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate); *state_atv = activation(*state, active_state); @@ -86,15 +102,26 @@ class lstm { T *prev_state, T *prev_state_grad, T *state, T *state_grad, T *state_atv, T *output_grad, T *checkI, T *checkF, T *checkO, T *checkIGrad, - T *checkFGrad, T *checkOGrad, + T *checkFGrad, T *checkOGrad, T *cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { *grad_og = activation((*output_grad) * (*state_atv), *value_og, active_gate); - *state_grad += - activation((*output_grad) * (*value_og), *state_atv, active_state) + - (*grad_og) * (*checkO); + if (*cell_clip > 0.0f) { + if (*state >= (*cell_clip) || *state <= (0.0f - (*cell_clip))) { + *state_grad = 0.0f; + } else { + *state_grad += + activation((*output_grad) * (*value_og), *state_atv, active_state) + + (*grad_og) * (*checkO); + } + } else { + *state_grad += + activation((*output_grad) * (*value_og), *state_atv, active_state) + + (*grad_og) * (*checkO); + } + *grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node); *grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate); *grad_fg = @@ -117,15 +144,24 @@ class lstm { __m256 *prev_state, __m256 *prev_state_grad, __m256 *state, __m256 *state_grad, __m256 *state_atv, __m256 *output_grad, __m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad, - __m256 *checkFGrad, __m256 *checkOGrad, ActivationType active_node, - ActivationType active_gate, ActivationType active_state) { + __m256 *checkFGrad, __m256 *checkOGrad, T *cell_clip, + ActivationType active_node, ActivationType active_gate, + ActivationType active_state) { *grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og, active_gate); - *state_grad = - _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og), - *state_atv, active_state), - *state_grad); - *state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad); + if (*cell_clip > 0.0f) { + T *state_ = reinterpret_cast(state); + if (*state_ >= (*cell_clip) || *state_ <= (0.0f - (*cell_clip))) { + *state_grad = _mm256_set1_ps(0.0f); + } else { + *state_grad = + _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og), + *state_atv, active_state), + *state_grad); + *state_grad = + _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad); + } + } *grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in, active_node); *grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig, diff --git a/paddle/fluid/operators/math/lstm_compute.cc b/paddle/fluid/operators/math/lstm_compute.cc index b6882b4fd8e6db8592a282410888d5625bae742a..94bbcbb50670d9f0b11b77cf6a54a99c227521bf 100644 --- a/paddle/fluid/operators/math/lstm_compute.cc +++ b/paddle/fluid/operators/math/lstm_compute.cc @@ -24,12 +24,12 @@ template struct LstmUnitFunctor { static void compute(const platform::CPUDeviceContext& context, LstmMetaValue value, int frame_size, int batch_size, - const detail::ActivationType& gate_act, + T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, const detail::ActivationType& cand_act) { for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, - cand_act, gate_act, cell_act); + cell_clip, cand_act, gate_act, cell_act); value.gate_value += frame_size * 4; value.state_value += frame_size; value.state_active_value += frame_size; @@ -45,13 +45,14 @@ template struct LstmUnitGradFunctor { static void compute(const platform::CPUDeviceContext& context, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, + int frame_size, int batch_size, T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, const detail::ActivationType& cand_act) { for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, - frame_size, cand_act, gate_act, cell_act); + frame_size, cell_clip, cand_act, gate_act, + cell_act); value.gate_value += frame_size * 4; value.state_value += frame_size; diff --git a/paddle/fluid/operators/math/lstm_compute.cu b/paddle/fluid/operators/math/lstm_compute.cu index 1233000083d6efc31fcbc527e8e9efb83224b4e3..e7445d3d40ae92ff66e7d33a38bfdebfc8455f0a 100644 --- a/paddle/fluid/operators/math/lstm_compute.cu +++ b/paddle/fluid/operators/math/lstm_compute.cu @@ -24,12 +24,12 @@ template struct LstmUnitFunctor { static void compute(const platform::CUDADeviceContext& context, LstmMetaValue value, int frame_size, int batch_size, - const detail::ActivationType& gate_act, + T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, const detail::ActivationType& cand_act) { detail::gpu_lstm_forward(context, detail::forward::lstm(), value, - frame_size, batch_size, cand_act, gate_act, - cell_act); + frame_size, batch_size, cell_clip, cand_act, + gate_act, cell_act); } }; @@ -37,13 +37,13 @@ template struct LstmUnitGradFunctor { static void compute(const platform::CUDADeviceContext& context, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, + int frame_size, int batch_size, T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, const detail::ActivationType& cand_act) { detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, - frame_size, batch_size, cand_act, gate_act, - cell_act); + frame_size, batch_size, cell_clip, cand_act, + gate_act, cell_act); } }; diff --git a/paddle/fluid/operators/math/lstm_compute.h b/paddle/fluid/operators/math/lstm_compute.h index ca2f78e6f318ce39bd2272bbce20f6a6f98fe430..80af5639387aaf6a983365e13c3478353c27a617 100644 --- a/paddle/fluid/operators/math/lstm_compute.h +++ b/paddle/fluid/operators/math/lstm_compute.h @@ -50,7 +50,7 @@ template class LstmUnitFunctor { public: static void compute(const DeviceContext &context, LstmMetaValue value, - int frame_size, int batch_size, + int frame_size, int batch_size, T cell_clip, const detail::ActivationType &gate_act, const detail::ActivationType &cell_act, const detail::ActivationType &cand_act); @@ -61,7 +61,7 @@ class LstmUnitGradFunctor { public: static void compute(const DeviceContext &context, LstmMetaValue value, LstmMetaGrad grad, int frame_size, int batch_size, - const detail::ActivationType &gate_act, + T cell_clip, const detail::ActivationType &gate_act, const detail::ActivationType &cell_act, const detail::ActivationType &cand_act); }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1a7d076835841e3c1b8a90a43437eff13645fb8a..de2cb46cff1ed9c2360d36c0f07e43c7fdf1d8fe 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -668,7 +668,11 @@ def dynamic_lstmp(input, candidate_activation='tanh', proj_activation='tanh', dtype='float32', - name=None): + name=None, + h_0=None, + c_0=None, + cell_clip=None, + proj_clip=None): """ **Dynamic LSTMP Layer** @@ -785,6 +789,17 @@ def dynamic_lstmp(input, dtype(str): Data type. Choices = ["float32", "float64"], default "float32". name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. + h_0(Variable): The initial hidden state is an optional input, default is zero. + This is a tensor with shape (N x D), where N is the + batch size and D is the projection size. + c_0(Variable): The initial cell state is an optional input, default is zero. + This is a tensor with shape (N x D), where N is the + batch size. `h_0` and `c_0` can be NULL but only at the same time. + cell_clip(float): If provided the cell state is clipped + by this value prior to the cell output activation. + proj_clip(float): If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. Returns: tuple: A tuple of two output variable: the projection of hidden state, \ @@ -831,25 +846,41 @@ def dynamic_lstmp(input, batch_hidden = helper.create_variable_for_type_inference(dtype) batch_gate = helper.create_variable_for_type_inference(dtype) batch_cell_pre_act = helper.create_variable_for_type_inference(dtype) + inputs = { + 'Input': input, + 'Weight': weight, + 'ProjWeight': proj_weight, + 'Bias': bias + } + batch_size = input.shape[0] + if h_0: + assert h_0.shape == (batch_size, proj_size), \ + 'The shape of h0 should be (batch_size, %d)' % proj_size + inputs['H0'] = h_0 + if c_0: + assert c_0.shape == (batch_size, size), \ + 'The shape of c0 should be (batch_size, %d)' % size + inputs['C0'] = c_0 + + if cell_clip: + assert cell_clip >= 0, "cell_clip should not be negtive." + if proj_clip: + assert proj_clip >= 0, "proj_clip should not be negtive." helper.append_op( type='lstmp', - inputs={ - 'Input': input, - 'Weight': weight, - 'ProjWeight': proj_weight, - 'Bias': bias - }, + inputs=inputs, outputs={ 'Projection': projection, 'Cell': cell, - 'OrderedP0': ordered_proj0, 'BatchHidden': batch_hidden, 'BatchGate': batch_gate, 'BatchCellPreAct': batch_cell_pre_act }, attrs={ 'use_peepholes': use_peepholes, + 'cell_clip': cell_clip, + 'proj_clip': proj_clip, 'is_reverse': is_reverse, 'gate_activation': gate_activation, 'cell_activation': cell_activation, diff --git a/python/paddle/fluid/tests/unittests/test_lstmp_op.py b/python/paddle/fluid/tests/unittests/test_lstmp_op.py index 9c3ec45515ffe0a07541fd9cfb7e92b079264071..0645cfedb8089f5618c54672cac91343e5dee285 100644 --- a/python/paddle/fluid/tests/unittests/test_lstmp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstmp_op.py @@ -36,12 +36,14 @@ def lstmp( w_b=None, # 1 x 4D w_c=None, # 1 x 3D is_reverse=False, + proj_clip=0.0, + cell_clip=0.0, act_gate=None, act_cell=None, act_cand=None, act_proj=None): - def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand, - act_proj): + def _step(x, w_r, w_rh, w_c, r_pre, c_pre, proj_clip, cell_clip, act_gate, + act_cell, act_cand, act_proj): g = np.dot(r_pre, w_r) # 1 x 4D g = g + x g = np.reshape(g, (1, g.size)) @@ -55,6 +57,17 @@ def lstmp( g_f = act_gate(g_f + w_fc * c_pre) # 1 x D c = g_f * c_pre + g_i * act_cand(c) # 1 x D + def array_clip(a, clip): + size = np.prod(a.shape) + new_a = np.reshape(a, (size)) + for i in range(size): + new_a[i] = max(new_a[i], -1.0 * clip) + new_a[i] = min(new_a[i], clip) + new_a = np.reshape(new_a, a.shape) + return new_a + + if cell_clip > 0.0: + c = array_clip(c, cell_clip) if w_c is None: g_o = act_gate(g_o) # 1 x D else: @@ -64,6 +77,8 @@ def lstmp( # projection r = np.dot(h, w_rh) r = act_proj(r) + if proj_clip > 0.0: + r = array_clip(r, proj_clip) return r, c def _reverse(x, offset): @@ -87,13 +102,13 @@ def lstmp( # compute one sequence seq_len = lod[0][i] x = input[offset[i]:offset[i + 1], :] - r_pre = np.dot(h0[i], w_rh) # 1 x P - r_pre = act_proj(r_pre) + r_pre = h0[i] c_pre = c0[i] # 1 x D for j in range(seq_len): # compute one step - r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, act_gate, - act_cell, act_cand, act_proj) + r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, proj_clip, + cell_clip, act_gate, act_cell, act_cand, + act_proj) projection.append(r_pre.flatten()) cell.append(c_pre.flatten()) @@ -123,13 +138,12 @@ class TestLstmpOp(LstmTest.TestLstmOp): T = sum(self.lod[0]) N = len(self.lod[0]) - x = np.random.normal(size=(T, 4 * self.D)).astype('float64') if self.has_initial_state: - h0 = np.random.normal(size=(N, self.D)).astype('float64') + h0 = np.random.normal(size=(N, self.P)).astype('float64') c0 = np.random.normal(size=(N, self.D)).astype('float64') else: - h0 = np.zeros((N, self.D)).astype('float64') + h0 = np.zeros((N, self.P)).astype('float64') c0 = np.zeros((N, self.D)).astype('float64') w = np.random.normal(size=(self.P, 4 * self.D)).astype('float64') if self.use_peepholes: @@ -140,9 +154,12 @@ class TestLstmpOp(LstmTest.TestLstmOp): w_b = b[:, 0:4 * self.D] w_c = b[:, 4 * self.D:] if self.use_peepholes else None w_rh = np.random.normal(size=(self.D, self.P)).astype('float64') + proj_clip = 0.1 + cell_clip = 0.1 r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse, - ACTIVATION[self.act_gate], ACTIVATION[self.act_cell], - ACTIVATION[self.act_cand], ACTIVATION[self.act_proj]) + proj_clip, cell_clip, ACTIVATION[self.act_gate], + ACTIVATION[self.act_cell], ACTIVATION[self.act_cand], + ACTIVATION[self.act_proj]) self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh} @@ -159,6 +176,8 @@ class TestLstmpOp(LstmTest.TestLstmOp): self.attrs = { 'use_peepholes': self.use_peepholes, 'is_reverse': self.is_reverse, + 'proj_clip': proj_clip, + 'cell_clip': cell_clip, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, 'candidate_activation': self.act_cand, @@ -171,14 +190,14 @@ class TestLstmpOp(LstmTest.TestLstmOp): def test_check_grad(self): # TODO(qingqing) remove folowing lines after the check_grad is refined. N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( ['Input', 'Weight', 'ProjWeight', 'Bias'], ['Projection'], - max_relative_error=1e-2) + max_relative_error=1e-2, + numeric_grad_delta=0.0000005) class TestLstmpOpHasInitial(TestLstmpOp): @@ -188,7 +207,6 @@ class TestLstmpOpHasInitial(TestLstmpOp): def test_check_grad(self): # TODO(qingqing) remove folowing lines after the check_grad is refined. N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -196,11 +214,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0', 'C0'], ['Projection'], + numeric_grad_delta=0.0000005, max_relative_error=1e-2) def test_check_grad_ingore_bias(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -208,11 +226,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'ProjWeight', 'Weight'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('Bias')) def test_check_grad_ingore_weight(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -220,11 +238,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'ProjWeight', 'Bias'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('Weight')) def test_check_grad_ingore_proj_weight(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -232,11 +250,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'Weight', 'Bias'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('ProjWeight')) def test_check_grad_ingore_input(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -244,11 +262,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Weight', 'ProjWeight', 'Bias'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('Input')) def test_check_grad_ingore_h0(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -256,11 +274,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'Weight', 'ProjWeight', 'Bias', 'C0'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('H0')) def test_check_grad_ingore_c0(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -268,6 +286,7 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('C0'))