未验证 提交 5ef14dd3 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13715 from tensor-tang/fix/op

bugfix fusion lstm and gru batch,seq  mode switch
...@@ -290,12 +290,13 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -290,12 +290,13 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) { if (x->lod()[0].size() == 2) {
xx->Resize({total_T, D3});
SeqCompute(ctx); SeqCompute(ctx);
return; return;
} }
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC INIT_VEC_FUNC
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0"); auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
......
...@@ -432,11 +432,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -432,11 +432,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext; using DeviceContext = platform::CPUDeviceContext;
INIT_BASE_INPUT_OUTPUT INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) { if (x->lod()[0].size() == 2) {
xx->Resize({x_dims[0], D4});
SeqCompute(ctx); SeqCompute(ctx);
return; return;
} }
INIT_BASE_SIZES
INIT_VEC_FUNC INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS INIT_BASE_INPUT_DATAS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册