未验证 提交 b67dd3e3 编写于 作者: L littletomatodonkey 提交者: GitHub

add addmm dyg mode, test=develop (#24100)

上级 0e6b43a6
......@@ -133,5 +133,19 @@ class TestAddMMOp3(OpTest):
self.check_grad(['Input'], 'Out', no_grad_set=None)
class TestAddMMOp4(unittest.TestCase):
def test_api_with_dygraph(self):
np_input = np.random.random((20, 30)).astype(np.float32)
np_x = np.random.random((20, 6)).astype(np.float32)
np_y = np.random.random((6, 30)).astype(np.float32)
with fluid.dygraph.guard():
input = fluid.dygraph.to_variable(np_input)
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
out = paddle.tensor.addmm(input, x, y)
assert np.allclose(np_input + np.dot(np_x, np_y), out.numpy())
if __name__ == "__main__":
unittest.main()
......@@ -1001,6 +1001,10 @@ def addmm(input, x, y, alpha=1.0, beta=1.0, name=None):
# [[10.5 10.5]
# [10.5 10.5]]
"""
if in_dygraph_mode():
out = core.ops.addmm(input, x, y, "Alpha", alpha, "Beta", beta)
return out
inputs = {'Input': input, "X": x, "Y": y}
attrs = {'Alpha': alpha, 'Beta': beta}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册