提交 733ea0d2 编写于 作者: C chenweihang

adjust infershape details

上级 2969aba1
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/sequence_enumerate_op.h" #include "paddle/fluid/operators/sequence_enumerate_op.h"
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -34,18 +33,12 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { ...@@ -34,18 +33,12 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims.size(), 2UL, x_dims.size(), 2UL,
"Input(X) of SequenceEnumerate operator's rank should be 2."); "Input(X) of SequenceEnumerate operator's rank should be 2.");
PADDLE_ENFORCE_EQ(
x_dims[1], 1UL,
"Input(X) of SequenceEnumerate operator's 2nd dimension should be 1.");
const auto win_size = ctx->Attrs().Get<int>("win_size"); const auto win_size = ctx->Attrs().Get<int>("win_size");
// TODO(chenweihang): unittest doesn't has batch size, but test_layers has ctx->SetOutputDim("Out", {x_dims[0], win_size});
auto first_dim = x_dims[0] == -1 ? x_dims[1] : x_dims[0];
PADDLE_ENFORCE(win_size <= first_dim,
"The enumerate window size should be less than or equal to "
"input sequence length.");
std::vector<int64_t> out_shape(x_dims.size() + 1, 0);
for (int i = 0; i < x_dims.size(); ++i) out_shape.emplace_back(x_dims[i]);
out_shape.emplace_back(win_size);
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
} }
}; };
......
...@@ -5563,7 +5563,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None): ...@@ -5563,7 +5563,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0) out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0)
""" """
helper = LayerHelper('sequence_enumerate', **locals()) helper = LayerHelper('sequence_enumerate', **locals())
out = helper.create_tmp_variable(helper.input_dtype()) out = helper.create_tmp_variable(helper.input_dtype(), stop_gradient=True)
helper.append_op( helper.append_op(
type='sequence_enumerate', type='sequence_enumerate',
inputs={'X': input}, inputs={'X': input},
......
...@@ -522,10 +522,8 @@ class TestBook(unittest.TestCase): ...@@ -522,10 +522,8 @@ class TestBook(unittest.TestCase):
def test_sequence_enumerate(self): def test_sequence_enumerate(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
x = layers.data( x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1)
name="input", shape=[30], dtype='int32', lod_level=1)
out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0) out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0)
self.assertIsNotNone(out)
print(str(program)) print(str(program))
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def sequence_enumerate(input_seq, lod0, win_size, pad_value): def sequence_enumerate(input_seq, win_size, pad_value):
out_seq = [] out_seq = []
for idx in range(0, len(input_seq)): for idx in range(0, len(input_seq)):
single_seq = [] single_seq = []
...@@ -48,8 +48,7 @@ class TestSequenceEnumerateOp(OpTest): ...@@ -48,8 +48,7 @@ class TestSequenceEnumerateOp(OpTest):
self.lod = [[9, 4, 11, 6]] self.lod = [[9, 4, 11, 6]]
self.win_size = 2 self.win_size = 2
self.pad_value = 0 self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32") self.out_seq = np.array(out_seq).astype("int32")
...@@ -59,8 +58,7 @@ class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp): ...@@ -59,8 +58,7 @@ class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp):
self.lod = [[9, 4, 11, 6]] self.lod = [[9, 4, 11, 6]]
self.win_size = 2 self.win_size = 2
self.pad_value = 0 self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
self.pad_value)
self.out_seq = np.array(out_seq).astype("int64") self.out_seq = np.array(out_seq).astype("int64")
...@@ -70,8 +68,7 @@ class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp): ...@@ -70,8 +68,7 @@ class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp):
self.lod = [[9, 4, 11, 6]] self.lod = [[9, 4, 11, 6]]
self.win_size = 30 self.win_size = 30
self.pad_value = 0 self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32") self.out_seq = np.array(out_seq).astype("int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册