diff --git a/CMakeLists.txt b/CMakeLists.txt index 48e52961a95d50264b201eec50ccb3a462f39c54..317f7f9eb46a96e9f6ea393abf82d608af50fc4b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -138,12 +138,6 @@ else() set(THIRD_PARTY_BUILD_TYPE Release) endif() -if(WITH_MKL) - option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF) - if (MKL_SPLIT_GEMM) - add_definitions(-DPADDLE_MKL_SPLIT_GEMM) - endif() -endif() set(WITH_MKLML ${WITH_MKL}) if (NOT DEFINED WITH_MKLDNN) if (WITH_MKL AND AVX2_FOUND) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1cb65346ee2b755b48f8dd8f1456a32861c3a0b6 --- /dev/null +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -0,0 +1,422 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/attention_lstm_op.h" +#include +#include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { + +void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(C0) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), + "Input(LSTMWeight) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), + "Input(LSTMBias) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), + "Input(AttentionWeight) of AttentionLSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(Hidden) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Cell"), + "Output(Cell) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), + "Output(AttentionedX) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), + "Output(AttentionFCOut) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), + "Output(LSTMX) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), + "Output(LSTMOUT) of AttentionLSTM should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + const int M = x_dims[1]; + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + + auto w_dims = ctx->GetInputDim("LSTMWeight"); + const int D = w_dims[1] / 4; + PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], D + M, + "LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); + + auto b_dims = ctx->GetInputDim("LSTMBias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); + PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); + + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); + PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); + if (ctx->HasInput("H0")) { + auto h_dims = ctx->GetInputDim("H0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); + PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, + "Input(AttentionWeight)'s rank must be 2."); + PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + if (ctx->HasInput("AttentionBias")) { + auto atten_b_dims = ctx->GetInputDim("AttentionBias"); + PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, + "Input(AttentionBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, + "AttentionBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, + "AttentionBias shapes must be 1 * 1."); + } + + if (ctx->HasInput("AttentionScalar")) { + auto dims = ctx->GetInputDim("AttentionScalar"); + PADDLE_ENFORCE_EQ(dims.size(), 2, + "Input(AttentionScalar)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); + } + + if (ctx->HasInput("AttentionScalarBias")) { + auto dims = ctx->GetInputDim("AttentionScalarBias"); + PADDLE_ENFORCE( + ctx->HasInput("AttentionScalar"), + "AttentionScalar should not be null when have AttentionScalarBias."); + PADDLE_ENFORCE_EQ(dims.size(), 2, + "Input(AttentionScalarBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); + } + + framework::DDim out_dims({x_dims[0], D}); + ctx->SetOutputDim("Hidden", out_dims); + ctx->SetOutputDim("Cell", out_dims); + ctx->SetOutputDim("AttentionedX", {x_dims[0], 1}); + ctx->SetOutputDim("LSTMX", {1, M}); + ctx->SetOutputDim("LSTMOUT", {1, 4 * D}); + // AttentionFCOut should be reshape as (maxseqlen,1) in runtime + ctx->ShareLoD("X", "Hidden"); + ctx->ShareLoD("X", "Cell"); +} + +framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); +} + +void AttentionLSTMOpMaker::Make() { + AddInput("X", + "(LoDTensor) the input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T X M), where T is the " + "total time steps in this mini-batch, M is the dim size of x."); + AddInput("C0", + "(Tensor) LSTM C0" + "This is a tensor with shape (N x D), where N is the batch size, D " + "is the gate size." + "C0 is necessary because of attention."); + AddInput("H0", + "(Tensor, optional) LSTM H0" + "This is a tensor with shape (N x D), where N is the " + "batch size and D is the gate size.") + .AsDispensable(); + AddInput("AttentionWeight", + "(Tensor) the weights of attention fc. Always relu the fc result." + "The shape is ((M+D) x 1), where M is the dim size of x, D is the " + "gate size of LSTM."); + AddInput("AttentionBias", + "(Tensor, optional) the bias of attention fc." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("AttentionScalar", + "(Tensor, optional) the scalar on the result of attentioned fc. " + "Always relu the Scalar." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("AttentionScalarBias", + "(Tensor, optional) the scalar bias of attention fc." + "The shape is (1 x 1)") + .AsDispensable(); + AddInput("LSTMWeight", + "(Tensor) the combined weight of LSTM" + " - The shape is ((D+M) x 4D), where D is the hidden gate size, M " + "is the dim size of x" + " - Weight = {W_forget, W_input, W_output, W_cell}"); + AddInput("LSTMBias", + "(Tensor) the combined bias of LSTM, shape (1x4D)." + "Note: we should add the bias of hidden and context accorindg to " + "the same gate: " + "{B_forget, B_input, B_output, B_cell}"); + AddOutput("Hidden", + "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("AttentionedX", + "(Tensor) shape is (T x 1), the result after X * AttentionWeight," + " where T is the total time steps in this mini-batch," + " D is the hidden size.") + .AsIntermediate(); + AddOutput("AttentionFCOut", + "(Tensor) (max_seq_len, 1), compute at each step.") + .AsIntermediate(); + AddOutput("LSTMX", + "(Tensor) the input X of LSTM for each step." + "Shape is (1 x M), where M is the x frame size") + .AsIntermediate(); + AddOutput( + "LSTMOUT", + "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step." + "Shape is (1 x 4D), where M is the x frame size") + .AsIntermediate(); + AddAttr("gate_activation", + "(string, default: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by default.") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("cell_activation", + "(string, default: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("candidate_activation", + "(string, default: tanh)" + "The activation for candidate hidden state, " + "`tanh` by default.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddComment(R"DOC( +Attention Long-Short Term Memory (LSTM) Operator. + +Attention part: +concat( x(seqlen * M), expand( cell_t-1(1,D) ) ) => tmp(seqlen*(M+D)) + +tmp(seqlen*(M+D)) * fc((M+D)*1) => fcout(seqlen*1) with bias, relu + +fcout(seqlen*1) * scalar => fcout(seqlen*1) with bias, relu + +dotmul and sum pool ( fcout(seqlen*1), x(seqlen * M) ) => lstm_x_t(1, M) + +LSTM part: +use lstm_x_t as input and compute as standard LSTM. + +)DOC"); +} + +// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0; +template +inline void bias_relu(const int n, const T* x, const T* bias, T* y) { + if (bias) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] + bias[0]; + } + math::vec_relu(n, y, y); + } else { + math::vec_relu(n, x, y); + } +} + +template +inline void vec_softmax(const math::BlasT& blas, const int n, + const T* x, T* y) { + T scalar = x[0]; + // max + for (int i = 1; i < n; ++i) { + scalar = scalar < x[i] ? x[i] : scalar; + } + + // sub + for (int i = 0; i < n; ++i) { + y[i] = x[i] - scalar; + } + + // exp + blas.VEXP(n, y, y); + + // sum + scalar = T(0); + for (int i = 0; i < n; ++i) { + scalar += y[i]; + } + + // scale + blas.SCAL(n, static_cast(1) / scalar, y); +} + +template +class AttentionLSTMKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using DeviceContext = paddle::platform::CPUDeviceContext; + + auto* x = ctx.Input("X"); + auto* h0 = ctx.Input("H0"); + auto* c0 = ctx.Input("C0"); + auto* atten_w = ctx.Input("AttentionWeight"); + auto* atten_b = ctx.Input("AttentionBias"); + auto* atten_scalar = ctx.Input("AttentionScalar"); + auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); + auto* lstm_w = ctx.Input("LSTMWeight"); + auto* lstm_b = ctx.Input("LSTMBias"); + + auto* hidden_out = ctx.Output("Hidden"); + auto* cell_out = ctx.Output("Cell"); + auto* atted_x = ctx.Output("AttentionedX"); + auto* fc_out = ctx.Output("AttentionFCOut"); + auto* lstm_x = ctx.Output("LSTMX"); + auto* lstm_out = ctx.Output("LSTMOUT"); + + // some shape should be reshape here since infershape can not get lod info + auto x_lod = x->lod(); + const int N = x_lod[0].size() - 1; // batch size + auto x_dims = x->dims(); // T x M + auto w_dims = lstm_w->dims(); // (D+M) x 4D + const int total_T = x_dims[0]; + const int M = x_dims[1]; // x frame size + const int D = w_dims[1] / 4; // gate frame size + const int D2 = D * 2; + const int D3 = D * 3; + const int D4 = w_dims[1]; + int max_seq_len = x_lod[0][1]; + for (int i = 1; i < N; ++i) { + int len = x_lod[0][i + 1] - x_lod[0][i]; + max_seq_len = max_seq_len < len ? len : max_seq_len; + } + PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1."); + PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); + fc_out->Resize({max_seq_len, 1}); + + math::VecActivations act_functor; + std::function act_gate, act_cell, act_cand; + act_gate = act_functor(ctx.Attr("gate_activation")); + act_cell = act_functor(ctx.Attr("cell_activation")); + act_cand = act_functor(ctx.Attr("candidate_activation")); + + const T* x_data = x->data(); + const T* h0_data = h0 ? h0->data() : NULL; + const T* c0_data = c0->data(); + const T* lstm_w_data = lstm_w->data(); + const T* lstm_b_data = lstm_b->data(); + const T* atten_w_data = atten_w->data(); + const T* atten_b_data = atten_b ? atten_b->data() : NULL; + const T* atten_scalar_data = atten_scalar ? atten_scalar->data() : NULL; + const T* atten_scalar_bias_data = + atten_scalar_bias ? atten_scalar_bias->data() : NULL; + + T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); + T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); + T* atted_x_data = atted_x->mutable_data(ctx.GetPlace()); + T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); + T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); + T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); + + // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 + auto blas = math::GetBlas(ctx); + math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data, + atted_x_data, atten_b_data); + + const T* cur_atten_x_data = atted_x_data; + const T* cur_x_data = x_data; + const T* prev_cell_data = NULL; + const T* prev_hidden_data = NULL; + T* cur_cell_out_data = cell_out_data; + T* cur_hidden_out_data = hidden_out_data; + for (int i = 0; i < N; ++i) { + int seq_len = x_lod[0][i + 1] - x_lod[0][i]; + prev_cell_data = c0_data + i * D; + prev_hidden_data = h0_data ? h0_data + i * D : NULL; + for (int step = 0; step < seq_len; ++step) { + /// 1. compute attention vector + // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt + T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M); + // 1b. add cell bias and relu + bias_relu(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data); + // 1c. fc scalar + if (atten_scalar_data) { + blas.SCAL(seq_len, *atten_scalar_data, fc_out_data); + bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, + fc_out_data); + } + // 1d. softmax + vec_softmax(blas, seq_len, fc_out_data, fc_out_data); + // mul x(seq_len*M) and sum pool + math::FCCompute(blas, 1, M, seq_len, fc_out_data, + cur_x_data, lstm_x_data); + + /// 2. compute LSTM step + // lstm weight : concat[forget , input , output , tilde] + // shape : (D + M) x (4 * D) + // fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D + blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data); + if (prev_hidden_data) { + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), + prev_hidden_data, D, lstm_w_data, D4, static_cast(1), + lstm_out_data, D4); + } + // since input is 1xM, so can use add bias + blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data); + + // gate act: sigmoid + act_gate(D3, lstm_out_data, lstm_out_data); + // candicate act: tanh + act_cand(D, lstm_out_data + D3, lstm_out_data + D3); + + // a = forget * prev_cell + blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); + + // b = input * tilde + blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D); + + // cell_out = a + b + blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); + + // state act tanh(cell_out) * output_gate + act_cell(D, cur_cell_out_data, lstm_out_data); + blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); + + prev_hidden_data = cur_hidden_out_data; + prev_cell_data = cur_cell_out_data; + cur_cell_out_data = cur_cell_out_data + D; + cur_hidden_out_data = cur_hidden_out_data + D; + } + cur_x_data = cur_x_data + seq_len * M; + cur_atten_x_data = cur_atten_x_data + seq_len; + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp, + ops::AttentionLSTMOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel, + ops::AttentionLSTMKernel); diff --git a/paddle/fluid/operators/attention_lstm_op.h b/paddle/fluid/operators/attention_lstm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6ede3a7f3c96dd2d13d7c5c19816647e16a3c8d0 --- /dev/null +++ b/paddle/fluid/operators/attention_lstm_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +class AttentionLSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class AttentionLSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fusion_lstm_op.h b/paddle/fluid/operators/fusion_lstm_op.h index 39dc09b4d116193399d8ac9a51e88dbc3e239918..7f79601602348ac454fc6c0cefcba0643ad8e6e2 100644 --- a/paddle/fluid/operators/fusion_lstm_op.h +++ b/paddle/fluid/operators/fusion_lstm_op.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -// #include #include "paddle/fluid/framework/op_registry.h" namespace paddle { diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 8dcf7c99f3860789dee834787eeb8b7ad4cc3530..da185d93c09f9b06bd5968b9c8e93176f9ef014b 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -90,6 +90,11 @@ class Blas { void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; + template + void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, + int ldc) const; + #ifdef PADDLE_WITH_MKLML template T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, @@ -109,6 +114,10 @@ class Blas { void GEMM_FREE(T* data) const; #endif + template + void MatMul(const int M, const int N, const int K, const T* A, const T* B, + T* C) const; + template void MatMul(const framework::Tensor& mat_a, bool trans_a, const framework::Tensor& mat_b, bool trans_b, T alpha, @@ -140,10 +149,19 @@ class Blas { template void VCOPY(int n, const T* x, T* y) const; + template + void VEXP(int n, const T* x, T* y) const; + template void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; + template + T DOT(int n, const T* x, const T* y) const; + + template + void SCAL(int n, const T a, T* x) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, @@ -215,11 +233,26 @@ class BlasT : private Blas { Base()->template VCOPY(args...); } + template + void VEXP(ARGS... args) const { + Base()->template VEXP(args...); + } + template void GEMV(ARGS... args) const { Base()->template GEMV(args...); } + template + T DOT(ARGS... args) const { + return Base()->template DOT(args...); + } + + template + void SCAL(ARGS... args) const { + Base()->template SCAL(args...); + } + template void BatchedGEMM(ARGS... args) const { Base()->template BatchedGEMM(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index dc77b6d793702458a22a2f59b68e9d9f2c23b4ff..e1df78d11e41c5f74e244643f40c6d0581fa6a4a 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -73,6 +73,16 @@ struct CBlas { platform::dynload::cblas_sgemv(args...); } + template + static float DOT(ARGS... args) { + return platform::dynload::cblas_sdot(args...); + } + + template + static void SCAL(ARGS... args) { + platform::dynload::cblas_sscal(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); @@ -87,6 +97,11 @@ struct CBlas { static void VMUL(ARGS... args) { platform::dynload::vsMul(args...); } + + template + static void VEXP(ARGS... args) { + platform::dynload::vsExp(args...); + } }; template <> @@ -138,6 +153,16 @@ struct CBlas { platform::dynload::cblas_dgemv(args...); } + template + static double DOT(ARGS... args) { + return platform::dynload::cblas_ddot(args...); + } + + template + static void SCAL(ARGS... args) { + platform::dynload::cblas_dscal(args...); + } + template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); @@ -152,6 +177,11 @@ struct CBlas { static void VMUL(ARGS... args) { platform::dynload::vdMul(args...); } + + template + static void VEXP(ARGS... args) { + platform::dynload::vdExp(args...); + } }; #else @@ -210,6 +240,9 @@ struct CBlas { PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } + static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } + static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; + static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -217,64 +250,6 @@ struct CBlas { #endif }; -template -inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa, - bool transb, const T &alpha, const T &beta) { -#ifdef PADDLE_WITH_LIBXSMM - // Refer to https://github.com/hfp/libxsmm/blob/master/README.md - // But the threshold is custom - constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; - if (m * n * k > LIBXSMM_THRESHOLD || transa || transb || - std::abs(alpha - static_cast(1) > - std::numeric_limits::epsilon()) || - std::abs(beta) > std::numeric_limits::epsilon()) { - return false; - } else { - return true; - } -#endif - return false; -} - -template <> -inline bool UseXSMM(const int &m, const int &n, const int &k, - bool transa, bool transb, - const platform::float16 &alpha, - const platform::float16 &beta) { - return false; -} - -template -inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, - const T *A, int lda, const T *B, int ldb, T beta, T *C, - int ldc) { -#ifdef PADDLE_WITH_LIBXSMM - if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, - beta)) { - // Note: SMM use ColMajor - const char transa = 'N'; - const char transb = 'N'; - CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda, - &beta, C, &ldc); - return; - } -#endif - -#ifdef PADDLE_MKL_SPLIT_GEMM - constexpr int bs = 2; - if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) { - for (int off = 0; off < M; off += bs) { - CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha, - A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc); - } - return; - } -#endif - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} - #ifdef PADDLE_WITH_MKLML template <> template @@ -319,8 +294,8 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; - GEMM_WARP(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template <> @@ -329,9 +304,20 @@ void Blas::GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T *A, int lda, const T *B, int ldb, T beta, T *C, int ldc) const { - GEMM_WARP(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); + CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, + int N, int K, T alpha, const T *A, + int lda, const T *B, int ldb, + T beta, T *C, int ldc) const { + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template @@ -399,6 +385,47 @@ void Blas::VMUL(int n, const T *x, const T *y, #endif } +template <> +template +void Blas::VEXP(int n, const T *x, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VEXP(n, x, y); +#else + // try to find if openblas support vexp + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +#endif +} + +template <> +template +T Blas::DOT(int n, const T *x, const T *y) const { +#ifdef PADDLE_WITH_MKLML + return CBlas::DOT(n, x, 1, y, 1); +#else + // try to find if openblas support cblas_dot + T sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] * y[i]; + } + return sum; +#endif +} + +template <> +template +void Blas::SCAL(int n, const T a, T *x) const { +#ifdef PADDLE_WITH_MKLML + CBlas::SCAL(n, a, x, 1); +#else + // try to find if openblas support cblas_scal + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +#endif +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, @@ -440,6 +467,42 @@ void Blas::BatchedGEMM( #endif } +template +template +void Blas::MatMul(const int M, const int N, const int K, + const T *A, const T *B, T *C) const { + this->template GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, + static_cast(1), A, K, B, N, static_cast(0), C, + N); +} + +template <> +template +void Blas::MatMul(const int M, const int N, + const int K, const T *A, + const T *B, T *C) const { +#ifdef PADDLE_WITH_LIBXSMM + // Refer to https://github.com/hfp/libxsmm/blob/master/README.md + // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; + + // Since the matrix is very small, + // so the unit of calculation is already very fast, + // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, + // use xsmm directly. + // Note: SMM use ColMajor + const char transa = 'N'; + const char transb = 'N'; + const T alpha = static_cast(1); + const T beta = static_cast(0); + CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, + C, &N); + return; +#endif + + CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, + static_cast(1), A, K, B, N, static_cast(0), C, N); +} + template template void Blas::MatMul(const framework::Tensor &mat_a, diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h new file mode 100644 index 0000000000000000000000000000000000000000..48c0da0e368a0fe6efcd758536e5659eeee26f7e --- /dev/null +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace math { + +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 + +template +inline T sigmoid(T x) { + return 1. / (1. + exp(-x)); +} + +template +inline T tanh(T x) { + return 2. * sigmoid(2. * x) - 1.; +} + +template +inline void vec_identity(const int n, const T* x, T* y) { + // do nothing + return; +} + +template +inline void vec_sigmoid(const int n, const T* x, T* y) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = 1.0 / (1.0 + std::exp(-tmp)); + } +} + +template +inline void vec_tanh(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = tanh(x[i]); + } +} + +template +inline void vec_relu(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template <> +inline void vec_relu(const int n, const float* x, + float* y) { + // TODO(TJ): complete me + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template <> +inline void vec_relu(const int n, const float* x, + float* y) { + // TODO(TJ): complete me + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template +class VecActivations { + public: + std::function operator()( + const std::string& type) { + if (type == "sigmoid") { + return vec_sigmoid; + } else if (type == "relu") { + return vec_relu; + } else if (type == "tanh") { + return vec_tanh; + } else if (type == "identity" || type == "") { + return vec_identity; + } + PADDLE_THROW("Not support type %s.", type); + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 8600fa9e2c4db9d54cbe0ffb68f82d52c086d4f7..1f5a49c0ab5a10b0d7dc1febd258ce76c467cb1c 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -25,17 +25,25 @@ namespace math { template inline void FCCompute(const BlasT& blas, const int M, const int N, const int K, const T* X, const T* W, T* Y, - const T* B = NULL) { - blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast(1), X, W, - static_cast(0), Y); - if (B) { + const T* B = NULL, bool relu = false) { + blas.MatMul(M, N, K, X, W, Y); + if (B == NULL) { + return; + } + #ifdef PADDLE_WITH_MKLML #pragma omp parallel for if (FLAGS_paddle_num_threads > 1) #endif - for (int i = 0; i < M; i++) { - blas.AXPY(N, static_cast(1), B, Y + i * N); - } + for (int i = 0; i < M; i++) { + blas.AXPY(N, static_cast(1), B, Y + i * N); } + + if (!relu) { + return; + } + + // TODO(TJ): fuse relu + LOG(FATAL) << "Not implemented!"; } } // namespace math diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index 7d53a684d6068c79659719159696ef5aebfeaa2b..fcd658d67cf4551dbdb9696ef49b5ab3cc58bf95 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -103,15 +103,16 @@ size_t CUDAPinnedMaxChunkSize() { return CUDAPinnedMaxAllocSize() / 256; } -#ifdef PADDLE_WITH_XBYAK namespace jit { - +#ifdef PADDLE_WITH_XBYAK static Xbyak::util::Cpu cpu; bool MayIUse(const cpu_isa_t cpu_isa) { using namespace Xbyak::util; // NOLINT switch (cpu_isa) { case sse42: return cpu.has(Cpu::tSSE42); + case avx: + return cpu.has(Cpu::tAVX); case avx2: return cpu.has(Cpu::tAVX2); case avx512_common: @@ -134,8 +135,16 @@ bool MayIUse(const cpu_isa_t cpu_isa) { } return false; } +#else +bool MayIUse(const cpu_isa_t cpu_isa) { + if (cpu_isa == isa_any) { + return true; + } else { + return false; + } +} +#endif } // namespace jit -#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index f5f67667594f1ab80058533e4c5d5b04c2592b60..5d17978dd7946596c490dc465dab51e7cf53a044 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -37,12 +37,11 @@ size_t CUDAPinnedMinChunkSize(); //! Get the maximum chunk size for buddy allocator. size_t CUDAPinnedMaxChunkSize(); -#ifdef PADDLE_WITH_XBYAK namespace jit { - typedef enum { isa_any, sse42, + avx, avx2, avx512_common, avx512_core, @@ -55,7 +54,6 @@ typedef enum { inline bool MayIUse(const cpu_isa_t cpu_isa); } // namespace jit -#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 15ad4a3b40b1ad13a10dd37449c6f6f3e2029df6..aa20553ceffceded09447693c6e92f55fb48702d 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -66,10 +66,16 @@ extern void* mklml_dso_handle; __macro(cblas_dgemm_free); \ __macro(cblas_sgemm_batch); \ __macro(cblas_dgemm_batch); \ + __macro(cblas_sdot); \ + __macro(cblas_ddot); \ + __macro(cblas_sscal); \ + __macro(cblas_dscal); \ __macro(vsAdd); \ __macro(vdAdd); \ __macro(vsMul); \ __macro(vdMul); \ + __macro(vsExp); \ + __macro(vdExp); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); diff --git a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a7382c2244ec3291c4e8f625cc2d15499e0acdac --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py @@ -0,0 +1,208 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from test_fusion_lstm_op import fc, ACTIVATION +from test_softmax_op import stable_softmax + + +def attention_lstm( + x, # T x M + lod, # 1 x N + h0, # N x D + c0, # N x D + fcws, # (M+D) x 1, 1x1 + fcbs, # 1 x 1, 1x1 + w, # (M+D) x 4D + b, # 1 x 4D + act_gate, + act_cell, + act_cand): + + T = sum(lod[0]) + N = len(lod[0]) + M = x.shape[1] + D = b.shape[1] / 4 + assert T == x.shape[0] + assert len(fcws) == len(fcbs) + hidden = [] + cell = [] + + start_offset = 0 + for bid in range(N): + seq_len = lod[0][bid] + xi = np.copy(x[start_offset:start_offset + seq_len, :]).reshape(seq_len, + M) + prev_cell = np.copy(c0[bid]).reshape([1, D]) + prev_hidden = np.copy(h0[bid]).reshape([1, D]) + for step in range(seq_len): + expanded_cell = np.repeat(prev_cell, seq_len, axis=0) + tmp = np.concatenate((xi, expanded_cell), axis=1) + assert tmp.shape[0] == seq_len + assert tmp.shape[1] == M + D + for fcid in range(len(fcbs)): + tmp = fc(tmp, fcws[fcid], fcbs[fcid]) + tmp = ACTIVATION['relu'](tmp) + tmp = np.reshape(tmp, (1, seq_len)) + tmp = stable_softmax(tmp).reshape(seq_len, 1) + lstmx = xi * tmp # seq * M + lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M]) + lstmin = np.concatenate((prev_hidden, lstmx), axis=1) + lstmout = fc(lstmin, w, b).reshape([1, 4 * D]) + + g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1) + g_f = act_gate(g_f).reshape([1, D]) + g_i = act_gate(g_i).reshape([1, D]) + g_o = act_gate(g_o).reshape([1, D]) + cand = act_cand(cand).reshape([1, D]) + + cell_t = (prev_cell * g_f) + (g_i * cand) + hidden_t = g_o * act_cell(cell_t) + + hidden.append(hidden_t.flatten()) + cell.append(cell_t.flatten()) + + prev_cell = cell_t.reshape([1, D]) + prev_hidden = hidden_t.reshape([1, D]) + + start_offset += seq_len + + hidden = np.array(hidden).astype('float32').reshape([T, D]) + cell = np.array(cell).astype('float32').reshape([T, D]) + return hidden, cell + + +class TestAttentionLSTMOp(OpTest): + def set_conf(self): + pass + + def setUp(self): + self.op_type = 'attention_lstm' + self.lod = [[3]] + self.M = 30 + self.D = 15 + self.has_initial_hidden = True + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + self.set_conf() + + T = sum(self.lod[0]) + bs = len(self.lod[0]) + + x = np.random.normal(size=(T, self.M)).astype('float32') + c0 = np.random.normal(size=(bs, self.D)).astype('float32') + if self.has_initial_hidden: + h0 = np.random.normal(size=(bs, self.D)).astype('float32') + else: + h0 = np.zeros((bs, self.D)).astype('float32') + + fcw1 = np.random.normal(size=(self.M + self.D, 1)).astype('float32') + fcb1 = np.random.normal(size=(1, 1)).astype('float32') + fcw2 = np.random.normal(size=(1, 1)).astype('float32') + fcb2 = np.random.normal(size=(1, 1)).astype('float32') + + # lstm weight and bias + w = np.random.normal(size=(self.M + self.D, + self.D * 4)).astype('float32') + b = np.random.normal(size=(1, self.D * 4)).astype('float32') + + h, c = attention_lstm(x, self.lod, h0, c0, [fcw1, fcw2], [fcb1, fcb2], + w, b, ACTIVATION[self.act_gate], + ACTIVATION[self.act_cell], + ACTIVATION[self.act_cand]) + + self.inputs = { + 'X': (x, self.lod), + 'C0': c0, + 'AttentionWeight': fcw1, + 'AttentionBias': fcb1, + 'AttentionScalar': fcw2, + 'AttentionScalarBias': fcb2, + 'LSTMWeight': w, + 'LSTMBias': b + } + + if self.has_initial_hidden: + self.inputs['H0'] = h0 + + self.outputs = { + 'Hidden': (h, self.lod), + 'Cell': (c, self.lod), + } + self.attrs = { + 'gate_activation': self.act_gate, + 'cell_activation': self.act_cell, + 'candidate_activation': self.act_cand + } + + def test_check_output(self): + self.check_output() + + +class TestAttentionOpNonInit(TestAttentionLSTMOp): + def set_conf(self): + self.has_initial_hidden = False + + +class TestAttentionOpAct(TestAttentionLSTMOp): + def set_conf(self): + self.M = 3 + self.D = 2 + self.act_gate = 'relu' + self.act_cell = 'tanh' + self.act_cand = 'sigmoid' + + +class TestAttentionOpMD1(TestAttentionLSTMOp): + def set_conf(self): + self.M = 36 + self.D = 8 + + +class TestAttentionOpMD2(TestAttentionLSTMOp): + def set_conf(self): + self.M = 8 + self.D = 8 + + +class TestAttentionOpMD3(TestAttentionLSTMOp): + def set_conf(self): + self.M = 15 + self.D = 30 + + +class TestAttentionOpBS1(TestAttentionLSTMOp): + def set_conf(self): + self.lod = [[5]] + self.M = 16 + self.D = 32 + + +class TestAttentionOpBS2(TestAttentionLSTMOp): + def set_conf(self): + self.lod = [[3, 6]] + + +class TestAttentionOpBS5(TestAttentionLSTMOp): + def set_conf(self): + self.lod = [[3, 2, 4, 7, 5]] + + +if __name__ == '__main__': + unittest.main()