未验证 提交 f13dcfb1 编写于 作者: J Jack Zhou 提交者: GitHub

Add AsExtra for transpose, lstm, gru (#35317)

* Add AsExtra for transpose

* add AsExtra for lstm op

* add AsExtra for gru
上级 b333dac0
...@@ -33,13 +33,15 @@ class GRUOp : public framework::OperatorWithKernel { ...@@ -33,13 +33,15 @@ class GRUOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU"); OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU"); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU");
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
"BatchResetHiddenPrev", "GRU");
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
"GRU");
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU"); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU");
bool is_test = ctx->Attrs().Get<bool>("is_test");
if (!is_test) {
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
"BatchResetHiddenPrev", "GRU");
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
"GRU");
}
auto input_dims = ctx->GetInputDim("Input"); auto input_dims = ctx->GetInputDim("Input");
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1]; int input_size = input_dims[1];
...@@ -84,9 +86,11 @@ class GRUOp : public framework::OperatorWithKernel { ...@@ -84,9 +86,11 @@ class GRUOp : public framework::OperatorWithKernel {
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3)); bias_height, bias_width, frame_size * 3));
} }
ctx->SetOutputDim("BatchGate", input_dims); if (!is_test) {
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size}); ctx->SetOutputDim("BatchGate", input_dims);
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size}); ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
}
ctx->SetOutputDim("Hidden", {input_dims[0], frame_size}); ctx->SetOutputDim("Hidden", {input_dims[0], frame_size});
ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Hidden");
} }
...@@ -124,19 +128,22 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -124,19 +128,22 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"organized in batches. The LoD size is 2. The first LoD contains " "organized in batches. The LoD size is 2. The first LoD contains "
"the batch offsets and the second LoD contains the indexes in " "the batch offsets and the second LoD contains the indexes in "
"the raw sequence data.") "the raw sequence data.")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
AddOutput( AddOutput(
"BatchResetHiddenPrev", "BatchResetHiddenPrev",
"(LoDTensor) The reset hidden state LoDTensor organized in batches. " "(LoDTensor) The reset hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD " "This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.") "with `BatchGate`.")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
AddOutput( AddOutput(
"BatchHidden", "BatchHidden",
"(LoDTensor) The hidden state LoDTensor organized in batches. " "(LoDTensor) The hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD " "This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.") "with `BatchGate`.")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
AddOutput( AddOutput(
"Hidden", "Hidden",
"(LoDTensor) the hidden state LoDTensor organized in sequences. " "(LoDTensor) the hidden state LoDTensor organized in sequences. "
...@@ -155,6 +162,9 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -155,6 +162,9 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default: False) " "(bool, default: False) "
"whether to compute reversed GRU.") "whether to compute reversed GRU.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("is_test", "True if in test phase.")
.SetDefault(false)
.AsExtra();
AddAttr<bool>("origin_mode", AddAttr<bool>("origin_mode",
"bool" "bool"
"use origin mode in article https://arxiv.org/abs/1412.3555") "use origin mode in article https://arxiv.org/abs/1412.3555")
...@@ -269,24 +279,42 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -269,24 +279,42 @@ class GRUCPUKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
using LodTensorPtr = LoDTensor*;
bool is_test = context.Attr<bool>("is_test");
bool origin_mode = context.Attr<bool>("origin_mode"); bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input"); auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias"); auto* bias = context.Input<Tensor>("Bias");
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(context.GetPlace());
auto* batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_hidden->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<LoDTensor>("Hidden"); auto* hidden = context.Output<LoDTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace()); hidden->mutable_data<T>(context.GetPlace());
auto input_dims = input->dims();
auto hidden_dims = hidden->dims(); auto hidden_dims = hidden->dims();
LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
if (is_test) {
batch_gate = &batch_gate_tmp;
batch_gate->Resize(input_dims);
batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
batch_reset_hidden_prev->Resize(hidden_dims);
batch_hidden = &batch_hidden_tmp;
batch_hidden->Resize(hidden_dims);
} else {
batch_gate = context.Output<LoDTensor>("BatchGate");
batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
}
batch_gate->mutable_data<T>(context.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
batch_hidden->mutable_data<T>(context.GetPlace());
bool is_reverse = context.Attr<bool>("is_reverse"); bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
......
...@@ -28,24 +28,42 @@ template <typename DeviceContext, typename T> ...@@ -28,24 +28,42 @@ template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> { class GRUKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
using LodTensorPtr = LoDTensor*;
bool is_test = context.Attr<bool>("is_test");
bool origin_mode = context.Attr<bool>("origin_mode"); bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input"); auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias"); auto* bias = context.Input<Tensor>("Bias");
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(context.GetPlace());
auto* batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_hidden->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<LoDTensor>("Hidden"); auto* hidden = context.Output<LoDTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace()); hidden->mutable_data<T>(context.GetPlace());
auto input_dims = input->dims();
auto hidden_dims = hidden->dims(); auto hidden_dims = hidden->dims();
LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
if (is_test) {
batch_gate = &batch_gate_tmp;
batch_gate->Resize(input_dims);
batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
batch_reset_hidden_prev->Resize(hidden_dims);
batch_hidden = &batch_hidden_tmp;
batch_hidden->Resize(hidden_dims);
} else {
batch_gate = context.Output<LoDTensor>("BatchGate");
batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
}
batch_gate->mutable_data<T>(context.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
batch_hidden->mutable_data<T>(context.GetPlace());
bool is_reverse = context.Attr<bool>("is_reverse"); bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
......
...@@ -30,10 +30,15 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -30,10 +30,15 @@ class LSTMOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM"); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM"); OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
"BatchCellPreAct", "LSTM");
bool is_test = ctx->Attrs().Get<bool>("is_test");
if (!is_test) {
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate",
"LSTM");
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
"BatchCellPreAct", "LSTM");
}
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size(), 2, in_dims.size(), 2,
...@@ -103,8 +108,10 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -103,8 +108,10 @@ class LSTMOp : public framework::OperatorWithKernel {
framework::DDim out_dims({in_dims[0], frame_size}); framework::DDim out_dims({in_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims); ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchGate", in_dims); if (!is_test) {
ctx->SetOutputDim("BatchCellPreAct", out_dims); ctx->SetOutputDim("BatchGate", in_dims);
ctx->SetOutputDim("BatchCellPreAct", out_dims);
}
ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell"); ctx->ShareLoD("Input", "Cell");
} }
...@@ -164,11 +171,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -164,11 +171,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"LoD is the batch offsets and the second LoD contains the " "LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence " "indexes, which denote the position of reorganized sequence "
"in the raw input.") "in the raw input.")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
AddOutput("BatchCellPreAct", AddOutput("BatchCellPreAct",
"(LoDTensor) This LoDTensor is obtained in the forward and used " "(LoDTensor) This LoDTensor is obtained in the forward and used "
"in the backward.") "in the backward.")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
AddAttr<bool>("use_peepholes", AddAttr<bool>("use_peepholes",
"(bool, default: True) " "(bool, default: True) "
"whether to enable diagonal/peephole connections.") "whether to enable diagonal/peephole connections.")
...@@ -177,6 +186,9 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -177,6 +186,9 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default: False) " "(bool, default: False) "
"whether to compute reversed LSTM.") "whether to compute reversed LSTM.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("is_test", "True if in test phase.")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>( AddAttr<std::string>(
"gate_activation", "gate_activation",
"(string, default: sigmoid)" "(string, default: sigmoid)"
......
...@@ -40,6 +40,8 @@ template <typename DeviceContext, typename T> ...@@ -40,6 +40,8 @@ template <typename DeviceContext, typename T>
class LSTMKernel : public framework::OpKernel<T> { class LSTMKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
bool is_test = ctx.Attr<bool>("is_test");
auto* input = ctx.Input<LoDTensor>("Input"); auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight"); auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias"); auto* bias = ctx.Input<Tensor>("Bias");
...@@ -47,7 +49,14 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -47,7 +49,14 @@ class LSTMKernel : public framework::OpKernel<T> {
auto* hidden_t0 = ctx.Input<Tensor>("H0"); auto* hidden_t0 = ctx.Input<Tensor>("H0");
auto* cell_t0 = ctx.Input<Tensor>("C0"); auto* cell_t0 = ctx.Input<Tensor>("C0");
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate"); LoDTensor* batch_gate = nullptr;
LoDTensor batch_gate_temp;
if (is_test) {
batch_gate = &batch_gate_temp;
batch_gate->Resize(input->dims());
} else {
batch_gate = ctx.Output<LoDTensor>("BatchGate");
}
batch_gate->mutable_data<T>(ctx.GetPlace()); batch_gate->mutable_data<T>(ctx.GetPlace());
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
hidden_out->mutable_data<T>(ctx.GetPlace()); hidden_out->mutable_data<T>(ctx.GetPlace());
...@@ -99,8 +108,13 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -99,8 +108,13 @@ class LSTMKernel : public framework::OpKernel<T> {
} }
// Use the local variable as here. // Use the local variable as here.
LoDTensor batch_hidden, batch_cell; LoDTensor batch_hidden, batch_cell, batch_cell_pre_act_temp;
auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct"); LoDTensor* batch_cell_pre_act;
if (is_test) {
batch_cell_pre_act = &batch_cell_pre_act_temp;
} else {
batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
}
batch_hidden.mutable_data<T>(dims, ctx.GetPlace()); batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
batch_cell.mutable_data<T>(dims, ctx.GetPlace()); batch_cell.mutable_data<T>(dims, ctx.GetPlace());
batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace()); batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
......
...@@ -119,14 +119,16 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -119,14 +119,16 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
"tensor's axes according to the values given."); "tensor's axes according to the values given.");
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false)
.AsExtra();
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". " "An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, " "Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault("AnyLayout"); .SetDefault("AnyLayout")
.AsExtra();
AddAttr<bool>( AddAttr<bool>(
"use_quantizer", "use_quantizer",
"(bool, default false) " "(bool, default false) "
...@@ -262,7 +264,9 @@ class Transpose2OpMaker : public TransposeOpMaker { ...@@ -262,7 +264,9 @@ class Transpose2OpMaker : public TransposeOpMaker {
public: public:
void Make() override { void Make() override {
TransposeOpMaker::Make(); TransposeOpMaker::Make();
AddOutput("XShape", "(Tensor)The output tensor.").AsIntermediate(); AddOutput("XShape", "(Tensor)The output tensor.")
.AsIntermediate()
.AsExtra();
} }
}; };
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import numpy as np import numpy as np
import math import math
import functools import functools
from op_test import OpTest from op_test import OpTest, skip_check_grad_ci
from paddle.fluid.tests.unittests.test_lstm_op import ACTIVATION from paddle.fluid.tests.unittests.test_lstm_op import ACTIVATION
from paddle import fluid from paddle import fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
...@@ -106,6 +106,9 @@ class TestGRUOp(OpTest): ...@@ -106,6 +106,9 @@ class TestGRUOp(OpTest):
def set_confs(self): def set_confs(self):
pass pass
def set_is_test(self):
self.is_test = False
def setUp(self): def setUp(self):
self.op_type = "gru" self.op_type = "gru"
self.lod = [[2, 4, 3]] self.lod = [[2, 4, 3]]
...@@ -118,6 +121,7 @@ class TestGRUOp(OpTest): ...@@ -118,6 +121,7 @@ class TestGRUOp(OpTest):
self.dtype = 'float64' self.dtype = 'float64'
self.origin_mode = False self.origin_mode = False
self.set_confs() self.set_confs()
self.set_is_test()
T = sum(self.lod[0]) T = sum(self.lod[0])
N = len(self.lod[0]) N = len(self.lod[0])
...@@ -153,7 +157,8 @@ class TestGRUOp(OpTest): ...@@ -153,7 +157,8 @@ class TestGRUOp(OpTest):
'activation': self.act_state, 'activation': self.act_state,
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode 'origin_mode': self.origin_mode,
'is_test': self.is_test
} }
def test_check_output(self): def test_check_output(self):
...@@ -229,6 +234,21 @@ class TestGRUOpReverseOriginMode(TestGRUOp): ...@@ -229,6 +234,21 @@ class TestGRUOpReverseOriginMode(TestGRUOp):
self.origin_mode = True self.origin_mode = True
class TestGRUOpInference(TestGRUOp):
def set_is_test(self):
self.is_test = True
def test_check_output(self):
new_outputs = {}
new_outputs['Hidden'] = self.outputs['Hidden']
self.outputs = new_outputs
super(TestGRUOpInference, self).test_check_output()
# avoid checking gradient
def test_check_grad(self):
pass
class TestGruOpError(unittest.TestCase): class TestGruOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest, skip_check_grad_ci
from paddle import fluid from paddle import fluid
from paddle.fluid.layers import lstm as LSTM from paddle.fluid.layers import lstm as LSTM
from paddle.fluid.layers import fill_constant from paddle.fluid.layers import fill_constant
...@@ -212,10 +212,14 @@ class LstmUnitTestError(unittest.TestCase): ...@@ -212,10 +212,14 @@ class LstmUnitTestError(unittest.TestCase):
class TestLstmOp(OpTest): class TestLstmOp(OpTest):
def set_is_test(self):
self.is_test = False
def set_lod(self): def set_lod(self):
self.lod = [[2, 3, 2]] self.lod = [[2, 3, 2]]
def set_argument(self): def set_argument(self):
self.set_is_test()
self.set_lod() self.set_lod()
self.D = 16 self.D = 16
...@@ -269,7 +273,8 @@ class TestLstmOp(OpTest): ...@@ -269,7 +273,8 @@ class TestLstmOp(OpTest):
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'cell_activation': self.act_cell, 'cell_activation': self.act_cell,
'candidate_activation': self.act_cand 'candidate_activation': self.act_cand,
'is_test': self.is_test
} }
def test_check_output(self): def test_check_output(self):
...@@ -302,6 +307,15 @@ class TestLstmOpCase3(TestLstmOp): ...@@ -302,6 +307,15 @@ class TestLstmOpCase3(TestLstmOp):
self.lod = [[2, 0, 4]] self.lod = [[2, 0, 4]]
class TestLstmOpInference(TestLstmOp):
def set_is_test(self):
self.is_test = True
# avoid checking gradient
def test_check_grad(self):
pass
class TestLstmOpError(unittest.TestCase): class TestLstmOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册