diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 4bb2cf4227deed1b9e8cd9f1720ac2ec62331c78..a5492d508103736ef44c8bb4586254ca2910838f 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1557,30 +1557,29 @@ def bmm(x, y, name=None): # [60., 60.]]]) """ - x_shape = x.shape - y_shape = y.shape - if not len(x_shape) == len(y_shape) == 3: - raise ValueError( - "x and y should be 3-dimensional. But received x's dimention: {}, y's dimention: {}".format( - x_shape, y_shape - ) - ) - if x_shape[2] != y_shape[1]: - raise ValueError( - "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".format( - x_shape, y_shape - ) - ) - if x_shape[0] != y_shape[0]: - raise ValueError( - "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".format( - x_shape, y_shape - ) - ) - if in_dygraph_mode(): return _C_ops.bmm(x, y) else: + x_shape = x.shape + y_shape = y.shape + if not len(x_shape) == len(y_shape) == 3: + raise ValueError( + "x and y should be 3-dimensional. But received x's dimention: {}, y's dimention: {}".format( + x_shape, y_shape + ) + ) + if x_shape[2] != y_shape[1]: + raise ValueError( + "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".format( + x_shape, y_shape + ) + ) + if x_shape[0] != y_shape[0]: + raise ValueError( + "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".format( + x_shape, y_shape + ) + ) helper = LayerHelper('bmm', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(