未验证 提交 20659fc9 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13107 from tensor-tang/optimize/op/fusion_gru

Optimize fusion gru
...@@ -13,16 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,16 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h" #include "paddle/fluid/operators/fusion_gru_op.h"
#include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -35,12 +32,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -35,12 +32,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"Input(WeightH) of GRU should not be null."); "Input(WeightH) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"), PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(BatchedGate) of GRU should not be null."); "Output(ReorderedH0) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"), PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchResetHiddenPrev) of GRU should not be null."); "Output(BatchedInput) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedHidden) of GRU should not be null."); "Output(BatchedOut) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null."); "Output(Hidden) of GRU should not be null.");
...@@ -83,12 +80,16 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -83,12 +80,16 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
} }
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedOut", out_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", out_dims);
ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Hidden");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; int xx_width;
if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1];
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
}
ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX"); ctx->ShareLoD("X", "XX");
} }
...@@ -115,22 +116,29 @@ void FusionGRUOpMaker::Make() { ...@@ -115,22 +116,29 @@ void FusionGRUOpMaker::Make() {
"(Tensor) The FC weight with shape (M x 3D)," "(Tensor) The FC weight with shape (M x 3D),"
"where M is the dim size of x, D is the hidden size. "); "where M is the dim size of x, D is the hidden size. ");
AddInput("WeightH", AddInput("WeightH",
"(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. "); "(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. "
"This weight is not exactly D x 3D as: {W_update, W_reset, W_state}"
"Acutally they are D x 2D and D x D two part weights."
"{W_update, W_reset; W_state}"
"{D x (D + D); D x D}");
AddInput("Bias", AddInput("Bias",
"(Tensor, optional) (1 x 3D)." "(Tensor, optional) (1 x 3D)."
"Almost same as GRUOp." "Almost same as GRUOp."
"Note: if have FC bias it should be added on this bias.") "Note: if have FC bias it should be added on this bias.")
.AsDispensable(); .AsDispensable();
AddOutput("ReorderedH0", "(Tensor) (N x D), which N is the min-batch size.")
.AsIntermediate();
AddOutput("XX", AddOutput("XX",
"(LoDTensor) the result after X * WeightX (size is T x 4D)" "(LoDTensor) the result after X * WeightX (size is T x 3D)"
" or batched_X (size is T x M), this will be automatically chosen," " or batched_X (size is T x M), this will be automatically chosen,"
" where T is the total time steps in this mini-batch," " where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input.") " D is the hidden size, M is the dim size of x input.")
.AsIntermediate(); .AsIntermediate();
AddOutput("BatchedGate", "(LoDTensor) Same as GRUOp").AsIntermediate(); AddOutput("BatchedInput",
AddOutput("BatchResetHiddenPrev", "(LoDTensor) (T x 3D) Same as GRUOp.") "(LoDTensor) This is the batched result of input X"
"or the batched result after fc, shape (T x 3D)")
.AsIntermediate(); .AsIntermediate();
AddOutput("BatchedHidden", "(LoDTensor) (T X D) Same as GRUOp.") AddOutput("BatchedOut", "(LoDTensor) (T X D) save batched hidden.")
.AsIntermediate(); .AsIntermediate();
AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp"); AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
AddAttr<std::string>("activation", AddAttr<std::string>("activation",
...@@ -146,6 +154,10 @@ void FusionGRUOpMaker::Make() { ...@@ -146,6 +154,10 @@ void FusionGRUOpMaker::Make() {
"(bool, defalut: False) " "(bool, defalut: False) "
"whether to compute reversed GRU.") "whether to compute reversed GRU.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_seq",
"(bool, defalut: True) "
"whether to use seq mode to compute GRU.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
The Fusion complete GRU Operator. The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU, This operator fuse the fully-connected operator into GRU,
...@@ -153,172 +165,261 @@ more details can refer to GRU op. ...@@ -153,172 +165,261 @@ more details can refer to GRU op.
)DOC"); )DOC");
} }
template <typename DeviceContext, typename T> template <typename T>
inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src,
framework::Vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index_lod, dst, indexed_src);
}
template <typename DeviceContext, typename T>
class FusionGRUKernel : public framework::OpKernel<T> { class FusionGRUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X"); if (ctx.Attr<bool>("use_seq")) {
auto* wx = ctx.Input<Tensor>("WeightX"); SeqCompute(ctx);
auto* wh = ctx.Input<Tensor>("WeightH"); } else {
auto* bias = ctx.Input<Tensor>("Bias"); BatchCompute(ctx);
auto* h0 = ctx.Input<Tensor>("H0"); }
}
auto* xx = ctx.Output<LoDTensor>("XX");
auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate"); #define INIT_VEC_FUNC \
auto* batch_reset_hidden_prev = std::function<void(const int, const T *, T *)> act_gate, act_state; \
ctx.Output<LoDTensor>("BatchResetHiddenPrev"); std::function<void(const int, const T*, const T*, const T*, T*)> cross; \
auto* batch_hidden = ctx.Output<LoDTensor>("BatchedHidden"); auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); auto& act_state_str = ctx.Attr<std::string>("activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \
act_gate = act_functor(act_gate_str); \
act_state = act_functor(act_state_str); \
cross = math::vec_cross<T, platform::jit::avx>; \
} else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_state = act_functor(act_state_str); \
cross = math::vec_cross<T, platform::jit::isa_any>; \
}
#define INIT_BASE_INPUT_OUTPUT \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); bool is_reverse = ctx.Attr<bool>("is_reverse");
#define INIT_BASE_SIZES \
auto x_dims = x->dims(); /* T x M*/ \
auto wh_dims = wh->dims(); /* D x 3D*/ \
const int total_T = x_dims[0]; \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D3 = wh_dims[1]; \
const int D2 = D * 2;
void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1;
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
const T* wh_state_data = wh_data + D * D2;
T* xx_data = xx->mutable_data<T>(ctx.GetPlace()); T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace()); T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(ctx.GetPlace());
batch_hidden->mutable_data<T>(ctx.GetPlace()); auto blas = math::GetBlas<DeviceContext, T>(ctx);
hidden_out->mutable_data<T>(ctx.GetPlace()); math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data,
bias ? bias->data<T>() : nullptr);
int xx_offset = D3;
int gate_offset = D;
if (is_reverse) {
const int offset = (total_T - 1) * D;
xx_data = xx_data + offset * 3;
hidden_out_data = hidden_out_data + offset;
xx_offset = -D3;
gate_offset = -D;
}
auto move_step = [&]() {
xx_data = xx_data + xx_offset;
hidden_out_data = hidden_out_data + gate_offset;
};
for (int i = 0; i < N; ++i) {
int bid = is_reverse ? N - 1 - i : i;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_hidden_data = nullptr;
int tstart = 0;
if (h0_data) {
prev_hidden_data = h0_data + bid * D;
} else {
// W: {W_update, W_reset; W_state}
// update gate
act_gate(D, xx_data, xx_data);
// state gate
act_state(D, xx_data + D2, xx_data + D2);
// out = a*b
blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data);
// save prev
prev_hidden_data = hidden_out_data;
tstart = 1;
move_step();
}
for (int step = tstart; step < seq_len; ++step) {
// gemm prev * (Wu + Wr)
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3);
act_gate(D2, xx_data, xx_data);
// rt = rt*ht_1 inplace result
blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data);
// gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3);
act_state(D, xx_data + D2, xx_data + D2);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
// save prev
prev_hidden_data = hidden_out_data;
move_step();
}
}
}
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
if (x->lod()[0].size() == 2) {
SeqCompute(ctx);
return;
}
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>(); const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>(); const T* wh_data = wh->data<T>();
auto x_dims = x->dims(); T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
auto wx_dims = wx->dims(); T* batched_input_data = batched_input->mutable_data<T>(ctx.GetPlace());
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
if (x_dims[1] > wx_dims[1]) { if (M > D3) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1], math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
x_data, wx_data, xx_data, xx_data,
bias ? bias->data<T>() : NULL); bias ? bias->data<T>() : nullptr);
to_batch(dev_ctx, *xx, batched_gate, true, is_reverse); to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
} else { } else {
to_batch(dev_ctx, *x, xx, true, is_reverse); to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_gate->set_lod(xx->lod()); batched_input->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1], math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, xx_data, wx_data,
xx_data, wx_data, batched_gate_data, batched_input_data,
bias ? bias->data<T>() : NULL); bias ? bias->data<T>() : nullptr);
} }
int frame_size = static_cast<int>(wx_dims[1] / 3); auto batched_lod = batched_input->lod();
math::GRUMetaValue<T> gru_value; const auto& seq_order = batched_lod[2];
gru_value.gate_weight = const_cast<T*>(wh_data); const int max_bs = seq_order.size();
gru_value.state_weight = reordered_h0->Resize({max_bs, D});
const_cast<T*>(wh_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
framework::Vector<size_t> order(batched_gate->lod()[2]);
int tstart = 0;
T* prev_hidden_data = nullptr;
if (h0) { if (h0) {
ReorderInitState<DeviceContext, T>( // reorder h0
ctx.template device_context<DeviceContext>(), *h0, order, &ordered_h0, T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace());
true); const T* h0_data = h0->data<T>();
gru_value.prev_out_value = ordered_h0.data<T>(); prev_hidden_data = reordered_h0_data;
} else { size_t sz = sizeof(T) * D;
gru_value.prev_out_value = nullptr; for (int i = 0; i < max_bs; ++i) {
std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
reordered_h0_data += D;
} }
auto batch_starts = batched_gate->lod()[0]; } else {
size_t seq_len = batch_starts.size() - 1; // compute without h0
auto active_node = T* cur_in_data = batched_input_data;
math::detail::GetActivationType(ctx.Attr<std::string>("activation")); T* cur_out_data = batched_out_data;
auto active_gate = math::detail::GetActivationType( // W: {W_update, W_reset; W_state}
ctx.Attr<std::string>("gate_activation")); for (int i = 0; i < max_bs; ++i) {
// update gate
#ifdef PADDLE_WITH_MKLML act_gate(D, cur_in_data, cur_in_data);
// use MKL packed to speedup GEMM // state gate
if (FLAGS_paddle_num_threads >= 4) { act_state(D, cur_in_data + D2, cur_in_data + D2);
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); // out = a*b
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data);
frame_size * 2 /*width of weight*/, // add offset
frame_size /*height of height*/); cur_in_data += D3;
PADDLE_ENFORCE(packed_gate); cur_out_data += D;
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
packed_gate);
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_state);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
frame_size, T(1.0), gru_value.state_weight, frame_size,
packed_state);
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
frame_size, gru_value.prev_out_value, frame_size, packed_gate,
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
} }
tstart = 1;
math::detail::forward_reset_output( prev_hidden_data = batched_out_data;
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
cur_batch_size, active_gate);
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
gru_value.reset_output_value, frame_size, packed_state,
frame_size, T(1), gru_value.gate_value + frame_size * 2,
frame_size * 3);
} }
// Then start from next
math::detail::forward_final_output( const T* wh_state_data = wh_data + D * D2;
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size, const auto& batch_starts = batched_lod[0];
cur_batch_size, active_node); const int max_seq_len = batch_starts.size() - 1;
batched_input_data = batched_input_data + tstart * max_bs * D3;
gru_value.prev_out_value = gru_value.output_value; batched_out_data = batched_out_data + tstart * max_bs * D;
for (int step = tstart; step < max_seq_len; ++step) {
const int cur_bs = batch_starts[step + 1] - batch_starts[step];
// gemm prev * (Wu + Wr)
blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1),
batched_input_data, D3);
T* cur_batched_data = batched_input_data;
T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) {
act_gate(D2, cur_batched_data, cur_batched_data);
// rt = rt*ht_1 inplace result
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
cur_batched_data += D3;
cur_prev_hidden_data += D;
cur_out_data += D;
} }
blas.GEMM_FREE(packed_gate); cur_batched_data = batched_input_data;
blas.GEMM_FREE(packed_state); cur_out_data = batched_out_data;
} else { blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast<T>(1),
#endif cur_out_data, D, wh_state_data, D, static_cast<T>(1),
for (size_t n = 0; n < seq_len; n++) { cur_batched_data + D2, D3);
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]); cur_prev_hidden_data = prev_hidden_data;
int cur_batch_size = bend - bstart; for (int i = 0; i < cur_bs; ++i) {
// ht~ = act_state(...)
Tensor gate_t = batched_gate->Slice(bstart, bend); act_state(D, cur_batched_data + D2, cur_batched_data + D2);
Tensor reset_hidden_prev_t = // out = zt*ht~ + (1-zt)*ht_1
batch_reset_hidden_prev->Slice(bstart, bend); cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data,
Tensor hidden_t = batch_hidden->Slice(bstart, bend); cur_out_data);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>(); cur_batched_data += D3;
gru_value.reset_output_value = reset_hidden_prev_t.data<T>(); cur_prev_hidden_data += D;
cur_out_data += D;
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
gru_value.prev_out_value = gru_value.output_value;
} }
#ifdef PADDLE_WITH_MKLML prev_hidden_data = batched_out_data;
batched_out_data = cur_out_data;
batched_input_data = cur_batched_data;
} }
#endif
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batched_gate->lod()); batched_out->set_lod(batched_lod);
to_seq(dev_ctx, *batch_hidden, hidden_out); to_seq(dev_ctx, *batched_out, hidden_out);
} }
#undef INIT_VEC_FUNC
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
}; };
} // namespace operators } // namespace operators
...@@ -327,6 +428,5 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -327,6 +428,5 @@ class FusionGRUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker, REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel<float>,
fusion_gru, ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, float>, ops::FusionGRUKernel<double>);
ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -132,6 +132,121 @@ inline void vec_scal<float, platform::jit::avx512_common>(const int n, ...@@ -132,6 +132,121 @@ inline void vec_scal<float, platform::jit::avx512_common>(const int n,
vec_scal<float, platform::jit::avx2>(n, a, x, y); vec_scal<float, platform::jit::avx2>(n, a, x, y);
} }
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_bias_sub(const int n, const T a, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = a - x[i];
}
}
template <>
inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
const float* x, float* y) {
#ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK;
if (n < block) {
vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
return;
}
const int rest = n % block;
const int end = n - rest;
int i = 0;
__m256 bias = _mm256_set1_ps(a);
__m256 tmp;
#define MOVE_ONE_STEP \
tmp = _mm256_loadu_ps(x + i); \
tmp = _mm256_sub_ps(bias, tmp); \
_mm256_storeu_ps(y + i, tmp)
for (i = 0; i < end; i += block) {
MOVE_ONE_STEP;
}
#undef MOVE_ONE_STEP
if (rest == 0) {
return;
}
// can not continue move step if src and dst are inplace
for (i = n - rest; i < n; ++i) {
y[i] = a - x[i];
}
#else
vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
#endif
}
template <>
inline void vec_bias_sub<float, platform::jit::avx2>(const int n, const float a,
const float* x, float* y) {
vec_bias_sub<float, platform::jit::avx>(n, a, x, y);
}
template <>
inline void vec_bias_sub<float, platform::jit::avx512_common>(const int n,
const float a,
const float* x,
float* y) {
// TODO(TJ): enable me
vec_bias_sub<float, platform::jit::avx2>(n, a, x, y);
}
// out = x*y + (1-x)*z
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) {
for (int i = 0; i < n; ++i) {
out[i] = x[i] * y[i] + (static_cast<T>(1) - x[i]) * z[i];
}
}
template <>
inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
const float* y, const float* z,
float* out) {
#ifdef __AVX__
constexpr int block = AVX_FLOAT_BLOCK;
if (n < block) {
vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
return;
}
const int rest = n % block;
const int end = n - rest;
int i = 0;
__m256 bias = _mm256_set1_ps(1.f);
__m256 tmpx, tmpy, tmpz;
for (i = 0; i < end; i += block) {
tmpx = _mm256_loadu_ps(x + i);
tmpy = _mm256_loadu_ps(y + i);
tmpz = _mm256_loadu_ps(z + i);
tmpy = _mm256_mul_ps(tmpx, tmpy);
tmpx = _mm256_sub_ps(bias, tmpx);
tmpz = _mm256_mul_ps(tmpx, tmpz);
tmpz = _mm256_add_ps(tmpy, tmpz);
_mm256_storeu_ps(out + i, tmpz);
}
if (rest == 0) {
return;
}
// can not continue move step if src and dst are inplace
for (i = n - rest; i < n; ++i) {
out[i] = x[i] * y[i] + (1.f - x[i]) * z[i];
}
#else
vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
#endif
}
template <>
inline void vec_cross<float, platform::jit::avx2>(const int n, const float* x,
const float* y,
const float* z, float* out) {
vec_cross<float, platform::jit::avx>(n, x, y, z, out);
}
template <>
inline void vec_cross<float, platform::jit::avx512_common>(
const int n, const float* x, const float* y, const float* z, float* out) {
// TODO(TJ): enable me
vec_cross<float, platform::jit::avx>(n, x, y, z, out);
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_add_bias(const int n, const T a, const T* x, T* y) { inline void vec_add_bias(const int n, const T a, const T* x, T* y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
......
...@@ -92,7 +92,7 @@ class LoDTensor2BatchFunctor { ...@@ -92,7 +92,7 @@ class LoDTensor2BatchFunctor {
// Calculate the start position of each batch. // Calculate the start position of each batch.
// example: sequences = {s0, s1, s2} // example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// num_batch = 5, // max_seqlen = 5,
// batchIndex = {b0, b1, b2, b3, b4} // batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12} // batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
...@@ -109,7 +109,7 @@ class LoDTensor2BatchFunctor { ...@@ -109,7 +109,7 @@ class LoDTensor2BatchFunctor {
// where 1 is the second sequence, // where 1 is the second sequence,
// 0 is the first sequence, // 0 is the first sequence,
// 2 is the third sequence. // 2 is the third sequence.
// The num_batch represents batch size after rearranging the // The max_seqlen represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence. // input LodTensor. It is also the maximum length of input sequence.
paddle::framework::LoD batch_lods; paddle::framework::LoD batch_lods;
...@@ -118,8 +118,8 @@ class LoDTensor2BatchFunctor { ...@@ -118,8 +118,8 @@ class LoDTensor2BatchFunctor {
batch_lods.emplace_back(std::vector<size_t>{0}); batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor // batch_lods[0] is the start positions for batch LoDTensor
int num_batch = seq_info[0].length; int max_seqlen = seq_info[0].length;
batch_lods[0].resize(static_cast<size_t>(num_batch + 1)); batch_lods[0].resize(static_cast<size_t>(max_seqlen + 1));
// batch_lods[1] is the raw index in the input LoDTensor // batch_lods[1] is the raw index in the input LoDTensor
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0])); batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
// batch_lods[2] is the sort order for the input LoDTensor. // batch_lods[2] is the sort order for the input LoDTensor.
...@@ -128,7 +128,7 @@ class LoDTensor2BatchFunctor { ...@@ -128,7 +128,7 @@ class LoDTensor2BatchFunctor {
size_t* batch_starts = batch_lods[0].data(); size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data(); size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0; batch_starts[0] = 0;
for (int n = 0; n < num_batch; n++) { for (int n = 0; n < max_seqlen; n++) {
auto batch_id = static_cast<int>(batch_starts[n]); auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t i = 0; i < seq_info.size(); ++i) { for (size_t i = 0; i < seq_info.size(); ++i) {
int seq_len = seq_info[i].length; int seq_len = seq_info[i].length;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册