From d3d69d8cd15b9459e610101bf0b9d9613fb3c464 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 16 Jan 2023 10:56:42 +0800 Subject: [PATCH] [Eager] polish bmm api (#49823) --- python/paddle/tensor/linalg.py | 41 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 4bb2cf4227..a5492d5081 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1557,30 +1557,29 @@ def bmm(x, y, name=None): # [60., 60.]]]) """ - x_shape = x.shape - y_shape = y.shape - if not len(x_shape) == len(y_shape) == 3: - raise ValueError( - "x and y should be 3-dimensional. But received x's dimention: {}, y's dimention: {}".format( - x_shape, y_shape - ) - ) - if x_shape[2] != y_shape[1]: - raise ValueError( - "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".format( - x_shape, y_shape - ) - ) - if x_shape[0] != y_shape[0]: - raise ValueError( - "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".format( - x_shape, y_shape - ) - ) - if in_dygraph_mode(): return _C_ops.bmm(x, y) else: + x_shape = x.shape + y_shape = y.shape + if not len(x_shape) == len(y_shape) == 3: + raise ValueError( + "x and y should be 3-dimensional. But received x's dimention: {}, y's dimention: {}".format( + x_shape, y_shape + ) + ) + if x_shape[2] != y_shape[1]: + raise ValueError( + "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".format( + x_shape, y_shape + ) + ) + if x_shape[0] != y_shape[0]: + raise ValueError( + "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".format( + x_shape, y_shape + ) + ) helper = LayerHelper('bmm', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( -- GitLab