提交 7e2af4c9 编写于 作者: Z zhongpu 提交者: liym27

modify sequence_pool optest apply get_sequence_instance_size_0_input, test=develop (#22214)

上级 fc0b21e1
......@@ -392,6 +392,42 @@ class OpTest(unittest.TestCase):
x = np.random.uniform(0.1, 1, shape).astype('float32')
return (x, lod)
def lod_has_single_zero(self, lod):
for i in range(len(lod) - 2):
if lod[i] != 0 and lod[i + 1] == 0 and lod[i + 2] != 0:
return True
return False
def lod_has_continuous_zero(self, lod):
for i in range(len(lod) - 3):
if lod[i] != 0 and lod[i + 1] == 0 and lod[i + 2] == 0 and lod[
i + 3] != 0:
return True
return False
def get_sequence_instance_size_0_input(self, lod=None, shape=None):
"""Get LoD input data whose instance size is 0.
All sequence related OP unittests should call this function to contain the case of instance size is 0.
Args:
lod (list[list of int], optional): Length-based LoD, lod[0]'s size must at least eight, lod[0] must at least two zeros at the beginning and at least two zeros at the end, the middle position of lod[0] contains a single zero and multiple zero. Default: [[0, 0, 4, 0, 3, 0, 0, 5, 0, 0]].
shape (list, optional): Shape of input, shape[0] should be equals to lod[0][0]. Default: [13, 23].
Returns:
tuple (ndarray, lod): LoD input data whose instance size is 0.
"""
if lod is None:
lod = [[0, 0, 4, 0, 3, 0, 0, 5, 0, 0]]
if shape is None:
shape = [12, 10]
assert len(lod[0]) >= 8
assert lod[0][0] == 0 and lod[0][1] == 0 and lod[0][-1] == 0 and lod[0][
-2] == 0
assert self.lod_has_single_zero(lod[0]) is True
assert self.lod_has_continuous_zero(lod[0]) is True
assert sum(lod[0]) == shape[0]
x = np.random.uniform(0.1, 1, shape).astype('float32')
return (x, lod)
def append_input_output_for_dygraph(self, op_proto, np_list, is_input,
if_return_inputs_grad_dict, block):
def create_var(np_value, name, is_input, if_return_inputs_grad_dict):
......
......@@ -105,6 +105,17 @@ class TestSeqAvgPoolBatch1(TestSeqAvgPool):
return x
class TestSeqAvgPoolInstance0(TestSeqAvgPool):
def set_lod(self):
return [[0, 0, 4, 0, 3, 0, 0, 5, 0, 0]]
def set_lod_data(self):
lod = self.set_lod()
x, _ = self.get_sequence_instance_size_0_input(
lod=lod, shape=[sum(lod[0]), 10])
return x
class TestSeqAvgPoolLen0(TestSeqAvgPool):
def set_lod(self):
return [[0, 4, 0, 7, 0]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册