From b368f13ef4cbbec0699cb9f2c44b10339039146e Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 9 Feb 2018 20:41:49 +0800 Subject: [PATCH] Fix output dims of sequence expand op --- paddle/operators/sequence_expand_op.cc | 4 +++- .../paddle/v2/fluid/tests/test_sequence_expand.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/paddle/operators/sequence_expand_op.cc b/paddle/operators/sequence_expand_op.cc index d34dbd35b6..d2a386ffbe 100644 --- a/paddle/operators/sequence_expand_op.cc +++ b/paddle/operators/sequence_expand_op.cc @@ -29,7 +29,9 @@ class SequenceExpandOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out")); PADDLE_ENFORCE(ctx->HasInput("Y")); framework::DDim out_dim; - out_dim = ctx->GetInputDim("Y"); + auto y_dim = ctx->GetInputDim("Y"); + out_dim = ctx->GetInputDim("X"); + out_dim[0] = y_dim[0]; ctx->ShareLoD("Y", "Out"); ctx->SetOutputDim("Out", out_dim); } diff --git a/python/paddle/v2/fluid/tests/test_sequence_expand.py b/python/paddle/v2/fluid/tests/test_sequence_expand.py index 6fc045125f..0d37751de4 100644 --- a/python/paddle/v2/fluid/tests/test_sequence_expand.py +++ b/python/paddle/v2/fluid/tests/test_sequence_expand.py @@ -73,5 +73,20 @@ class TestSequenceExpandCase3(TestSequenceExpand): self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} +class TestSequenceExpandCase4(TestSequenceExpand): + def set_data(self): + x_data = np.array( + [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape( + [2, 5]).astype('float32') + x_lod = [[ + 0, + 1, + 2, + ]] + y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32') + y_lod = [[0, 1, 2], [0, 1, 2]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + if __name__ == '__main__': unittest.main() -- GitLab