diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index dc82e4fa754ebf586aa37f9bb547c1f5b3416fb8..a2d61695649dcc6825dbcda9258b03983ae435af 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -33,13 +33,15 @@ class GRUOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU"); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU"); - OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU"); - OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output", - "BatchResetHiddenPrev", "GRU"); - OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden", - "GRU"); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU"); - + bool is_test = ctx->Attrs().Get("is_test"); + if (!is_test) { + OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU"); + OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output", + "BatchResetHiddenPrev", "GRU"); + OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden", + "GRU"); + } auto input_dims = ctx->GetInputDim("Input"); auto weight_dims = ctx->GetInputDim("Weight"); int input_size = input_dims[1]; @@ -84,9 +86,11 @@ class GRUOp : public framework::OperatorWithKernel { "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", bias_height, bias_width, frame_size * 3)); } - ctx->SetOutputDim("BatchGate", input_dims); - ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size}); - ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size}); + if (!is_test) { + ctx->SetOutputDim("BatchGate", input_dims); + ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size}); + ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size}); + } ctx->SetOutputDim("Hidden", {input_dims[0], frame_size}); ctx->ShareLoD("Input", "Hidden"); } @@ -124,19 +128,22 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker { "organized in batches. The LoD size is 2. The first LoD contains " "the batch offsets and the second LoD contains the indexes in " "the raw sequence data.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); AddOutput( "BatchResetHiddenPrev", "(LoDTensor) The reset hidden state LoDTensor organized in batches. " "This LoDTensor is a matrix with shape (T X D) and has the same LoD " "with `BatchGate`.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); AddOutput( "BatchHidden", "(LoDTensor) The hidden state LoDTensor organized in batches. " "This LoDTensor is a matrix with shape (T X D) and has the same LoD " "with `BatchGate`.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); AddOutput( "Hidden", "(LoDTensor) the hidden state LoDTensor organized in sequences. " @@ -155,6 +162,9 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default: False) " "whether to compute reversed GRU.") .SetDefault(false); + AddAttr("is_test", "True if in test phase.") + .SetDefault(false) + .AsExtra(); AddAttr("origin_mode", "bool" "use origin mode in article https://arxiv.org/abs/1412.3555") @@ -269,24 +279,42 @@ class GRUCPUKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { using DeviceContext = paddle::platform::CPUDeviceContext; + using LodTensorPtr = LoDTensor*; + bool is_test = context.Attr("is_test"); + bool origin_mode = context.Attr("origin_mode"); auto* input = context.Input("Input"); auto* h0 = context.Input("H0"); auto* weight = context.Input("Weight"); const T* weight_data = weight->data(); auto* bias = context.Input("Bias"); - auto* batch_gate = context.Output("BatchGate"); - batch_gate->mutable_data(context.GetPlace()); - auto* batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - auto* batch_hidden = context.Output("BatchHidden"); - batch_hidden->mutable_data(context.GetPlace()); auto* hidden = context.Output("Hidden"); hidden->mutable_data(context.GetPlace()); + auto input_dims = input->dims(); auto hidden_dims = hidden->dims(); + LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden; + LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp; + if (is_test) { + batch_gate = &batch_gate_tmp; + batch_gate->Resize(input_dims); + + batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; + batch_reset_hidden_prev->Resize(hidden_dims); + + batch_hidden = &batch_hidden_tmp; + batch_hidden->Resize(hidden_dims); + } else { + batch_gate = context.Output("BatchGate"); + batch_hidden = context.Output("BatchHidden"); + batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + } + batch_gate->mutable_data(context.GetPlace()); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + batch_hidden->mutable_data(context.GetPlace()); + bool is_reverse = context.Attr("is_reverse"); math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = context.template device_context(); diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc index bdc5debaea790c740f2e133b66e2cfb9e334dc3e..edd7f8a7cf5539c50a8173a5d85bd6de0468d2ab 100644 --- a/paddle/fluid/operators/gru_op.cu.cc +++ b/paddle/fluid/operators/gru_op.cu.cc @@ -28,24 +28,42 @@ template class GRUKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { + using LodTensorPtr = LoDTensor*; + + bool is_test = context.Attr("is_test"); bool origin_mode = context.Attr("origin_mode"); auto* input = context.Input("Input"); auto* h0 = context.Input("H0"); auto* weight = context.Input("Weight"); const T* weight_data = weight->data(); auto* bias = context.Input("Bias"); - auto* batch_gate = context.Output("BatchGate"); - batch_gate->mutable_data(context.GetPlace()); - auto* batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - auto* batch_hidden = context.Output("BatchHidden"); - batch_hidden->mutable_data(context.GetPlace()); auto* hidden = context.Output("Hidden"); hidden->mutable_data(context.GetPlace()); + auto input_dims = input->dims(); auto hidden_dims = hidden->dims(); + LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden; + LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp; + if (is_test) { + batch_gate = &batch_gate_tmp; + batch_gate->Resize(input_dims); + + batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; + batch_reset_hidden_prev->Resize(hidden_dims); + + batch_hidden = &batch_hidden_tmp; + batch_hidden->Resize(hidden_dims); + } else { + batch_gate = context.Output("BatchGate"); + batch_hidden = context.Output("BatchHidden"); + batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + } + batch_gate->mutable_data(context.GetPlace()); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + batch_hidden->mutable_data(context.GetPlace()); + bool is_reverse = context.Attr("is_reverse"); math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = context.template device_context(); diff --git a/paddle/fluid/operators/lstm_op.cc b/paddle/fluid/operators/lstm_op.cc index 2c9669cbd65499b73f6bb03a8b962279a311709e..0405578f5dc1edbebf410b34199fdf106a397a9b 100644 --- a/paddle/fluid/operators/lstm_op.cc +++ b/paddle/fluid/operators/lstm_op.cc @@ -30,10 +30,15 @@ class LSTMOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM"); OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM"); - OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTM"); - OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output", - "BatchCellPreAct", "LSTM"); + bool is_test = ctx->Attrs().Get("is_test"); + + if (!is_test) { + OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", + "LSTM"); + OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output", + "BatchCellPreAct", "LSTM"); + } auto in_dims = ctx->GetInputDim("Input"); PADDLE_ENFORCE_EQ( in_dims.size(), 2, @@ -103,8 +108,10 @@ class LSTMOp : public framework::OperatorWithKernel { framework::DDim out_dims({in_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Cell", out_dims); - ctx->SetOutputDim("BatchGate", in_dims); - ctx->SetOutputDim("BatchCellPreAct", out_dims); + if (!is_test) { + ctx->SetOutputDim("BatchGate", in_dims); + ctx->SetOutputDim("BatchCellPreAct", out_dims); + } ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Cell"); } @@ -164,11 +171,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "LoD is the batch offsets and the second LoD contains the " "indexes, which denote the position of reorganized sequence " "in the raw input.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); AddOutput("BatchCellPreAct", "(LoDTensor) This LoDTensor is obtained in the forward and used " "in the backward.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); AddAttr("use_peepholes", "(bool, default: True) " "whether to enable diagonal/peephole connections.") @@ -177,6 +186,9 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default: False) " "whether to compute reversed LSTM.") .SetDefault(false); + AddAttr("is_test", "True if in test phase.") + .SetDefault(false) + .AsExtra(); AddAttr( "gate_activation", "(string, default: sigmoid)" diff --git a/paddle/fluid/operators/lstm_op.h b/paddle/fluid/operators/lstm_op.h index a4434283abb6f4c20bda9198a5557f4abcc3f470..c6f43b949a73696e5f19c753c061f0a1e1553dcf 100644 --- a/paddle/fluid/operators/lstm_op.h +++ b/paddle/fluid/operators/lstm_op.h @@ -40,6 +40,8 @@ template class LSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + bool is_test = ctx.Attr("is_test"); + auto* input = ctx.Input("Input"); auto* weight = ctx.Input("Weight"); auto* bias = ctx.Input("Bias"); @@ -47,7 +49,14 @@ class LSTMKernel : public framework::OpKernel { auto* hidden_t0 = ctx.Input("H0"); auto* cell_t0 = ctx.Input("C0"); - auto* batch_gate = ctx.Output("BatchGate"); + LoDTensor* batch_gate = nullptr; + LoDTensor batch_gate_temp; + if (is_test) { + batch_gate = &batch_gate_temp; + batch_gate->Resize(input->dims()); + } else { + batch_gate = ctx.Output("BatchGate"); + } batch_gate->mutable_data(ctx.GetPlace()); auto* hidden_out = ctx.Output("Hidden"); hidden_out->mutable_data(ctx.GetPlace()); @@ -99,8 +108,13 @@ class LSTMKernel : public framework::OpKernel { } // Use the local variable as here. - LoDTensor batch_hidden, batch_cell; - auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); + LoDTensor batch_hidden, batch_cell, batch_cell_pre_act_temp; + LoDTensor* batch_cell_pre_act; + if (is_test) { + batch_cell_pre_act = &batch_cell_pre_act_temp; + } else { + batch_cell_pre_act = ctx.Output("BatchCellPreAct"); + } batch_hidden.mutable_data(dims, ctx.GetPlace()); batch_cell.mutable_data(dims, ctx.GetPlace()); batch_cell_pre_act->mutable_data(dims, ctx.GetPlace()); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index da7f824b3a6f12db6adcf90fd448fc18aa3030af..18ee5d71541e06dcedf5cc0e24a4f324e9bd539f 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -119,14 +119,16 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { "tensor's axes according to the values given."); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddAttr( "data_format", "(string, default NCHW) Only used in " "An optional string from: \"NHWC\", \"NCHW\". " "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") - .SetDefault("AnyLayout"); + .SetDefault("AnyLayout") + .AsExtra(); AddAttr( "use_quantizer", "(bool, default false) " @@ -262,7 +264,9 @@ class Transpose2OpMaker : public TransposeOpMaker { public: void Make() override { TransposeOpMaker::Make(); - AddOutput("XShape", "(Tensor)The output tensor.").AsIntermediate(); + AddOutput("XShape", "(Tensor)The output tensor.") + .AsIntermediate() + .AsExtra(); } }; diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index 3ec943ef2e04a28324a034676394e3fb02caceba..7740cc0b03b494e09a2c6bd4189de2880bdef0a2 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -18,7 +18,7 @@ import unittest import numpy as np import math import functools -from op_test import OpTest +from op_test import OpTest, skip_check_grad_ci from paddle.fluid.tests.unittests.test_lstm_op import ACTIVATION from paddle import fluid from paddle.fluid import Program, program_guard @@ -106,6 +106,9 @@ class TestGRUOp(OpTest): def set_confs(self): pass + def set_is_test(self): + self.is_test = False + def setUp(self): self.op_type = "gru" self.lod = [[2, 4, 3]] @@ -118,6 +121,7 @@ class TestGRUOp(OpTest): self.dtype = 'float64' self.origin_mode = False self.set_confs() + self.set_is_test() T = sum(self.lod[0]) N = len(self.lod[0]) @@ -153,7 +157,8 @@ class TestGRUOp(OpTest): 'activation': self.act_state, 'gate_activation': self.act_gate, 'is_reverse': self.is_reverse, - 'origin_mode': self.origin_mode + 'origin_mode': self.origin_mode, + 'is_test': self.is_test } def test_check_output(self): @@ -229,6 +234,21 @@ class TestGRUOpReverseOriginMode(TestGRUOp): self.origin_mode = True +class TestGRUOpInference(TestGRUOp): + def set_is_test(self): + self.is_test = True + + def test_check_output(self): + new_outputs = {} + new_outputs['Hidden'] = self.outputs['Hidden'] + self.outputs = new_outputs + super(TestGRUOpInference, self).test_check_output() + + # avoid checking gradient + def test_check_grad(self): + pass + + class TestGruOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/test_lstm_op.py b/python/paddle/fluid/tests/unittests/test_lstm_op.py index 185255439cc264358ba3d562dfd9f136f870779a..fff5fef29221e633def4c9258069189835c79e52 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, skip_check_grad_ci from paddle import fluid from paddle.fluid.layers import lstm as LSTM from paddle.fluid.layers import fill_constant @@ -212,10 +212,14 @@ class LstmUnitTestError(unittest.TestCase): class TestLstmOp(OpTest): + def set_is_test(self): + self.is_test = False + def set_lod(self): self.lod = [[2, 3, 2]] def set_argument(self): + self.set_is_test() self.set_lod() self.D = 16 @@ -269,7 +273,8 @@ class TestLstmOp(OpTest): 'is_reverse': self.is_reverse, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, - 'candidate_activation': self.act_cand + 'candidate_activation': self.act_cand, + 'is_test': self.is_test } def test_check_output(self): @@ -302,6 +307,15 @@ class TestLstmOpCase3(TestLstmOp): self.lod = [[2, 0, 4]] +class TestLstmOpInference(TestLstmOp): + def set_is_test(self): + self.is_test = True + + # avoid checking gradient + def test_check_grad(self): + pass + + class TestLstmOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()):