From ea0b98e00740d09697d703c8bd135bfcac6266f9 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sun, 30 Sep 2018 18:03:32 +0800 Subject: [PATCH] bugfix: fusion lstm and gru batch,seq mode switch test=develop --- paddle/fluid/operators/fusion_gru_op.cc | 5 +++-- paddle/fluid/operators/fusion_lstm_op.cc | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index 31e87d9113..a04c1c1263 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -290,12 +290,13 @@ class FusionGRUKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); + INIT_BASE_INPUT_OUTPUT + INIT_BASE_SIZES if (x->lod()[0].size() == 2) { + xx->Resize({total_T, D3}); SeqCompute(ctx); return; } - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES INIT_VEC_FUNC auto* reordered_h0 = ctx.Output("ReorderedH0"); diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 23e8edd18d..ae1f6d8e48 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -432,11 +432,12 @@ class FuisonLSTMKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = platform::CPUDeviceContext; INIT_BASE_INPUT_OUTPUT + INIT_BASE_SIZES if (x->lod()[0].size() == 2) { + xx->Resize({x_dims[0], D4}); SeqCompute(ctx); return; } - INIT_BASE_SIZES INIT_VEC_FUNC INIT_BASE_INPUT_DATAS -- GitLab