未验证 提交 1cc35f36 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13118 from tensor-tang/optimize/op/fusion_lstm

Optimize fusion lstm batch mode
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
......@@ -94,11 +95,31 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetOutput("Hidden", {hidden_n->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()});
op_desc.SetOutput("XX", {xx_n->Name()});
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"});
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true);
#define TMP_NAME(x) "at.new.tmp." #x
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
OP_SET_OUT(ReorderedC0);
#undef OP_SET_OUT
auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>()
TMP_NEW(BatchedCell);
TMP_NEW(BatchedHidden);
TMP_NEW(ReorderedH0);
TMP_NEW(ReorderedC0);
#undef TMP_NEW
#undef TMP_NAME
#define LINK_TO(a, b) \
a->outputs.push_back(b); \
......@@ -116,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto fc_no_bias_handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
#define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \
......
......@@ -16,14 +16,10 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(seq_mode, true, "Use sequence mode");
namespace paddle {
namespace operators {
......@@ -42,10 +38,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
"Output(BatchedGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchedGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
"Output(BatchedCell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
"Output(ReorderedC0) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
......@@ -97,13 +99,14 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchCellPreAct", out_dims);
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims);
ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell");
int xx_width;
if (FLAGS_seq_mode) {
if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1];
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
......@@ -169,9 +172,11 @@ void FusionLSTMOpMaker::Make() {
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input.")
.AsIntermediate();
AddOutput("BatchedGate", "(LoDTensor) (same as LSTMOp).").AsIntermediate();
AddOutput("BatchCellPreAct", "(LoDTensor) (same as LSTMOp).")
.AsIntermediate();
AddOutput("BatchedInput", "(LoDTensor) (T x 4D).").AsIntermediate();
AddOutput("BatchedHidden", "(LoDTensor) (T x D).").AsIntermediate();
AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
AddAttr<bool>("use_peepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
......@@ -180,6 +185,10 @@ void FusionLSTMOpMaker::Make() {
"(bool, defalut: False) "
"whether to compute reversed LSTM.")
.SetDefault(false);
AddAttr<bool>("use_seq",
"(bool, defalut: True) "
"whether to use seq mode to compute.")
.SetDefault(true);
AddAttr<std::string>("gate_activation",
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
......@@ -203,64 +212,60 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
)DOC");
}
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src,
framework::Vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
// TODO(TJ): check mem copy perf
row_shuffle(ctx, src, index_lod, dst, indexed_src);
}
template <typename T>
class FuisonLSTMKernel : public framework::OpKernel<T> {
public:
#define INIT_VEC_FUNC \
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \
} else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \
}
#define INIT_BASE_INPUT_OUTPUT \
auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool is_reverse = ctx.Attr<bool>("is_reverse");
#define INIT_BASE_SIZES \
auto x_dims = x->dims(); /* T x M*/ \
auto wh_dims = wh->dims(); /* D x 4D*/ \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const int D3 = D * 3; \
const int D4 = wh_dims[1];
void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
auto* h0 = ctx.Input<Tensor>("H0");
auto* c0 = ctx.Input<Tensor>("C0");
auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
bool is_reverse = ctx.Attr<bool>("is_reverse");
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
}
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
auto x_lod = x->lod();
auto x_dims = x->dims(); // T x M
auto wh_dims = wh->dims(); // D x 4D
const int total_T = x_dims[0];
const int N = x_lod[0].size() - 1; // batch size
const int M = x_dims[1]; // x frame size
const int D = wh_dims[0];
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = wh_dims[1];
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;
const T* c0_data = c0 ? c0->data<T>() : NULL;
const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* c0_data = c0 ? c0->data<T>() : nullptr;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
......@@ -290,12 +295,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
for (int i = 0; i < N; ++i) {
int bid = is_reverse ? N - 1 - i : i;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_cell_data = NULL;
const T* prev_hidden_data = NULL;
const T* prev_c_data = nullptr;
const T* prev_h_data = nullptr;
int tstart = 0;
if (h0_data) {
prev_hidden_data = h0_data + bid * D;
prev_cell_data = c0_data + bid * D;
prev_h_data = h0_data + bid * D;
prev_c_data = c0_data + bid * D;
} else {
// W_ch, W_ih, W_fh, W_oh
act_gate(D3, xx_data + D, xx_data + D);
......@@ -307,23 +312,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
prev_hidden_data = hidden_out_data;
prev_cell_data = cell_out_data;
prev_h_data = hidden_out_data;
prev_c_data = cell_out_data;
tstart = 1;
move_step();
}
for (int step = tstart; step < seq_len; ++step) {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D4, static_cast<T>(1), xx_data,
D4);
prev_h_data, D, wh_data, D4, static_cast<T>(1), xx_data, D4);
// W_ch, W_ih, W_fh, W_oh
act_gate(D3, xx_data + D, xx_data + D);
act_cand(D, xx_data, xx_data);
// a = forget * prev_cell
blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2);
// b = input * tilde
blas.VMUL(D, xx_data, xx_data + D, xx_data + D);
......@@ -336,8 +340,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
prev_hidden_data = hidden_out_data;
prev_cell_data = cell_out_data;
prev_h_data = hidden_out_data;
prev_c_data = cell_out_data;
move_step();
}
......@@ -346,143 +350,155 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_t0 = ctx.Input<Tensor>("H0");
auto* cell_t0 = ctx.Input<Tensor>("C0");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
bool is_reverse = ctx.Attr<bool>("is_reverse");
INIT_BASE_INPUT_OUTPUT
if (x->lod()[0].size() == 2) {
SeqCompute(ctx);
return;
}
INIT_BASE_SIZES
INIT_VEC_FUNC
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
cell_out->mutable_data<T>(ctx.GetPlace());
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>();
auto x_dims = x->dims();
auto wx_dims = wx->dims();
const T* wh_data = wh->data<T>();
auto place = ctx.GetPlace();
T* xx_data = xx->mutable_data<T>(place);
T* batched_input_data = batched_input->mutable_data<T>(place);
T* batched_c_out_data = batched_c_out->mutable_data<T>(place);
T* batched_h_out_data = batched_h_out->mutable_data<T>(place);
hidden_out->mutable_data<T>(place);
cell_out->mutable_data<T>(place);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (x_dims[1] > wx_dims[1]) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
x_data, wx_data, xx_data,
bias->data<T>());
to_batch(dev_ctx, *xx, batched_gate, true, is_reverse);
if (M > D4) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
xx_data, bias->data<T>());
to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_gate->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
xx_data, wx_data, batched_gate_data,
batched_input->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, xx_data,
wx_data, batched_input_data,
bias->data<T>());
}
int frame_size = static_cast<int>(wx_dims[1] / 4);
framework::DDim out_dims({x_dims[0], frame_size});
math::LstmMetaValue<T> lstm_value;
// no peephole
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
lstm_value.prev_state_value = nullptr;
Tensor ordered_c0;
framework::Vector<size_t> order(batched_gate->lod()[2]);
if (cell_t0) {
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<DeviceContext, T>(dev_ctx, *cell_t0, order, &ordered_c0,
true);
lstm_value.prev_state_value = ordered_c0.data<T>();
auto batched_lod = batched_input->lod();
const auto& seq_order = batched_lod[2];
const int max_bs = seq_order.size();
reordered_h0->Resize({max_bs, D});
reordered_c0->Resize({max_bs, D});
int tstart = 0;
T* prev_h_data = nullptr;
T* prev_c_data = nullptr;
if (h0) {
// reorder h0, c0
T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
T* reordered_c0_data = reordered_c0->mutable_data<T>(place);
const T* h0_data = h0->data<T>();
const T* c0_data = c0->data<T>();
prev_h_data = reordered_h0_data;
prev_c_data = reordered_c0_data;
size_t sz = sizeof(T) * D;
for (int i = 0; i < max_bs; ++i) {
std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
std::memcpy(reordered_c0_data, c0_data + seq_order[i] * D, sz);
reordered_h0_data += D;
reordered_c0_data += D;
}
} else {
// compute without h0, c0
T* cur_in_data = batched_input_data;
T* cur_h_out_data = batched_h_out_data;
T* cur_c_out_data = batched_c_out_data;
// W_ch, W_ih, W_fh, W_oh
for (int i = 0; i < max_bs; ++i) {
act_gate(D3, cur_in_data + D, cur_in_data + D);
act_cand(D, cur_in_data, cur_in_data);
// cell out= input*tilde
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_c_out_data);
// hidden out= act_state(cellout) * outgate
act_cell(D, cur_c_out_data, cur_in_data + D2);
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
// add offset
cur_in_data += D4;
cur_c_out_data += D;
cur_h_out_data += D;
}
tstart = 1;
prev_h_data = batched_h_out_data;
prev_c_data = batched_c_out_data;
}
// Then start from next
const auto& batch_starts = batched_lod[0];
const int max_seq_len = batch_starts.size() - 1;
const int offset = tstart * max_bs * D;
batched_input_data = batched_input_data + offset * 4;
batched_h_out_data = batched_h_out_data + offset;
batched_c_out_data = batched_c_out_data + offset;
for (int step = tstart; step < max_seq_len; ++step) {
const int cur_bs = batch_starts[step + 1] - batch_starts[step];
blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D4, D, static_cast<T>(1),
prev_h_data, D, wh_data, D4, static_cast<T>(1),
batched_input_data, D4);
T* cur_in_data = batched_input_data;
T* cur_prev_c_data = prev_c_data;
T* cur_c_out_data = batched_c_out_data;
T* cur_h_out_data = batched_h_out_data;
for (int i = 0; i < cur_bs; ++i) {
// W_ch, W_ih, W_fh, W_oh
act_gate(D3, cur_in_data + D, cur_in_data + D);
act_cand(D, cur_in_data, cur_in_data);
// a = forget * prev_cell
blas.VMUL(D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2);
// b = input * tilde
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D);
// cell out= a+b
blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data);
// hidden out= act_state(cellout) * outgate
act_cell(D, cur_c_out_data, cur_in_data + D2);
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
// Use the local variable as here.
LoDTensor batch_hidden, batch_cell;
auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
batch_hidden.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_pre_act->mutable_data<T>(out_dims, ctx.GetPlace());
auto batch_starts = batched_gate->lod()[0];
size_t max_seq_len = batch_starts.size() - 1;
auto gate_act = math::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation"));
auto cell_act = math::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
for (size_t n = 0; n < max_seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor out_t = batch_hidden.Slice(bstart, bend);
Tensor cell_t = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend);
int cur_batch_size = bend - bstart;
if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
// TODO(TJ): use gemm directly
blas.MatMul(pre_hidden_t, false, *wh, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
} else if (hidden_t0) {
// TODO(TJ): move h0 outside for
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor ordered_h0;
ReorderInitState<DeviceContext, T>(dev_ctx, *hidden_t0, order,
&ordered_h0, true);
// TODO(TJ): use gemm directly
blas.MatMul(ordered_h0, false, *wh, false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
cur_in_data += D4;
cur_prev_c_data += D;
cur_c_out_data += D;
cur_h_out_data += D;
}
lstm_value.gate_value = gate_t.data<T>();
lstm_value.output_value = out_t.data<T>();
lstm_value.state_value = cell_t.data<T>();
lstm_value.state_active_value = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<DeviceContext, T>::compute(
dev_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act,
cand_act);
lstm_value.prev_state_value = lstm_value.state_value;
prev_c_data = batched_c_out_data;
prev_h_data = batched_h_out_data;
batched_c_out_data = cur_c_out_data;
batched_h_out_data = cur_h_out_data;
batched_input_data = cur_in_data;
}
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden.set_lod(batched_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden
to_seq(dev_ctx, batch_hidden, hidden_out);
batch_cell.set_lod(batched_gate->lod());
// restore the output cell state in LoDTensor from the batch cell
to_seq(dev_ctx, batch_cell, cell_out);
batched_h_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_h_out, hidden_out);
batched_c_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_c_out, cell_out);
}
void Compute(const framework::ExecutionContext& ctx) const override {
if (FLAGS_seq_mode) {
if (ctx.Attr<bool>("use_seq")) {
SeqCompute(ctx);
} else {
BatchCompute(ctx);
}
}
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册