提交 a2eebf2f 编写于 作者: T tensor-tang

bugfix: fusion lstm and gru batch,seq mode switch

test=develop
上级 acde6e3c
......@@ -290,12 +290,13 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) {
xx->Resize({total_T, D3});
SeqCompute(ctx);
return;
}
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
......
......@@ -424,11 +424,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext;
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) {
xx->Resize({x_dims[0], D4});
SeqCompute(ctx);
return;
}
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册