未验证 提交 76a3678f 编写于 作者: G GaoWei8 提交者: GitHub

fix lod_reset check dtype (#24227)

上级 c36c67fa
......@@ -7660,8 +7660,7 @@ def lod_reset(x, y=None, target_lod=None):
out = helper.create_variable_for_type_inference(dtype=x.dtype)
if y is not None:
check_type(y, 'y', (Variable), 'lod_reset')
if y.lod_level == 0:
check_variable_and_dtype(y, 'y', ['int32'], 'lod_reset')
#TODO: check y.lod_level = 0 dtype
helper.append_op(
type="lod_reset", inputs={'X': x,
'Y': y}, outputs={'Out': out})
......@@ -7732,8 +7731,7 @@ def lod_append(x, level):
if isinstance(level, Variable):
inputs['Y'] = level
if level.lod_level == 0:
check_variable_and_dtype(level, 'level', ['int32'], 'lod_append')
#TODO: check y.lod_level = 0 dtype
else:
attrs['target_lod'] = level
helper.append_op(
......
......@@ -67,14 +67,6 @@ class TestLodAppendOpError(unittest.TestCase):
name='level3' + dtype, shape=[4], dtype='int32', lod_level=2)
self.assertRaises(TypeError, fluid.layers.lod_append, x3, level3)
# Input(level) dtype must be int32 when lod_level=0
for dtype in ["bool", "float16", "float32", "float64", "int64"]:
x4 = fluid.layers.data(
name='x4' + dtype, shape=[4], dtype='float32')
level4 = fluid.layers.data(
name='level4_' + dtype, shape=[4], dtype=dtype, lod_level=0)
self.assertRaises(TypeError, fluid.layers.lod_append, x4, level4)
if __name__ == "__main__":
unittest.main()
......@@ -150,14 +150,6 @@ class TestLodResetOpError(unittest.TestCase):
name='y2' + dtype, shape=[4], dtype='int32', lod_level=2)
self.assertRaises(TypeError, fluid.layers.lod_reset, x2, y2)
# Input(y) dtype must be int32 when lod_level=0
for dtype in ["bool", "float16", "float32", "float64", "int64"]:
x3 = fluid.layers.data(
name='x3' + dtype, shape=[4], dtype='float32')
y3 = fluid.layers.data(
name='y3' + dtype, shape=[4], dtype=dtype, lod_level=0)
self.assertRaises(TypeError, fluid.layers.lod_reset, x3, y3)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册