提交 733ea0d2 编写于 作者: C chenweihang

adjust infershape details

上级 2969aba1
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/fluid/operators/sequence_enumerate_op.h"
#include <vector>
namespace paddle {
namespace operators {
......@@ -34,18 +33,12 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
x_dims.size(), 2UL,
"Input(X) of SequenceEnumerate operator's rank should be 2.");
PADDLE_ENFORCE_EQ(
x_dims[1], 1UL,
"Input(X) of SequenceEnumerate operator's 2nd dimension should be 1.");
const auto win_size = ctx->Attrs().Get<int>("win_size");
// TODO(chenweihang): unittest doesn't has batch size, but test_layers has
auto first_dim = x_dims[0] == -1 ? x_dims[1] : x_dims[0];
PADDLE_ENFORCE(win_size <= first_dim,
"The enumerate window size should be less than or equal to "
"input sequence length.");
std::vector<int64_t> out_shape(x_dims.size() + 1, 0);
for (int i = 0; i < x_dims.size(); ++i) out_shape.emplace_back(x_dims[i]);
out_shape.emplace_back(win_size);
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
ctx->SetOutputDim("Out", {x_dims[0], win_size});
ctx->ShareLoD("X", "Out");
}
};
......
......@@ -5563,7 +5563,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0)
"""
helper = LayerHelper('sequence_enumerate', **locals())
out = helper.create_tmp_variable(helper.input_dtype())
out = helper.create_tmp_variable(helper.input_dtype(), stop_gradient=True)
helper.append_op(
type='sequence_enumerate',
inputs={'X': input},
......
......@@ -522,10 +522,8 @@ class TestBook(unittest.TestCase):
def test_sequence_enumerate(self):
program = Program()
with program_guard(program):
x = layers.data(
name="input", shape=[30], dtype='int32', lod_level=1)
x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1)
out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0)
self.assertIsNotNone(out)
print(str(program))
......
......@@ -19,7 +19,7 @@ import numpy as np
from op_test import OpTest
def sequence_enumerate(input_seq, lod0, win_size, pad_value):
def sequence_enumerate(input_seq, win_size, pad_value):
out_seq = []
for idx in range(0, len(input_seq)):
single_seq = []
......@@ -48,8 +48,7 @@ class TestSequenceEnumerateOp(OpTest):
self.lod = [[9, 4, 11, 6]]
self.win_size = 2
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size,
self.pad_value)
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
......@@ -59,8 +58,7 @@ class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp):
self.lod = [[9, 4, 11, 6]]
self.win_size = 2
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size,
self.pad_value)
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
self.out_seq = np.array(out_seq).astype("int64")
......@@ -70,8 +68,7 @@ class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp):
self.lod = [[9, 4, 11, 6]]
self.win_size = 30
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size,
self.pad_value)
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册