未验证 提交 d3d69d8c 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] polish bmm api (#49823)

上级 02384bc6
......@@ -1557,6 +1557,9 @@ def bmm(x, y, name=None):
# [60., 60.]]])
"""
if in_dygraph_mode():
return _C_ops.bmm(x, y)
else:
x_shape = x.shape
y_shape = y.shape
if not len(x_shape) == len(y_shape) == 3:
......@@ -1577,10 +1580,6 @@ def bmm(x, y, name=None):
x_shape, y_shape
)
)
if in_dygraph_mode():
return _C_ops.bmm(x, y)
else:
helper = LayerHelper('bmm', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册