未验证 提交 fe8d006f 编写于 作者: G Guo Sheng 提交者: GitHub

API/OP(sequence_expand_as) error message enhancement (#23712)

* API/OP(sequence_expand_as) error message enhancement.
test=develop
Co-authored-by: NFrostML <380185688@qq.com>
上级 03ba5b74
...@@ -27,18 +27,18 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel { ...@@ -27,18 +27,18 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "SequenceExpandAs");
"Input(X) of SequenceExpandAsOp should not be null."); OP_INOUT_CHECK(ctx->HasInputs("Y"), "Input", "Y", "SequenceExpandAs");
PADDLE_ENFORCE(ctx->HasInput("Y"), OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SequenceExpandAs");
"Input(Y) of SequenceExpandAsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceExpandAsOp should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto out_dims = x_dims; auto out_dims = x_dims;
PADDLE_ENFORCE_GE(x_dims.size(), 2, PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Dimension number of Input(X) should be at least 2."); platform::errors::InvalidArgument(
"Dimension number of Input(X) should be at least 2. "
"But received X's dimensions = %d, X's shape = [%s].",
x_dims.size(), x_dims));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
framework::Variable* x_var = framework::Variable* x_var =
...@@ -50,11 +50,17 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel { ...@@ -50,11 +50,17 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel {
auto& y_lod = y_var->Get<LoDTensor>().lod(); auto& y_lod = y_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE_EQ(y_lod.size(), 1, PADDLE_ENFORCE_EQ(y_lod.size(), 1,
"Level number of Input(Y)'s lod should be 1."); platform::errors::InvalidArgument(
"Level number of Input(Y)'s lod should be 1. But "
"received Y's lod level = %d.",
y_lod.size()));
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dim[0]), y_lod[0].size() - 1, PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dim[0]), y_lod[0].size() - 1,
"The first dimension of Input(X) should be equal " platform::errors::InvalidArgument(
"to the size of Input(Y)'s 0 level lod."); "The first dimension of Input(X) should be one "
"less than the size of Input(Y)'s 0 level lod. But "
"received X's shape[0] = %d, Y's lod[0].size = %d.",
x_dim[0], y_lod[0].size()));
int64_t out_first_dim = 0; int64_t out_first_dim = 0;
if (y_lod[0].size() <= 1) { if (y_lod[0].size() <= 1) {
...@@ -138,9 +144,9 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { ...@@ -138,9 +144,9 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "SequenceExpandAsGrad");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Out")), "Input",
"Input(Out@GRAD) should not be null."); "Out@GRAD", "SequenceExpandAsGrad");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
......
...@@ -74,13 +74,25 @@ class SequenceExpandAsKernel : public framework::OpKernel<T> { ...@@ -74,13 +74,25 @@ class SequenceExpandAsKernel : public framework::OpKernel<T> {
auto *y = context.Input<framework::LoDTensor>("Y"); auto *y = context.Input<framework::LoDTensor>("Y");
auto *out = context.Output<framework::LoDTensor>("Out"); auto *out = context.Output<framework::LoDTensor>("Out");
PADDLE_ENFORCE_EQ(y->lod().empty(), false, PADDLE_ENFORCE_EQ(
"Input(Y) Tensor of SequenceExpandAsOp does not contain " y->lod().empty(), false,
"LoD information."); platform::errors::InvalidArgument(
"Input(Y) of SequenceExpandAsOp has wrong LoD information. "
"Expected Y's lod is not empty, but received empty lod."));
auto &y_lod = y->lod(); auto &y_lod = y->lod();
PADDLE_ENFORCE_EQ(y_lod.size(), 1, "LoD of Y should be 1."); PADDLE_ENFORCE_EQ(y_lod.size(), 1,
PADDLE_ENFORCE_GT(y_lod[0].size(), 1, "."); platform::errors::InvalidArgument(
"Input(Y) of SequenceExpandAsOp has wrong LoD "
"information. Expected Y's lod level = 1, but "
"received lod level = %d.",
y_lod.size()));
PADDLE_ENFORCE_GT(y_lod[0].size(), 1,
platform::errors::InvalidArgument(
"Input(Y) of SequenceExpandAsOp has wrong LoD "
"information. Expected the size of Y's lod[0] > 1, "
"but received lod[0].size = %d.",
y_lod[0].size()));
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
......
...@@ -52,7 +52,7 @@ def sequence_conv(input, ...@@ -52,7 +52,7 @@ def sequence_conv(input,
act=None, act=None,
name=None): name=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
**Notes: The Op only receives LoDTensor as input. If your input is Tensor, please use conv2d Op.(fluid.layers.** :ref:`api_fluid_layers_conv2d` ). **Notes: The Op only receives LoDTensor as input. If your input is Tensor, please use conv2d Op.(fluid.layers.** :ref:`api_fluid_layers_conv2d` ).
...@@ -176,7 +176,7 @@ def sequence_conv(input, ...@@ -176,7 +176,7 @@ def sequence_conv(input,
def sequence_softmax(input, use_cudnn=False, name=None): def sequence_softmax(input, use_cudnn=False, name=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
**Note**: **Note**:
...@@ -260,7 +260,7 @@ def sequence_softmax(input, use_cudnn=False, name=None): ...@@ -260,7 +260,7 @@ def sequence_softmax(input, use_cudnn=False, name=None):
def sequence_pool(input, pool_type, is_test=False, pad_value=0.0): def sequence_pool(input, pool_type, is_test=False, pad_value=0.0):
""" """
:api_attr: Static Graph :api_attr: Static Graph
**Notes: The Op only receives LoDTensor as input. If your input is Tensor, please use pool2d Op.(fluid.layers.** :ref:`api_fluid_layers_pool2d` ). **Notes: The Op only receives LoDTensor as input. If your input is Tensor, please use pool2d Op.(fluid.layers.** :ref:`api_fluid_layers_pool2d` ).
...@@ -374,7 +374,7 @@ def sequence_pool(input, pool_type, is_test=False, pad_value=0.0): ...@@ -374,7 +374,7 @@ def sequence_pool(input, pool_type, is_test=False, pad_value=0.0):
@templatedoc() @templatedoc()
def sequence_concat(input, name=None): def sequence_concat(input, name=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
**Notes: The Op only receives LoDTensor as input. If your input is Tensor, please use concat Op.(fluid.layers.** :ref:`api_fluid_layers_concat` ). **Notes: The Op only receives LoDTensor as input. If your input is Tensor, please use concat Op.(fluid.layers.** :ref:`api_fluid_layers_concat` ).
...@@ -435,7 +435,7 @@ def sequence_concat(input, name=None): ...@@ -435,7 +435,7 @@ def sequence_concat(input, name=None):
def sequence_first_step(input): def sequence_first_step(input):
""" """
:api_attr: Static Graph :api_attr: Static Graph
This operator only supports LoDTensor as input. Given the input LoDTensor, it will This operator only supports LoDTensor as input. Given the input LoDTensor, it will
select first time-step feature of each sequence as output. select first time-step feature of each sequence as output.
...@@ -489,7 +489,7 @@ def sequence_first_step(input): ...@@ -489,7 +489,7 @@ def sequence_first_step(input):
def sequence_last_step(input): def sequence_last_step(input):
""" """
:api_attr: Static Graph :api_attr: Static Graph
This operator only supports LoDTensor as input. Given the input LoDTensor, it will This operator only supports LoDTensor as input. Given the input LoDTensor, it will
select last time-step feature of each sequence as output. select last time-step feature of each sequence as output.
...@@ -544,7 +544,7 @@ def sequence_last_step(input): ...@@ -544,7 +544,7 @@ def sequence_last_step(input):
def sequence_slice(input, offset, length, name=None): def sequence_slice(input, offset, length, name=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
**Sequence Slice Layer** **Sequence Slice Layer**
...@@ -632,7 +632,7 @@ def sequence_slice(input, offset, length, name=None): ...@@ -632,7 +632,7 @@ def sequence_slice(input, offset, length, name=None):
def sequence_expand(x, y, ref_level=-1, name=None): def sequence_expand(x, y, ref_level=-1, name=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
Sequence Expand Layer. This layer will expand the input variable ``x`` \ Sequence Expand Layer. This layer will expand the input variable ``x`` \
according to specified level ``ref_level`` lod of ``y``. Please note that \ according to specified level ``ref_level`` lod of ``y``. Please note that \
...@@ -768,7 +768,7 @@ def sequence_expand(x, y, ref_level=-1, name=None): ...@@ -768,7 +768,7 @@ def sequence_expand(x, y, ref_level=-1, name=None):
def sequence_expand_as(x, y, name=None): def sequence_expand_as(x, y, name=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
Sequence Expand As Layer. This OP will expand the input variable ``x`` \ Sequence Expand As Layer. This OP will expand the input variable ``x`` \
according to the zeroth level lod of ``y``. Current implementation requires \ according to the zeroth level lod of ``y``. Current implementation requires \
...@@ -815,7 +815,7 @@ def sequence_expand_as(x, y, name=None): ...@@ -815,7 +815,7 @@ def sequence_expand_as(x, y, name=None):
Args: Args:
x (Variable): The input variable which is a Tensor or LoDTensor, with the \ x (Variable): The input variable which is a Tensor or LoDTensor, with the \
dims ``[M, K]``. The data type should be float32, float64, int8, int32 \ dims ``[M, K]``. The data type should be float32, float64, int32 \
or int64. or int64.
y (Variable): The input variable which is a LoDTensor with 1-level lod. y (Variable): The input variable which is a LoDTensor with 1-level lod.
name (str, optional): For detailed information, please refer \ name (str, optional): For detailed information, please refer \
...@@ -872,6 +872,9 @@ def sequence_expand_as(x, y, name=None): ...@@ -872,6 +872,9 @@ def sequence_expand_as(x, y, name=None):
""" """
assert not in_dygraph_mode(), ( assert not in_dygraph_mode(), (
"sequence layer is not supported in dygraph mode yet.") "sequence layer is not supported in dygraph mode yet.")
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'sequence_expand_as')
check_type(y, 'y', Variable, 'sequence_expand_as')
helper = LayerHelper('sequence_expand_as', input=x, **locals()) helper = LayerHelper('sequence_expand_as', input=x, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
tmp = helper.create_variable_for_type_inference(dtype) tmp = helper.create_variable_for_type_inference(dtype)
...@@ -885,7 +888,7 @@ def sequence_expand_as(x, y, name=None): ...@@ -885,7 +888,7 @@ def sequence_expand_as(x, y, name=None):
def sequence_pad(x, pad_value, maxlen=None, name=None): def sequence_pad(x, pad_value, maxlen=None, name=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
This layer padding the sequences in a same batch to a common length (according \ This layer padding the sequences in a same batch to a common length (according \
to ``maxlen``). The padding value is defined by ``pad_value``, and will be \ to ``maxlen``). The padding value is defined by ``pad_value``, and will be \
...@@ -999,7 +1002,7 @@ def sequence_pad(x, pad_value, maxlen=None, name=None): ...@@ -999,7 +1002,7 @@ def sequence_pad(x, pad_value, maxlen=None, name=None):
def sequence_unpad(x, length, name=None): def sequence_unpad(x, length, name=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
**Note**: **Note**:
...@@ -1074,7 +1077,7 @@ def sequence_unpad(x, length, name=None): ...@@ -1074,7 +1077,7 @@ def sequence_unpad(x, length, name=None):
def sequence_reshape(input, new_dim): def sequence_reshape(input, new_dim):
""" """
:api_attr: Static Graph :api_attr: Static Graph
**Notes: The Op only receives LoDTensor as input. If your input is Tensor, please use reshape Op.(fluid.layers.** :ref:`api_fluid_layers_reshape` ). **Notes: The Op only receives LoDTensor as input. If your input is Tensor, please use reshape Op.(fluid.layers.** :ref:`api_fluid_layers_reshape` ).
...@@ -1136,7 +1139,7 @@ def sequence_reshape(input, new_dim): ...@@ -1136,7 +1139,7 @@ def sequence_reshape(input, new_dim):
def sequence_scatter(input, index, updates, name=None): def sequence_scatter(input, index, updates, name=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
**Note**: **Note**:
...@@ -1226,7 +1229,7 @@ def sequence_scatter(input, index, updates, name=None): ...@@ -1226,7 +1229,7 @@ def sequence_scatter(input, index, updates, name=None):
def sequence_enumerate(input, win_size, pad_value=0, name=None): def sequence_enumerate(input, win_size, pad_value=0, name=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
Generate a new sequence for the input index sequence with \ Generate a new sequence for the input index sequence with \
shape ``[d_1, win_size]``, which enumerates all the \ shape ``[d_1, win_size]``, which enumerates all the \
......
...@@ -18,7 +18,9 @@ import unittest ...@@ -18,7 +18,9 @@ import unittest
import numpy as np import numpy as np
import sys import sys
sys.path.append("../") sys.path.append("../")
import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest
from paddle.fluid import Program, program_guard
class TestSequenceExpandAs(OpTest): class TestSequenceExpandAs(OpTest):
...@@ -84,5 +86,22 @@ class TestSequenceExpandAsCase3(TestSequenceExpandAs): ...@@ -84,5 +86,22 @@ class TestSequenceExpandAsCase3(TestSequenceExpandAs):
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
class TestSequenceExpandAsOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input x must be Variable
x1 = np.random.random((2, 4)).astype("float32")
self.assertRaises(TypeError, fluid.layers.sequence_expand_as, x1)
# the dtype of input x must be float32, float64, int32 or int64
x2 = fluid.data(name='x2', shape=[None, 4], dtype="bool")
self.assertRaises(TypeError, fluid.layers.sequence_expand_as, x2)
# the input y must be Variable
x3 = fluid.data(name='x3', shape=[None, 4], dtype="float32")
y = np.random.random((2, 4)).astype("float32")
self.assertRaises(TypeError, fluid.layers.sequence_expand_as, x3, y)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册