From b67dd3e37afbfda47de8f5a679a2d5c643891137 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Fri, 24 Apr 2020 20:18:09 +0800 Subject: [PATCH] add addmm dyg mode, test=develop (#24100) --- .../paddle/fluid/tests/unittests/test_addmm_op.py | 14 ++++++++++++++ python/paddle/tensor/math.py | 4 ++++ 2 files changed, 18 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_addmm_op.py b/python/paddle/fluid/tests/unittests/test_addmm_op.py index cb44f37225e..8c0b599a379 100644 --- a/python/paddle/fluid/tests/unittests/test_addmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_addmm_op.py @@ -133,5 +133,19 @@ class TestAddMMOp3(OpTest): self.check_grad(['Input'], 'Out', no_grad_set=None) +class TestAddMMOp4(unittest.TestCase): + def test_api_with_dygraph(self): + np_input = np.random.random((20, 30)).astype(np.float32) + np_x = np.random.random((20, 6)).astype(np.float32) + np_y = np.random.random((6, 30)).astype(np.float32) + + with fluid.dygraph.guard(): + input = fluid.dygraph.to_variable(np_input) + x = fluid.dygraph.to_variable(np_x) + y = fluid.dygraph.to_variable(np_y) + out = paddle.tensor.addmm(input, x, y) + assert np.allclose(np_input + np.dot(np_x, np_y), out.numpy()) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 20be0e0c919..29dbb74482c 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1001,6 +1001,10 @@ def addmm(input, x, y, alpha=1.0, beta=1.0, name=None): # [[10.5 10.5] # [10.5 10.5]] """ + if in_dygraph_mode(): + out = core.ops.addmm(input, x, y, "Alpha", alpha, "Beta", beta) + return out + inputs = {'Input': input, "X": x, "Y": y} attrs = {'Alpha': alpha, 'Beta': beta} -- GitLab