From 5ca0bb9aadd50b10dc0e20bbc528604b8937e2c1 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 23 Aug 2018 00:01:45 +0800 Subject: [PATCH] support more activation type and remove some comments --- paddle/fluid/operators/attention_lstm_op.cc | 57 ++++++++++--------- paddle/fluid/operators/math/cpu_vec.h | 26 ++++++++- .../tests/unittests/test_attention_lstm_op.py | 9 +++ 3 files changed, 63 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 5d57703c0b9..1cb65346ee2 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" +#include #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" @@ -192,24 +193,23 @@ void AttentionLSTMOpMaker::Make() { "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step." "Shape is (1 x 4D), where M is the x frame size") .AsIntermediate(); - // TODO(TJ): InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("gate_activation", "(string, default: sigmoid)" "The activation for input gate, forget gate and output " "gate, `sigmoid` by default.") .SetDefault("sigmoid") - .InEnum({"sigmoid"}); + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("cell_activation", "(string, default: tanh)" "The activation for cell output, `tanh` by defalut.") .SetDefault("tanh") - .InEnum({"tanh"}); + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("candidate_activation", "(string, default: tanh)" "The activation for candidate hidden state, " "`tanh` by default.") .SetDefault("tanh") - .InEnum({"tanh"}); + .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddComment(R"DOC( Attention Long-Short Term Memory (LSTM) Operator. @@ -273,22 +273,23 @@ class AttentionLSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* x = ctx.Input("X"); // T x M - auto* h0 = ctx.Input("H0"); // N x D - auto* c0 = ctx.Input("C0"); // N x D - auto* atten_w = ctx.Input("AttentionWeight"); // (M+D) x 1 - auto* atten_b = ctx.Input("AttentionBias"); // 1x1 - auto* atten_scalar = ctx.Input("AttentionScalar"); // 1x1 - auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); // 1x1 - auto* lstm_w = ctx.Input("LSTMWeight"); // (D+M) x D*4 - auto* lstm_b = ctx.Input("LSTMBias"); // 1 x D*4 - - auto* hidden_out = ctx.Output("Hidden"); // TxD - auto* cell_out = ctx.Output("Cell"); // TxD - auto* atted_x = ctx.Output("AttentionedX"); // T x 1 - auto* fc_out = ctx.Output("AttentionFCOut"); // max_seq_len x 1 - auto* lstm_x = ctx.Output("LSTMX"); // 1 x M - auto* lstm_out = ctx.Output("LSTMOUT"); // 1 x 4D + + auto* x = ctx.Input("X"); + auto* h0 = ctx.Input("H0"); + auto* c0 = ctx.Input("C0"); + auto* atten_w = ctx.Input("AttentionWeight"); + auto* atten_b = ctx.Input("AttentionBias"); + auto* atten_scalar = ctx.Input("AttentionScalar"); + auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); + auto* lstm_w = ctx.Input("LSTMWeight"); + auto* lstm_b = ctx.Input("LSTMBias"); + + auto* hidden_out = ctx.Output("Hidden"); + auto* cell_out = ctx.Output("Cell"); + auto* atted_x = ctx.Output("AttentionedX"); + auto* fc_out = ctx.Output("AttentionFCOut"); + auto* lstm_x = ctx.Output("LSTMX"); + auto* lstm_out = ctx.Output("LSTMOUT"); // some shape should be reshape here since infershape can not get lod info auto x_lod = x->lod(); @@ -310,11 +311,11 @@ class AttentionLSTMKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); fc_out->Resize({max_seq_len, 1}); - // TODO(TJ): act functor init here - // if (platform::jit::MayIUse(platform::jit::avx2)) { - // } else if (platform::jit::MayIUse(platform::jit::avx)) { - // } else { - // } + math::VecActivations act_functor; + std::function act_gate, act_cell, act_cand; + act_gate = act_functor(ctx.Attr("gate_activation")); + act_cell = act_functor(ctx.Attr("cell_activation")); + act_cand = act_functor(ctx.Attr("candidate_activation")); const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : NULL; @@ -381,9 +382,9 @@ class AttentionLSTMKernel : public framework::OpKernel { blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data); // gate act: sigmoid - math::vec_sigmoid(D3, lstm_out_data, lstm_out_data); + act_gate(D3, lstm_out_data, lstm_out_data); // candicate act: tanh - math::vec_tanh(D, lstm_out_data + D3, lstm_out_data + D3); + act_cand(D, lstm_out_data + D3, lstm_out_data + D3); // a = forget * prev_cell blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); @@ -395,7 +396,7 @@ class AttentionLSTMKernel : public framework::OpKernel { blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); // state act tanh(cell_out) * output_gate - math::vec_tanh(D, cur_cell_out_data, lstm_out_data); + act_cell(D, cur_cell_out_data, lstm_out_data); blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); prev_hidden_data = cur_hidden_out_data; diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 29476fce709..48c0da0e368 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once - +#include #include "paddle/fluid/platform/cpu_info.h" namespace paddle { @@ -34,6 +34,12 @@ inline T tanh(T x) { return 2. * sigmoid(2. * x) - 1.; } +template +inline void vec_identity(const int n, const T* x, T* y) { + // do nothing + return; +} + template inline void vec_sigmoid(const int n, const T* x, T* y) { const T min = SIGMOID_THRESHOLD_MIN; @@ -76,6 +82,24 @@ inline void vec_relu(const int n, const float* x, } } +template +class VecActivations { + public: + std::function operator()( + const std::string& type) { + if (type == "sigmoid") { + return vec_sigmoid; + } else if (type == "relu") { + return vec_relu; + } else if (type == "tanh") { + return vec_tanh; + } else if (type == "identity" || type == "") { + return vec_identity; + } + PADDLE_THROW("Not support type %s.", type); + } +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py index cb02c7e5868..a7382c2244e 100644 --- a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py @@ -160,6 +160,15 @@ class TestAttentionOpNonInit(TestAttentionLSTMOp): self.has_initial_hidden = False +class TestAttentionOpAct(TestAttentionLSTMOp): + def set_conf(self): + self.M = 3 + self.D = 2 + self.act_gate = 'relu' + self.act_cell = 'tanh' + self.act_cand = 'sigmoid' + + class TestAttentionOpMD1(TestAttentionLSTMOp): def set_conf(self): self.M = 36 -- GitLab