未验证 提交 d3d69d8c 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] polish bmm api (#49823)

上级 02384bc6
......@@ -1557,30 +1557,29 @@ def bmm(x, y, name=None):
# [60., 60.]]])
"""
x_shape = x.shape
y_shape = y.shape
if not len(x_shape) == len(y_shape) == 3:
raise ValueError(
"x and y should be 3-dimensional. But received x's dimention: {}, y's dimention: {}".format(
x_shape, y_shape
)
)
if x_shape[2] != y_shape[1]:
raise ValueError(
"x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".format(
x_shape, y_shape
)
)
if x_shape[0] != y_shape[0]:
raise ValueError(
"x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".format(
x_shape, y_shape
)
)
if in_dygraph_mode():
return _C_ops.bmm(x, y)
else:
x_shape = x.shape
y_shape = y.shape
if not len(x_shape) == len(y_shape) == 3:
raise ValueError(
"x and y should be 3-dimensional. But received x's dimention: {}, y's dimention: {}".format(
x_shape, y_shape
)
)
if x_shape[2] != y_shape[1]:
raise ValueError(
"x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".format(
x_shape, y_shape
)
)
if x_shape[0] != y_shape[0]:
raise ValueError(
"x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".format(
x_shape, y_shape
)
)
helper = LayerHelper('bmm', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册