未验证 提交 274f4e94 编写于 作者: W whs 提交者: GitHub

Merge pull request #8334 from wanghaoshuang/fix_seq_expand

Fix output dims of sequence expand op
...@@ -29,7 +29,9 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -29,7 +29,9 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out")); PADDLE_ENFORCE(ctx->HasOutput("Out"));
PADDLE_ENFORCE(ctx->HasInput("Y")); PADDLE_ENFORCE(ctx->HasInput("Y"));
framework::DDim out_dim; framework::DDim out_dim;
out_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
out_dim = ctx->GetInputDim("X");
out_dim[0] = y_dim[0];
ctx->ShareLoD("Y", "Out"); ctx->ShareLoD("Y", "Out");
ctx->SetOutputDim("Out", out_dim); ctx->SetOutputDim("Out", out_dim);
} }
......
...@@ -73,5 +73,20 @@ class TestSequenceExpandCase3(TestSequenceExpand): ...@@ -73,5 +73,20 @@ class TestSequenceExpandCase3(TestSequenceExpand):
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
class TestSequenceExpandCase4(TestSequenceExpand):
def set_data(self):
x_data = np.array(
[0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape(
[2, 5]).astype('float32')
x_lod = [[
0,
1,
2,
]]
y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32')
y_lod = [[0, 1, 2], [0, 1, 2]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册