未验证 提交 2841b004 编写于 作者: H Hui Zhang 提交者: GitHub

tirl bool for jit (#46512)

上级 6b11b693
...@@ -925,7 +925,8 @@ def _tril_triu_op(helper): ...@@ -925,7 +925,8 @@ def _tril_triu_op(helper):
assert x is not None, 'x cannot be None in {}'.format(op_type) assert x is not None, 'x cannot be None in {}'.format(op_type)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type) x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
op_type)
if len(x.shape) < 2: if len(x.shape) < 2:
raise ValueError("x shape in {} must be at least 2-D".format(op_type)) raise ValueError("x shape in {} must be at least 2-D".format(op_type))
diagonal = helper.kwargs.get('diagonal', 0) diagonal = helper.kwargs.get('diagonal', 0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册