diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 81c2a4548d8812edd1b3ff4458f923a8a49d1157..600f93683eff357dcc869f6a271d53d60a061d62 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -326,6 +326,16 @@ kernel : func : bitwise_xor +# bmm +- api : bmm + args : (Tensor x, Tensor y) + output : Tensor + infer_meta : + func : BmmInferMeta + kernel : + func : bmm + backward : bmm_grad + # box_coder - api : box_coder args : (Tensor prior_box, Tensor prior_box_var, Tensor target_box, str code_type, bool box_normalized, int axis, float[] variance) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 47183fed746d040981f6af69bf8e961428199a51..310cf7c151ff22e7babcfa9f7f4cc993866d5cb5 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -260,6 +260,15 @@ kernel : func : bilinear_tensor_product_grad +- backward_api : bmm_grad + forward : bmm (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : BmmGradInferMeta + kernel : + func : bmm_grad + - backward_api : brelu_grad forward : brelu (Tensor x, float t_min, float t_max) -> Tensor(out) args : (Tensor x, Tensor out_grad, float t_min, float t_max) diff --git a/python/paddle/fluid/tests/unittests/test_bmm_op.py b/python/paddle/fluid/tests/unittests/test_bmm_op.py index b9a5853c492f595e6de58a280e5fa86811544fa8..5e5c41ae882798b0180d6a2922ea80ffe9264ce8 100644 --- a/python/paddle/fluid/tests/unittests/test_bmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_bmm_op.py @@ -27,6 +27,7 @@ class TestBmmOp(OpTest): def setUp(self): self.op_type = "bmm" + self.python_api = paddle.tensor.bmm X = np.random.random((10, 3, 4)).astype("float64") Y = np.random.random((10, 4, 5)).astype("float64") self.inputs = {'X': X, 'Y': Y} @@ -34,10 +35,10 @@ class TestBmmOp(OpTest): self.outputs = {'Out': Out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_checkout_grad(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) class API_TestBmm(unittest.TestCase): diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index a77d6b5a2ad92a36286a690543baaddd5397ddde..d1468765b5907b100f8e058c50d3dc42057da97f 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1521,6 +1521,9 @@ def bmm(x, y, name=None): "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}" .format(x_shape, y_shape)) + if in_dygraph_mode(): + return _C_ops.final_state_bmm(x, y) + if paddle.in_dynamic_mode(): return _C_ops.bmm(x, y)