提交 1409586e 编写于 作者: C Chen Weihang 提交者: Tao Luo

Add LoD empty check for all related sequence ops (#19980)

* add lod check for sequence op, test=develop

* delete unnecessary check in expend op, test=develop
上级 88af4ab6
......@@ -61,6 +61,9 @@ class SeqConcatKernel : public framework::OpKernel<T> {
size_t lod_size = 0;
for (auto &x : xs) {
if (lod_size == 0) {
PADDLE_ENFORCE_EQ(x.get().lod().empty(), false,
"Input(X) Tensor of SequenceConcatOp does not "
"contain LoD information.");
lod_size = x.get().lod()[0].size();
} else {
PADDLE_ENFORCE_EQ(
......
......@@ -39,6 +39,9 @@ class SequenceConvKernel : public framework::OpKernel<T> {
int context_stride = context.Attr<int>("contextStride");
bool padding_trainable = context.Attr<bool>("paddingTrainable");
PADDLE_ENFORCE_EQ(
in->lod().empty(), false,
"Input(X) Tensor of SequenceConvOp does not contain LoD information.");
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
"Only support one level sequence now.");
......
......@@ -29,6 +29,10 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
int win_size = context.Attr<int>("win_size");
auto pad_value = static_cast<T>(context.Attr<int>("pad_value"));
PADDLE_ENFORCE_EQ(in->lod().empty(), false,
"Input(X) Tensor of SequenceEnumerateOp does not contain "
"LoD information.");
auto in_dims = in->dims();
auto lod0 = in->lod()[0];
PADDLE_ENFORCE_EQ(
......
......@@ -28,6 +28,9 @@ class SequenceEraseKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto lod = in->lod();
PADDLE_ENFORCE_EQ(
lod.empty(), false,
"Input(X) Tensor of SequenceEraseOp does not contain LoD information.");
PADDLE_ENFORCE_EQ(lod[lod.size() - 1].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<int>>("tokens");
......
......@@ -75,6 +75,10 @@ class SequenceExpandAsKernel : public framework::OpKernel<T> {
auto *y = context.Input<framework::LoDTensor>("Y");
auto *out = context.Output<framework::LoDTensor>("Out");
PADDLE_ENFORCE_EQ(y->lod().empty(), false,
"Input(Y) Tensor of SequenceExpandAsOp does not contain "
"LoD information.");
auto &y_lod = y->lod();
PADDLE_ENFORCE_EQ(y_lod.size(), 1, "LoD of Y should be 1.");
PADDLE_ENFORCE_GT(y_lod[0].size(), 1, ".");
......
......@@ -92,6 +92,10 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
auto& x_lod = x->lod();
auto& y_lod = y->lod();
PADDLE_ENFORCE_EQ(y_lod.empty(), false,
"Input(Y) Tensor of SequenceExpandOp does not contain "
"LoD information.");
if (ref_level == -1) ref_level = y_lod.size() - 1;
out->mutable_data<T>(context.GetPlace());
......
......@@ -35,6 +35,10 @@ class SequencePadOpKernel : public framework::OpKernel<T> {
auto* len_t = ctx.Output<LoDTensor>("Length");
out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(
x->lod().empty(), false,
"Input(X) Tensor of SequencePadOp does not contain LoD information.");
const auto* pad_value = ctx.Input<LoDTensor>("PadValue");
int padded_length = ctx.Attr<int>("padded_length");
......
......@@ -38,8 +38,9 @@ class SequencePoolKernel : public framework::OpKernel<T> {
auto lod = in->lod();
auto lod_level = lod.size();
// InferShape by lod
PADDLE_ENFORCE_GE(lod_level, 1UL,
"The lod level of input shall be 1 at least.");
PADDLE_ENFORCE_GT(
lod_level, 0,
"Input(X) Tensor of SequencePoolOp does not contain LoD information.");
PADDLE_ENFORCE_LE(lod_level, 2UL,
"The lod level of input shall be no more than 2.");
PADDLE_ENFORCE_GE(
......
......@@ -32,6 +32,9 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
int64_t in_width = in_dims[1];
auto& in_lod = in->lod();
PADDLE_ENFORCE_EQ(in_lod.empty(), false,
"Input(X) Tensor of SequenceReshapeOp does not contain "
"LoD information.");
PADDLE_ENFORCE_EQ(in_lod.size(), 1UL,
"Only support one level sequence now.");
PADDLE_ENFORCE_EQ(
......
......@@ -107,6 +107,9 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> {
auto &x = *ctx.Input<LoDTensor>("X");
auto *y = ctx.Output<LoDTensor>("Y");
PADDLE_ENFORCE_EQ(x.lod().empty(), false,
"Input(X) Tensor of SequenceReverseOp does not contain "
"LoD information.");
PADDLE_ENFORCE_EQ(x.lod().size(), 1,
"SequenceReverse Op only support one level lod.");
......
......@@ -34,6 +34,9 @@ class SequenceScatterOpKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out");
auto& ids_lod = ids->lod();
PADDLE_ENFORCE_EQ(ids_lod.empty(), false,
"Input(Ids) Tensor of SequenceScatterOp does not contain "
"LoD information.");
// Initialize out as same as x
out->mutable_data<T>(ctx.GetPlace());
......
......@@ -49,8 +49,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = in->lod();
auto n = lod[0].size() - 1;
PADDLE_ENFORCE_EQ(
lod.empty(), false,
"Input(X) Tensor of SequenceSliceOp does not contain LoD information.");
auto n = lod[0].size() - 1;
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(length->dims()[0]),
......
......@@ -95,6 +95,9 @@ class SequenceSoftmaxKernel : public framework::OpKernel<T> {
auto lod = x->lod();
auto dims = x->dims();
PADDLE_ENFORCE_EQ(lod.empty(), false,
"Input(X) Tensor of SequenceSoftmaxOp does not contain "
"LoD information.");
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_GT(
......
......@@ -70,6 +70,16 @@ class SequenceTopkAvgPoolingKernel : public framework::OpKernel<T> {
auto* out = context.Output<LoDTensor>("Out");
auto* pos = context.Output<Tensor>("pos");
PADDLE_ENFORCE_EQ(in->lod().empty(), false,
"Input(X) Tensor of SequenceTopkAvgPoolingOp does not "
"contain LoD information.");
PADDLE_ENFORCE_EQ(row->lod().empty(), false,
"Input(ROW) Tensor of SequenceTopkAvgPoolingOp does not "
"contain LoD information.");
PADDLE_ENFORCE_EQ(col->lod().empty(), false,
"Input(COLUMN) Tensor of SequenceTopkAvgPoolingOp does "
"not contain LoD information.");
auto channel_num = context.Attr<int>("channel_num");
auto topks = context.Attr<std::vector<int>>("topks");
auto k_num = topks.size();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册