未验证 提交 f13dcfb1 编写于 作者: J Jack Zhou 提交者: GitHub

Add AsExtra for transpose, lstm, gru (#35317)

* Add AsExtra for transpose

* add AsExtra for lstm op

* add AsExtra for gru
上级 b333dac0
......@@ -33,13 +33,15 @@ class GRUOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU");
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
"BatchResetHiddenPrev", "GRU");
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
"GRU");
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU");
bool is_test = ctx->Attrs().Get<bool>("is_test");
if (!is_test) {
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
"BatchResetHiddenPrev", "GRU");
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
"GRU");
}
auto input_dims = ctx->GetInputDim("Input");
auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1];
......@@ -84,9 +86,11 @@ class GRUOp : public framework::OperatorWithKernel {
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
}
ctx->SetOutputDim("BatchGate", input_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
if (!is_test) {
ctx->SetOutputDim("BatchGate", input_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
}
ctx->SetOutputDim("Hidden", {input_dims[0], frame_size});
ctx->ShareLoD("Input", "Hidden");
}
......@@ -124,19 +128,22 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"organized in batches. The LoD size is 2. The first LoD contains "
"the batch offsets and the second LoD contains the indexes in "
"the raw sequence data.")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
AddOutput(
"BatchResetHiddenPrev",
"(LoDTensor) The reset hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
AddOutput(
"BatchHidden",
"(LoDTensor) The hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
AddOutput(
"Hidden",
"(LoDTensor) the hidden state LoDTensor organized in sequences. "
......@@ -155,6 +162,9 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default: False) "
"whether to compute reversed GRU.")
.SetDefault(false);
AddAttr<bool>("is_test", "True if in test phase.")
.SetDefault(false)
.AsExtra();
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
......@@ -269,24 +279,42 @@ class GRUCPUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
using LodTensorPtr = LoDTensor*;
bool is_test = context.Attr<bool>("is_test");
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias");
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(context.GetPlace());
auto* batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_hidden->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<LoDTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
auto input_dims = input->dims();
auto hidden_dims = hidden->dims();
LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
if (is_test) {
batch_gate = &batch_gate_tmp;
batch_gate->Resize(input_dims);
batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
batch_reset_hidden_prev->Resize(hidden_dims);
batch_hidden = &batch_hidden_tmp;
batch_hidden->Resize(hidden_dims);
} else {
batch_gate = context.Output<LoDTensor>("BatchGate");
batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
}
batch_gate->mutable_data<T>(context.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
batch_hidden->mutable_data<T>(context.GetPlace());
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = context.template device_context<DeviceContext>();
......
......@@ -28,24 +28,42 @@ template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
using LodTensorPtr = LoDTensor*;
bool is_test = context.Attr<bool>("is_test");
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias");
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(context.GetPlace());
auto* batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_hidden->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<LoDTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
auto input_dims = input->dims();
auto hidden_dims = hidden->dims();
LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
if (is_test) {
batch_gate = &batch_gate_tmp;
batch_gate->Resize(input_dims);
batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
batch_reset_hidden_prev->Resize(hidden_dims);
batch_hidden = &batch_hidden_tmp;
batch_hidden->Resize(hidden_dims);
} else {
batch_gate = context.Output<LoDTensor>("BatchGate");
batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
}
batch_gate->mutable_data<T>(context.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
batch_hidden->mutable_data<T>(context.GetPlace());
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = context.template device_context<DeviceContext>();
......
......@@ -30,10 +30,15 @@ class LSTMOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
"BatchCellPreAct", "LSTM");
bool is_test = ctx->Attrs().Get<bool>("is_test");
if (!is_test) {
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate",
"LSTM");
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
"BatchCellPreAct", "LSTM");
}
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(
in_dims.size(), 2,
......@@ -103,8 +108,10 @@ class LSTMOp : public framework::OperatorWithKernel {
framework::DDim out_dims({in_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchGate", in_dims);
ctx->SetOutputDim("BatchCellPreAct", out_dims);
if (!is_test) {
ctx->SetOutputDim("BatchGate", in_dims);
ctx->SetOutputDim("BatchCellPreAct", out_dims);
}
ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell");
}
......@@ -164,11 +171,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence "
"in the raw input.")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
AddOutput("BatchCellPreAct",
"(LoDTensor) This LoDTensor is obtained in the forward and used "
"in the backward.")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
AddAttr<bool>("use_peepholes",
"(bool, default: True) "
"whether to enable diagonal/peephole connections.")
......@@ -177,6 +186,9 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default: False) "
"whether to compute reversed LSTM.")
.SetDefault(false);
AddAttr<bool>("is_test", "True if in test phase.")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"gate_activation",
"(string, default: sigmoid)"
......
......@@ -40,6 +40,8 @@ template <typename DeviceContext, typename T>
class LSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
bool is_test = ctx.Attr<bool>("is_test");
auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias");
......@@ -47,7 +49,14 @@ class LSTMKernel : public framework::OpKernel<T> {
auto* hidden_t0 = ctx.Input<Tensor>("H0");
auto* cell_t0 = ctx.Input<Tensor>("C0");
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
LoDTensor* batch_gate = nullptr;
LoDTensor batch_gate_temp;
if (is_test) {
batch_gate = &batch_gate_temp;
batch_gate->Resize(input->dims());
} else {
batch_gate = ctx.Output<LoDTensor>("BatchGate");
}
batch_gate->mutable_data<T>(ctx.GetPlace());
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
hidden_out->mutable_data<T>(ctx.GetPlace());
......@@ -99,8 +108,13 @@ class LSTMKernel : public framework::OpKernel<T> {
}
// Use the local variable as here.
LoDTensor batch_hidden, batch_cell;
auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
LoDTensor batch_hidden, batch_cell, batch_cell_pre_act_temp;
LoDTensor* batch_cell_pre_act;
if (is_test) {
batch_cell_pre_act = &batch_cell_pre_act_temp;
} else {
batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
}
batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
......
......@@ -119,14 +119,16 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
"tensor's axes according to the values given.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
.SetDefault("AnyLayout")
.AsExtra();
AddAttr<bool>(
"use_quantizer",
"(bool, default false) "
......@@ -262,7 +264,9 @@ class Transpose2OpMaker : public TransposeOpMaker {
public:
void Make() override {
TransposeOpMaker::Make();
AddOutput("XShape", "(Tensor)The output tensor.").AsIntermediate();
AddOutput("XShape", "(Tensor)The output tensor.")
.AsIntermediate()
.AsExtra();
}
};
......
......@@ -18,7 +18,7 @@ import unittest
import numpy as np
import math
import functools
from op_test import OpTest
from op_test import OpTest, skip_check_grad_ci
from paddle.fluid.tests.unittests.test_lstm_op import ACTIVATION
from paddle import fluid
from paddle.fluid import Program, program_guard
......@@ -106,6 +106,9 @@ class TestGRUOp(OpTest):
def set_confs(self):
pass
def set_is_test(self):
self.is_test = False
def setUp(self):
self.op_type = "gru"
self.lod = [[2, 4, 3]]
......@@ -118,6 +121,7 @@ class TestGRUOp(OpTest):
self.dtype = 'float64'
self.origin_mode = False
self.set_confs()
self.set_is_test()
T = sum(self.lod[0])
N = len(self.lod[0])
......@@ -153,7 +157,8 @@ class TestGRUOp(OpTest):
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode
'origin_mode': self.origin_mode,
'is_test': self.is_test
}
def test_check_output(self):
......@@ -229,6 +234,21 @@ class TestGRUOpReverseOriginMode(TestGRUOp):
self.origin_mode = True
class TestGRUOpInference(TestGRUOp):
def set_is_test(self):
self.is_test = True
def test_check_output(self):
new_outputs = {}
new_outputs['Hidden'] = self.outputs['Hidden']
self.outputs = new_outputs
super(TestGRUOpInference, self).test_check_output()
# avoid checking gradient
def test_check_grad(self):
pass
class TestGruOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, skip_check_grad_ci
from paddle import fluid
from paddle.fluid.layers import lstm as LSTM
from paddle.fluid.layers import fill_constant
......@@ -212,10 +212,14 @@ class LstmUnitTestError(unittest.TestCase):
class TestLstmOp(OpTest):
def set_is_test(self):
self.is_test = False
def set_lod(self):
self.lod = [[2, 3, 2]]
def set_argument(self):
self.set_is_test()
self.set_lod()
self.D = 16
......@@ -269,7 +273,8 @@ class TestLstmOp(OpTest):
'is_reverse': self.is_reverse,
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
'candidate_activation': self.act_cand
'candidate_activation': self.act_cand,
'is_test': self.is_test
}
def test_check_output(self):
......@@ -302,6 +307,15 @@ class TestLstmOpCase3(TestLstmOp):
self.lod = [[2, 0, 4]]
class TestLstmOpInference(TestLstmOp):
def set_is_test(self):
self.is_test = True
# avoid checking gradient
def test_check_grad(self):
pass
class TestLstmOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册