未验证 提交 66cae915 编写于 作者: G GaoWei8 提交者: GitHub

Op (lod_reset) error message enhancement (#23499)

上级 63bfe0b9
......@@ -24,16 +24,17 @@ class LoDResetOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LoDResetOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LoDResetOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LoDReset");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "LoDReset");
if (!ctx->HasInput("Y")) {
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
PADDLE_ENFORCE_GT(level0.size(), 0,
"If Input(Y) not provided, the target lod should be "
"specified by attribute `target_lod`.");
PADDLE_ENFORCE_GT(
static_cast<int64_t>(level0.size()), 0,
platform::errors::InvalidArgument(
"If Input(Y) not provided, the target lod should be "
"specified by attribute `target_lod`. But the size of "
"`target_lod` is 0."));
} else if (ctx->IsRuntime()) {
ctx->ShareLoD("Y", "Out");
}
......@@ -181,10 +182,9 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LoDResetGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) of LoDResetGradOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LoDResetGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Output",
framework::GradVarName("Out"), "LoDResetGrad");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
......
......@@ -38,9 +38,14 @@ class LoDResetKernel : public framework::OpKernel<T> {
if (lod_t->lod().size() > 0) {
auto y_lod = lod_t->lod();
auto last_level = y_lod[y_lod.size() - 1];
PADDLE_ENFORCE_EQ((int64_t)(last_level.back()), in->dims()[0],
"Last value of `Y`'s last level LoD should be equal "
"to the first dimension of `X`");
PADDLE_ENFORCE_EQ(
static_cast<int64_t>(last_level.back()), in->dims()[0],
platform::errors::InvalidArgument(
"The last value of `Y`'s last level LoD should be equal "
"to the first dimension of `X`. But received the last value of "
"`Y`'s last level LoD is %d, the first dimension of `X` is "
"%d. ",
static_cast<int64_t>(last_level.back()), in->dims()[0]));
out->set_lod(y_lod);
return; // early return, since lod already set
} else {
......@@ -56,16 +61,33 @@ class LoDResetKernel : public framework::OpKernel<T> {
level0 = ctx.Attr<std::vector<int>>("target_lod");
}
PADDLE_ENFORCE_GT(level0.size(), 1UL,
"Size of target LoD should be greater than 1.");
PADDLE_ENFORCE_EQ(level0[0], 0,
"Target LoD should be a vector starting from 0.");
PADDLE_ENFORCE_EQ(level0.back(), in->dims()[0],
"Target LoD should be a vector end with the "
"first dimension of Input(X).");
PADDLE_ENFORCE_GT(
level0.size(), 1UL,
platform::errors::InvalidArgument(
"The size of target LoD should be greater than 1. But received the "
"size of target LoD is %d.",
level0.size()));
PADDLE_ENFORCE_EQ(static_cast<int64_t>(level0[0]), 0,
platform::errors::InvalidArgument(
"Target LoD should be a vector starting from 0. But "
"target LoD starts from %d.",
static_cast<int64_t>(level0[0])));
PADDLE_ENFORCE_EQ(
static_cast<int64_t>(level0.back()), in->dims()[0],
platform::errors::InvalidArgument(
"The last value of `Target LoD`'s last level LoD should be equal "
"to the first dimension of `X`. But received the last value of "
"`Target LoD`'s last level LoD is %d, the first dimension of `X` "
"is "
"%d. ",
static_cast<int64_t>(level0.back()), in->dims()[0]));
for (size_t i = 0; i < level0.size() - 1; ++i) {
PADDLE_ENFORCE(level0[i + 1] >= level0[i],
"Target LoD should be an ascending vector.");
PADDLE_ENFORCE_GE(
level0[i + 1], level0[i],
platform::errors::InvalidArgument(
"Target LoD should be an ascending vector. But the %s element is "
"%s and the %s element of Target LoD is %s.",
i + 1, level0[i + 1], i, level0[i]));
}
// cast level0 to size_t
......
......@@ -6189,9 +6189,16 @@ def lod_reset(x, y=None, target_lod=None):
y = fluid.layers.data(name='y', shape=[10, 20], lod_level=2)
out = fluid.layers.lod_reset(x=x, y=y)
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'lod_reset')
helper = LayerHelper("lod_reset", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
if y is not None:
if y.lod_level > 0:
check_variable_and_dtype(
y, 'y', ['float32', 'float64', 'int32', 'int64'], 'lod_reset')
else:
check_variable_and_dtype(y, 'y', ['int32', 'int64'], 'lod_reset')
helper.append_op(
type="lod_reset", inputs={'X': x,
'Y': y}, outputs={'Out': out})
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import Program, program_guard
class TestLodResetOpByAttr(OpTest):
......@@ -132,5 +133,32 @@ class TestLodAppendOpByAttr(OpTest):
self.check_grad(["X"], "Out", check_dygraph=False)
class TestLodResetOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
# The input must be Variable.
x1 = fluid.create_lod_tensor(
np.ones([6]), [3, 3], fluid.CPUPlace())
y1 = fluid.create_lod_tensor(
np.ones([6]), [2, 2, 2], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.lod_reset, [x1, y1])
def test_type():
# dtype must be float32 or float64 or int32 or int64
x2 = fluid.layers.data(shape=[4], dtype='uint8', name='x2')
y2 = fluid.layers.data(
shape=[4], dtype='uint8', name='x2', lod_level=2)
self.assertRaises(TypeError, fluid.layers.lod_reset, [x2, y2])
def test_type2():
# dtype must be int32 or int64
x3 = fluid.layers.data(shape=[4], dtype='float32', name='x3')
y3 = fluid.layers.data(
shape=[4], dtype='float32', name='x3', lod_level=0)
self.assertRaises(TypeError, fluid.layers.lod_reset, [x3, y3])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册