From a2eebf2f01891f19aca7fd1e5ec414efab80cb33 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sun, 30 Sep 2018 18:03:32 +0800 Subject: [PATCH] bugfix: fusion lstm and gru batch,seq mode switch test=develop --- paddle/fluid/operators/fusion_gru_op.cc | 5 +++-- paddle/fluid/operators/fusion_lstm_op.cc | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index 31e87d911..a04c1c126 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -290,12 +290,13 @@ class FusionGRUKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); + INIT_BASE_INPUT_OUTPUT + INIT_BASE_SIZES if (x->lod()[0].size() == 2) { + xx->Resize({total_T, D3}); SeqCompute(ctx); return; } - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES INIT_VEC_FUNC auto* reordered_h0 = ctx.Output("ReorderedH0"); diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 8ca79d20e..a20a5880d 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -424,11 +424,12 @@ class FuisonLSTMKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = platform::CPUDeviceContext; INIT_BASE_INPUT_OUTPUT + INIT_BASE_SIZES if (x->lod()[0].size() == 2) { + xx->Resize({x_dims[0], D4}); SeqCompute(ctx); return; } - INIT_BASE_SIZES INIT_VEC_FUNC INIT_BASE_INPUT_DATAS -- GitLab