From 733ea0d29bac96924e62c5714fa57ce07e2ff220 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Fri, 24 Aug 2018 05:51:56 +0000 Subject: [PATCH] adjust infershape details --- paddle/fluid/operators/sequence_enumerate_op.cc | 15 ++++----------- python/paddle/fluid/layers/nn.py | 4 ++-- .../paddle/fluid/tests/unittests/test_layers.py | 4 +--- .../tests/unittests/test_sequence_enumerate_op.py | 11 ++++------- 4 files changed, 11 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_enumerate_op.cc index cacbb097771..b8c8daf3f39 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_enumerate_op.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/operators/sequence_enumerate_op.h" -#include namespace paddle { namespace operators { @@ -34,18 +33,12 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( x_dims.size(), 2UL, "Input(X) of SequenceEnumerate operator's rank should be 2."); + PADDLE_ENFORCE_EQ( + x_dims[1], 1UL, + "Input(X) of SequenceEnumerate operator's 2nd dimension should be 1."); const auto win_size = ctx->Attrs().Get("win_size"); - // TODO(chenweihang): unittest doesn't has batch size, but test_layers has - auto first_dim = x_dims[0] == -1 ? x_dims[1] : x_dims[0]; - PADDLE_ENFORCE(win_size <= first_dim, - "The enumerate window size should be less than or equal to " - "input sequence length."); - - std::vector out_shape(x_dims.size() + 1, 0); - for (int i = 0; i < x_dims.size(); ++i) out_shape.emplace_back(x_dims[i]); - out_shape.emplace_back(win_size); - ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); + ctx->SetOutputDim("Out", {x_dims[0], win_size}); ctx->ShareLoD("X", "Out"); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c8e4c99f9ea..9411256c74f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5563,7 +5563,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None): out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0) """ helper = LayerHelper('sequence_enumerate', **locals()) - out = helper.create_tmp_variable(helper.input_dtype()) + out = helper.create_tmp_variable(helper.input_dtype(), stop_gradient=True) helper.append_op( type='sequence_enumerate', inputs={'X': input}, @@ -5571,7 +5571,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None): attrs={'win_size': win_size, 'pad_value': pad_value}) - + def stack(x, axis=0): helper = LayerHelper('stack', **locals()) axis = 0 if axis is None else axis diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index c45ccee4bd1..351bcf790bf 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -522,10 +522,8 @@ class TestBook(unittest.TestCase): def test_sequence_enumerate(self): program = Program() with program_guard(program): - x = layers.data( - name="input", shape=[30], dtype='int32', lod_level=1) + x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1) out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0) - self.assertIsNotNone(out) print(str(program)) diff --git a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py index 18d91728fb9..41624da5126 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py @@ -19,7 +19,7 @@ import numpy as np from op_test import OpTest -def sequence_enumerate(input_seq, lod0, win_size, pad_value): +def sequence_enumerate(input_seq, win_size, pad_value): out_seq = [] for idx in range(0, len(input_seq)): single_seq = [] @@ -48,8 +48,7 @@ class TestSequenceEnumerateOp(OpTest): self.lod = [[9, 4, 11, 6]] self.win_size = 2 self.pad_value = 0 - out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, - self.pad_value) + out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) self.out_seq = np.array(out_seq).astype("int32") @@ -59,8 +58,7 @@ class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp): self.lod = [[9, 4, 11, 6]] self.win_size = 2 self.pad_value = 0 - out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, - self.pad_value) + out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) self.out_seq = np.array(out_seq).astype("int64") @@ -70,8 +68,7 @@ class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp): self.lod = [[9, 4, 11, 6]] self.win_size = 30 self.pad_value = 0 - out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, - self.pad_value) + out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) self.out_seq = np.array(out_seq).astype("int32") -- GitLab