未验证 提交 76a3678f 编写于 作者: G GaoWei8 提交者: GitHub

fix lod_reset check dtype (#24227)

上级 c36c67fa
...@@ -7660,8 +7660,7 @@ def lod_reset(x, y=None, target_lod=None): ...@@ -7660,8 +7660,7 @@ def lod_reset(x, y=None, target_lod=None):
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
if y is not None: if y is not None:
check_type(y, 'y', (Variable), 'lod_reset') check_type(y, 'y', (Variable), 'lod_reset')
if y.lod_level == 0: #TODO: check y.lod_level = 0 dtype
check_variable_and_dtype(y, 'y', ['int32'], 'lod_reset')
helper.append_op( helper.append_op(
type="lod_reset", inputs={'X': x, type="lod_reset", inputs={'X': x,
'Y': y}, outputs={'Out': out}) 'Y': y}, outputs={'Out': out})
...@@ -7732,8 +7731,7 @@ def lod_append(x, level): ...@@ -7732,8 +7731,7 @@ def lod_append(x, level):
if isinstance(level, Variable): if isinstance(level, Variable):
inputs['Y'] = level inputs['Y'] = level
if level.lod_level == 0: #TODO: check y.lod_level = 0 dtype
check_variable_and_dtype(level, 'level', ['int32'], 'lod_append')
else: else:
attrs['target_lod'] = level attrs['target_lod'] = level
helper.append_op( helper.append_op(
......
...@@ -67,14 +67,6 @@ class TestLodAppendOpError(unittest.TestCase): ...@@ -67,14 +67,6 @@ class TestLodAppendOpError(unittest.TestCase):
name='level3' + dtype, shape=[4], dtype='int32', lod_level=2) name='level3' + dtype, shape=[4], dtype='int32', lod_level=2)
self.assertRaises(TypeError, fluid.layers.lod_append, x3, level3) self.assertRaises(TypeError, fluid.layers.lod_append, x3, level3)
# Input(level) dtype must be int32 when lod_level=0
for dtype in ["bool", "float16", "float32", "float64", "int64"]:
x4 = fluid.layers.data(
name='x4' + dtype, shape=[4], dtype='float32')
level4 = fluid.layers.data(
name='level4_' + dtype, shape=[4], dtype=dtype, lod_level=0)
self.assertRaises(TypeError, fluid.layers.lod_append, x4, level4)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -150,14 +150,6 @@ class TestLodResetOpError(unittest.TestCase): ...@@ -150,14 +150,6 @@ class TestLodResetOpError(unittest.TestCase):
name='y2' + dtype, shape=[4], dtype='int32', lod_level=2) name='y2' + dtype, shape=[4], dtype='int32', lod_level=2)
self.assertRaises(TypeError, fluid.layers.lod_reset, x2, y2) self.assertRaises(TypeError, fluid.layers.lod_reset, x2, y2)
# Input(y) dtype must be int32 when lod_level=0
for dtype in ["bool", "float16", "float32", "float64", "int64"]:
x3 = fluid.layers.data(
name='x3' + dtype, shape=[4], dtype='float32')
y3 = fluid.layers.data(
name='y3' + dtype, shape=[4], dtype=dtype, lod_level=0)
self.assertRaises(TypeError, fluid.layers.lod_reset, x3, y3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册