未验证 提交 eaad3e4c 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add check of input in sequence_expand op. (#15466)

* Add check of input in sequence_expand op.
test=develop

* Correct the unittest of sequence_expand op.
test=develop
上级 f4dec5cd
......@@ -68,6 +68,11 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
"Level number of Input(X)'s lod could be 0. Otherwise "
"size of Input(X)'s first level lod should be equal to "
"size of Input(Y)'s referred level lod.");
} else {
PADDLE_ENFORCE_EQ(x_dims[0], y_lod[ref_level].size() - 1,
"When Input(X)'s lod is null, the dims[0] of "
"Input(X) should match the "
"size of Input(Y)'s referred level lod.");
}
int64_t out_first_dim = 0;
......
......@@ -81,11 +81,10 @@ class TestSequenceExpand(OpTest):
class TestSequenceExpandCase1(TestSequenceExpand):
def set_data(self):
x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32')
x_lod = [[2, 3]]
y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32')
y_lod = [[2, 3], [2, 2, 3, 3, 3]]
self.inputs = {'X': x_data, 'Y': (y_data, y_lod)}
self.attrs = {'ref_level': 0}
self.attrs = {'ref_level': 1}
class TestSequenceExpandCase2(TestSequenceExpand):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册