未验证 提交 f0f06992 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #12878 from tensor-tang/feature/op/attention_lstm

Add attention lstm cpu forward
......@@ -138,12 +138,6 @@ else()
set(THIRD_PARTY_BUILD_TYPE Release)
endif()
if(WITH_MKL)
option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF)
if (MKL_SPLIT_GEMM)
add_definitions(-DPADDLE_MKL_SPLIT_GEMM)
endif()
endif()
set(WITH_MKLML ${WITH_MKL})
if (NOT DEFINED WITH_MKLDNN)
if (WITH_MKL AND AVX2_FOUND)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/attention_lstm_op.h"
#include <sys/time.h>
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
"Input(LSTMWeight) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
"Input(LSTMBias) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Input(AttentionWeight) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
"Output(AttentionedX) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
"Output(AttentionFCOut) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Output(LSTMX) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Output(LSTMOUT) of AttentionLSTM should not be null.");
auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1];
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto w_dims = ctx->GetInputDim("LSTMWeight");
const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D);
auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
"AttentionBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
"AttentionBias shapes must be 1 * 1.");
}
if (ctx->HasInput("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
}
if (ctx->HasInput("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias");
PADDLE_ENFORCE(
ctx->HasInput("AttentionScalar"),
"AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
}
framework::DDim out_dims({x_dims[0], D});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
ctx->SetOutputDim("LSTMX", {1, M});
ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
// AttentionFCOut should be reshape as (maxseqlen,1) in runtime
ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell");
}
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void AttentionLSTMOpMaker::Make() {
AddInput("X",
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
AddInput("C0",
"(Tensor) LSTM C0"
"This is a tensor with shape (N x D), where N is the batch size, D "
"is the gate size."
"C0 is necessary because of attention.");
AddInput("H0",
"(Tensor, optional) LSTM H0"
"This is a tensor with shape (N x D), where N is the "
"batch size and D is the gate size.")
.AsDispensable();
AddInput("AttentionWeight",
"(Tensor) the weights of attention fc. Always relu the fc result."
"The shape is ((M+D) x 1), where M is the dim size of x, D is the "
"gate size of LSTM.");
AddInput("AttentionBias",
"(Tensor, optional) the bias of attention fc."
"The shape is (1 x 1)")
.AsDispensable();
AddInput("AttentionScalar",
"(Tensor, optional) the scalar on the result of attentioned fc. "
"Always relu the Scalar."
"The shape is (1 x 1)")
.AsDispensable();
AddInput("AttentionScalarBias",
"(Tensor, optional) the scalar bias of attention fc."
"The shape is (1 x 1)")
.AsDispensable();
AddInput("LSTMWeight",
"(Tensor) the combined weight of LSTM"
" - The shape is ((D+M) x 4D), where D is the hidden gate size, M "
"is the dim size of x"
" - Weight = {W_forget, W_input, W_output, W_cell}");
AddInput("LSTMBias",
"(Tensor) the combined bias of LSTM, shape (1x4D)."
"Note: we should add the bias of hidden and context accorindg to "
"the same gate: "
"{B_forget, B_input, B_output, B_cell}");
AddOutput("Hidden",
"(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("AttentionedX",
"(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size.")
.AsIntermediate();
AddOutput("AttentionFCOut",
"(Tensor) (max_seq_len, 1), compute at each step.")
.AsIntermediate();
AddOutput("LSTMX",
"(Tensor) the input X of LSTM for each step."
"Shape is (1 x M), where M is the x frame size")
.AsIntermediate();
AddOutput(
"LSTMOUT",
"(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step."
"Shape is (1 x 4D), where M is the x frame size")
.AsIntermediate();
AddAttr<std::string>("gate_activation",
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.")
.SetDefault("sigmoid")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("cell_activation",
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("candidate_activation",
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by default.")
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC(
Attention Long-Short Term Memory (LSTM) Operator.
Attention part:
concat( x(seqlen * M), expand( cell_t-1(1,D) ) ) => tmp(seqlen*(M+D))
tmp(seqlen*(M+D)) * fc((M+D)*1) => fcout(seqlen*1) with bias, relu
fcout(seqlen*1) * scalar => fcout(seqlen*1) with bias, relu
dotmul and sum pool ( fcout(seqlen*1), x(seqlen * M) ) => lstm_x_t(1, M)
LSTM part:
use lstm_x_t as input and compute as standard LSTM.
)DOC");
}
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
if (bias) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] + bias[0];
}
math::vec_relu<T>(n, y, y);
} else {
math::vec_relu<T>(n, x, y);
}
}
template <typename DeviceContext, typename T>
inline void vec_softmax(const math::BlasT<DeviceContext, T>& blas, const int n,
const T* x, T* y) {
T scalar = x[0];
// max
for (int i = 1; i < n; ++i) {
scalar = scalar < x[i] ? x[i] : scalar;
}
// sub
for (int i = 0; i < n; ++i) {
y[i] = x[i] - scalar;
}
// exp
blas.VEXP(n, y, y);
// sum
scalar = T(0);
for (int i = 0; i < n; ++i) {
scalar += y[i];
}
// scale
blas.SCAL(n, static_cast<T>(1) / scalar, y);
}
template <typename T>
class AttentionLSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
auto* h0 = ctx.Input<Tensor>("H0");
auto* c0 = ctx.Input<Tensor>("C0");
auto* atten_w = ctx.Input<Tensor>("AttentionWeight");
auto* atten_b = ctx.Input<Tensor>("AttentionBias");
auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar");
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias");
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight");
auto* lstm_b = ctx.Input<Tensor>("LSTMBias");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
auto* atted_x = ctx.Output<Tensor>("AttentionedX");
auto* fc_out = ctx.Output<Tensor>("AttentionFCOut");
auto* lstm_x = ctx.Output<Tensor>("LSTMX");
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT");
// some shape should be reshape here since infershape can not get lod info
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; // batch size
auto x_dims = x->dims(); // T x M
auto w_dims = lstm_w->dims(); // (D+M) x 4D
const int total_T = x_dims[0];
const int M = x_dims[1]; // x frame size
const int D = w_dims[1] / 4; // gate frame size
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = w_dims[1];
int max_seq_len = x_lod[0][1];
for (int i = 1; i < N; ++i) {
int len = x_lod[0][i + 1] - x_lod[0][i];
max_seq_len = max_seq_len < len ? len : max_seq_len;
}
PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1.");
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1});
math::VecActivations<T> act_functor;
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
act_gate = act_functor(ctx.Attr<std::string>("gate_activation"));
act_cell = act_functor(ctx.Attr<std::string>("cell_activation"));
act_cand = act_functor(ctx.Attr<std::string>("candidate_activation"));
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;
const T* c0_data = c0->data<T>();
const T* lstm_w_data = lstm_w->data<T>();
const T* lstm_b_data = lstm_b->data<T>();
const T* atten_w_data = atten_w->data<T>();
const T* atten_b_data = atten_b ? atten_b->data<T>() : NULL;
const T* atten_scalar_data = atten_scalar ? atten_scalar->data<T>() : NULL;
const T* atten_scalar_bias_data =
atten_scalar_bias ? atten_scalar_bias->data<T>() : NULL;
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
T* atted_x_data = atted_x->mutable_data<T>(ctx.GetPlace());
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
atted_x_data, atten_b_data);
const T* cur_atten_x_data = atted_x_data;
const T* cur_x_data = x_data;
const T* prev_cell_data = NULL;
const T* prev_hidden_data = NULL;
T* cur_cell_out_data = cell_out_data;
T* cur_hidden_out_data = hidden_out_data;
for (int i = 0; i < N; ++i) {
int seq_len = x_lod[0][i + 1] - x_lod[0][i];
prev_cell_data = c0_data + i * D;
prev_hidden_data = h0_data ? h0_data + i * D : NULL;
for (int step = 0; step < seq_len; ++step) {
/// 1. compute attention vector
// 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt
T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
// 1b. add cell bias and relu
bias_relu<T>(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data);
// 1c. fc scalar
if (atten_scalar_data) {
blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
fc_out_data);
}
// 1d. softmax
vec_softmax<DeviceContext, T>(blas, seq_len, fc_out_data, fc_out_data);
// mul x(seq_len*M) and sum pool
math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data,
cur_x_data, lstm_x_data);
/// 2. compute LSTM step
// lstm weight : concat[forget , input , output , tilde]
// shape : (D + M) x (4 * D)
// fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D
blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data);
if (prev_hidden_data) {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
prev_hidden_data, D, lstm_w_data, D4, static_cast<T>(1),
lstm_out_data, D4);
}
// since input is 1xM, so can use add bias
blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data);
// gate act: sigmoid
act_gate(D3, lstm_out_data, lstm_out_data);
// candicate act: tanh
act_cand(D, lstm_out_data + D3, lstm_out_data + D3);
// a = forget * prev_cell
blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);
// b = input * tilde
blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D);
// cell_out = a + b
blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);
// state act tanh(cell_out) * output_gate
act_cell(D, cur_cell_out_data, lstm_out_data);
blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
prev_hidden_data = cur_hidden_out_data;
prev_cell_data = cur_cell_out_data;
cur_cell_out_data = cur_cell_out_data + D;
cur_hidden_out_data = cur_hidden_out_data + D;
}
cur_x_data = cur_x_data + seq_len * M;
cur_atten_x_data = cur_atten_x_data + seq_len;
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp,
ops::AttentionLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel<float>,
ops::AttentionLSTMKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class AttentionLSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class AttentionLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// #include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......
......@@ -90,6 +90,11 @@ class Blas {
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
template <typename T>
void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
int ldc) const;
#ifdef PADDLE_WITH_MKLML
template <typename T>
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
......@@ -109,6 +114,10 @@ class Blas {
void GEMM_FREE(T* data) const;
#endif
template <typename T>
void MatMul(const int M, const int N, const int K, const T* A, const T* B,
T* C) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a,
const framework::Tensor& mat_b, bool trans_b, T alpha,
......@@ -140,10 +149,19 @@ class Blas {
template <typename T>
void VCOPY(int n, const T* x, T* y) const;
template <typename T>
void VEXP(int n, const T* x, T* y) const;
template <typename T>
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
T* C) const;
template <typename T>
T DOT(int n, const T* x, const T* y) const;
template <typename T>
void SCAL(int n, const T a, T* x) const;
template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T* A, const T* B, T beta, T* C,
......@@ -215,11 +233,26 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VCOPY<T>(args...);
}
template <typename... ARGS>
void VEXP(ARGS... args) const {
Base()->template VEXP<T>(args...);
}
template <typename... ARGS>
void GEMV(ARGS... args) const {
Base()->template GEMV<T>(args...);
}
template <typename... ARGS>
T DOT(ARGS... args) const {
return Base()->template DOT<T>(args...);
}
template <typename... ARGS>
void SCAL(ARGS... args) const {
Base()->template SCAL<T>(args...);
}
template <typename... ARGS>
void BatchedGEMM(ARGS... args) const {
Base()->template BatchedGEMM<T>(args...);
......
......@@ -73,6 +73,16 @@ struct CBlas<float> {
platform::dynload::cblas_sgemv(args...);
}
template <typename... ARGS>
static float DOT(ARGS... args) {
return platform::dynload::cblas_sdot(args...);
}
template <typename... ARGS>
static void SCAL(ARGS... args) {
platform::dynload::cblas_sscal(args...);
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_sgemm_batch(args...);
......@@ -87,6 +97,11 @@ struct CBlas<float> {
static void VMUL(ARGS... args) {
platform::dynload::vsMul(args...);
}
template <typename... ARGS>
static void VEXP(ARGS... args) {
platform::dynload::vsExp(args...);
}
};
template <>
......@@ -138,6 +153,16 @@ struct CBlas<double> {
platform::dynload::cblas_dgemv(args...);
}
template <typename... ARGS>
static double DOT(ARGS... args) {
return platform::dynload::cblas_ddot(args...);
}
template <typename... ARGS>
static void SCAL(ARGS... args) {
platform::dynload::cblas_dscal(args...);
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_dgemm_batch(args...);
......@@ -152,6 +177,11 @@ struct CBlas<double> {
static void VMUL(ARGS... args) {
platform::dynload::vdMul(args...);
}
template <typename... ARGS>
static void VEXP(ARGS... args) {
platform::dynload::vdExp(args...);
}
};
#else
......@@ -210,6 +240,9 @@ struct CBlas<platform::float16> {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
}
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
#ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
......@@ -217,64 +250,6 @@ struct CBlas<platform::float16> {
#endif
};
template <typename T>
inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa,
bool transb, const T &alpha, const T &beta) {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
std::abs<T>(alpha - static_cast<T>(1) >
std::numeric_limits<T>::epsilon()) ||
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
return false;
} else {
return true;
}
#endif
return false;
}
template <>
inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
bool transa, bool transb,
const platform::float16 &alpha,
const platform::float16 &beta) {
return false;
}
template <typename T>
inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
const T *A, int lda, const T *B, int ldb, T beta, T *C,
int ldc) {
#ifdef PADDLE_WITH_LIBXSMM
if (UseXSMM<T>(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
beta)) {
// Note: SMM use ColMajor
const char transa = 'N';
const char transb = 'N';
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
&beta, C, &ldc);
return;
}
#endif
#ifdef PADDLE_MKL_SPLIT_GEMM
constexpr int bs = 2;
if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
for (int off = 0; off < M; off += bs) {
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
}
return;
}
#endif
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
#ifdef PADDLE_WITH_MKLML
template <>
template <typename T>
......@@ -319,8 +294,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
GEMM_WARP<T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
......@@ -329,9 +304,20 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb,
T beta, T *C, int ldc) const {
GEMM_WARP<T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M,
int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb,
T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <typename DeviceContext>
......@@ -399,6 +385,47 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VEXP(n, x, y);
#else
// try to find if openblas support vexp
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
#endif
}
template <>
template <typename T>
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
#ifdef PADDLE_WITH_MKLML
return CBlas<T>::DOT(n, x, 1, y, 1);
#else
// try to find if openblas support cblas_dot
T sum = 0;
for (int i = 0; i < n; ++i) {
sum += x[i] * y[i];
}
return sum;
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::SCAL(n, a, x, 1);
#else
// try to find if openblas support cblas_scal
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
......@@ -440,6 +467,42 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const int M, const int N, const int K,
const T *A, const T *B, T *C) const {
this->template GEMM<T>(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C,
N);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N,
const int K, const T *A,
const T *B, T *C) const {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
// Since the matrix is very small,
// so the unit of calculation is already very fast,
// and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead,
// use xsmm directly.
// Note: SMM use ColMajor
const char transa = 'N';
const char transb = 'N';
const T alpha = static_cast<T>(1);
const T beta = static_cast<T>(0);
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
C, &N);
return;
#endif
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C, N);
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
template <typename T>
inline T sigmoid(T x) {
return 1. / (1. + exp(-x));
}
template <typename T>
inline T tanh(T x) {
return 2. * sigmoid(2. * x) - 1.;
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_identity(const int n, const T* x, T* y) {
// do nothing
return;
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_sigmoid(const int n, const T* x, T* y) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 1.0 / (1.0 + std::exp(-tmp));
}
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_tanh(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = tanh<T>(x[i]);
}
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_relu(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <>
inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x,
float* y) {
// TODO(TJ): complete me
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <>
inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
float* y) {
// TODO(TJ): complete me
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
class VecActivations {
public:
std::function<void(const int, const T*, T*)> operator()(
const std::string& type) {
if (type == "sigmoid") {
return vec_sigmoid<T, isa>;
} else if (type == "relu") {
return vec_relu<T, isa>;
} else if (type == "tanh") {
return vec_tanh<T, isa>;
} else if (type == "identity" || type == "") {
return vec_identity<T, isa>;
}
PADDLE_THROW("Not support type %s.", type);
}
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -25,17 +25,25 @@ namespace math {
template <typename DeviceContext, typename T>
inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = NULL) {
blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast<T>(1), X, W,
static_cast<T>(0), Y);
if (B) {
const T* B = NULL, bool relu = false) {
blas.MatMul(M, N, K, X, W, Y);
if (B == NULL) {
return;
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), B, Y + i * N);
}
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), B, Y + i * N);
}
if (!relu) {
return;
}
// TODO(TJ): fuse relu
LOG(FATAL) << "Not implemented!";
}
} // namespace math
......
......@@ -103,15 +103,16 @@ size_t CUDAPinnedMaxChunkSize() {
return CUDAPinnedMaxAllocSize() / 256;
}
#ifdef PADDLE_WITH_XBYAK
namespace jit {
#ifdef PADDLE_WITH_XBYAK
static Xbyak::util::Cpu cpu;
bool MayIUse(const cpu_isa_t cpu_isa) {
using namespace Xbyak::util; // NOLINT
switch (cpu_isa) {
case sse42:
return cpu.has(Cpu::tSSE42);
case avx:
return cpu.has(Cpu::tAVX);
case avx2:
return cpu.has(Cpu::tAVX2);
case avx512_common:
......@@ -134,8 +135,16 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
}
return false;
}
#else
bool MayIUse(const cpu_isa_t cpu_isa) {
if (cpu_isa == isa_any) {
return true;
} else {
return false;
}
}
#endif
} // namespace jit
#endif
} // namespace platform
} // namespace paddle
......@@ -37,12 +37,11 @@ size_t CUDAPinnedMinChunkSize();
//! Get the maximum chunk size for buddy allocator.
size_t CUDAPinnedMaxChunkSize();
#ifdef PADDLE_WITH_XBYAK
namespace jit {
typedef enum {
isa_any,
sse42,
avx,
avx2,
avx512_common,
avx512_core,
......@@ -55,7 +54,6 @@ typedef enum {
inline bool MayIUse(const cpu_isa_t cpu_isa);
} // namespace jit
#endif
} // namespace platform
} // namespace paddle
......@@ -66,10 +66,16 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_free); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm_batch); \
__macro(cblas_sdot); \
__macro(cblas_ddot); \
__macro(cblas_sscal); \
__macro(cblas_dscal); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(vsMul); \
__macro(vdMul); \
__macro(vsExp); \
__macro(vdExp); \
__macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from test_fusion_lstm_op import fc, ACTIVATION
from test_softmax_op import stable_softmax
def attention_lstm(
x, # T x M
lod, # 1 x N
h0, # N x D
c0, # N x D
fcws, # (M+D) x 1, 1x1
fcbs, # 1 x 1, 1x1
w, # (M+D) x 4D
b, # 1 x 4D
act_gate,
act_cell,
act_cand):
T = sum(lod[0])
N = len(lod[0])
M = x.shape[1]
D = b.shape[1] / 4
assert T == x.shape[0]
assert len(fcws) == len(fcbs)
hidden = []
cell = []
start_offset = 0
for bid in range(N):
seq_len = lod[0][bid]
xi = np.copy(x[start_offset:start_offset + seq_len, :]).reshape(seq_len,
M)
prev_cell = np.copy(c0[bid]).reshape([1, D])
prev_hidden = np.copy(h0[bid]).reshape([1, D])
for step in range(seq_len):
expanded_cell = np.repeat(prev_cell, seq_len, axis=0)
tmp = np.concatenate((xi, expanded_cell), axis=1)
assert tmp.shape[0] == seq_len
assert tmp.shape[1] == M + D
for fcid in range(len(fcbs)):
tmp = fc(tmp, fcws[fcid], fcbs[fcid])
tmp = ACTIVATION['relu'](tmp)
tmp = np.reshape(tmp, (1, seq_len))
tmp = stable_softmax(tmp).reshape(seq_len, 1)
lstmx = xi * tmp # seq * M
lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M])
lstmin = np.concatenate((prev_hidden, lstmx), axis=1)
lstmout = fc(lstmin, w, b).reshape([1, 4 * D])
g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1)
g_f = act_gate(g_f).reshape([1, D])
g_i = act_gate(g_i).reshape([1, D])
g_o = act_gate(g_o).reshape([1, D])
cand = act_cand(cand).reshape([1, D])
cell_t = (prev_cell * g_f) + (g_i * cand)
hidden_t = g_o * act_cell(cell_t)
hidden.append(hidden_t.flatten())
cell.append(cell_t.flatten())
prev_cell = cell_t.reshape([1, D])
prev_hidden = hidden_t.reshape([1, D])
start_offset += seq_len
hidden = np.array(hidden).astype('float32').reshape([T, D])
cell = np.array(cell).astype('float32').reshape([T, D])
return hidden, cell
class TestAttentionLSTMOp(OpTest):
def set_conf(self):
pass
def setUp(self):
self.op_type = 'attention_lstm'
self.lod = [[3]]
self.M = 30
self.D = 15
self.has_initial_hidden = True
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.set_conf()
T = sum(self.lod[0])
bs = len(self.lod[0])
x = np.random.normal(size=(T, self.M)).astype('float32')
c0 = np.random.normal(size=(bs, self.D)).astype('float32')
if self.has_initial_hidden:
h0 = np.random.normal(size=(bs, self.D)).astype('float32')
else:
h0 = np.zeros((bs, self.D)).astype('float32')
fcw1 = np.random.normal(size=(self.M + self.D, 1)).astype('float32')
fcb1 = np.random.normal(size=(1, 1)).astype('float32')
fcw2 = np.random.normal(size=(1, 1)).astype('float32')
fcb2 = np.random.normal(size=(1, 1)).astype('float32')
# lstm weight and bias
w = np.random.normal(size=(self.M + self.D,
self.D * 4)).astype('float32')
b = np.random.normal(size=(1, self.D * 4)).astype('float32')
h, c = attention_lstm(x, self.lod, h0, c0, [fcw1, fcw2], [fcb1, fcb2],
w, b, ACTIVATION[self.act_gate],
ACTIVATION[self.act_cell],
ACTIVATION[self.act_cand])
self.inputs = {
'X': (x, self.lod),
'C0': c0,
'AttentionWeight': fcw1,
'AttentionBias': fcb1,
'AttentionScalar': fcw2,
'AttentionScalarBias': fcb2,
'LSTMWeight': w,
'LSTMBias': b
}
if self.has_initial_hidden:
self.inputs['H0'] = h0
self.outputs = {
'Hidden': (h, self.lod),
'Cell': (c, self.lod),
}
self.attrs = {
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
'candidate_activation': self.act_cand
}
def test_check_output(self):
self.check_output()
class TestAttentionOpNonInit(TestAttentionLSTMOp):
def set_conf(self):
self.has_initial_hidden = False
class TestAttentionOpAct(TestAttentionLSTMOp):
def set_conf(self):
self.M = 3
self.D = 2
self.act_gate = 'relu'
self.act_cell = 'tanh'
self.act_cand = 'sigmoid'
class TestAttentionOpMD1(TestAttentionLSTMOp):
def set_conf(self):
self.M = 36
self.D = 8
class TestAttentionOpMD2(TestAttentionLSTMOp):
def set_conf(self):
self.M = 8
self.D = 8
class TestAttentionOpMD3(TestAttentionLSTMOp):
def set_conf(self):
self.M = 15
self.D = 30
class TestAttentionOpBS1(TestAttentionLSTMOp):
def set_conf(self):
self.lod = [[5]]
self.M = 16
self.D = 32
class TestAttentionOpBS2(TestAttentionLSTMOp):
def set_conf(self):
self.lod = [[3, 6]]
class TestAttentionOpBS5(TestAttentionLSTMOp):
def set_conf(self):
self.lod = [[3, 2, 4, 7, 5]]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册