未验证 提交 30e4cacd 编写于 作者: W wawltor 提交者: GitHub

Add the support dygraph attribute of op mm, support the out attribute

chery-pick from the pr#23978
上级 56fd2e47
......@@ -37,7 +37,8 @@ std::map<std::string, std::set<std::string>> op_passing_out_map = {
{"momentum", {"ParamOut", "VelocityOut"}},
{"batch_norm", {"MeanOut", "VarianceOut"}},
{"accuracy", {"Correct", "Total"}},
{"fill_constant", {"Out"}}};
{"fill_constant", {"Out"}},
{"matmul", {"Out"}}};
// clang-format off
const char* OUT_INITIALIZER_TEMPLATE =
......
......@@ -936,7 +936,8 @@ class Linear(layers.Layer):
def forward(self, input):
if in_dygraph_mode():
pre_bias = core.ops.matmul(input, self.weight, 'transpose_X', False,
pre_bias = _varbase_creator(dtype=input.dtype)
core.ops.matmul(input, self.weight, pre_bias, 'transpose_X', False,
'transpose_Y', False, "alpha", 1)
pre_act = dygraph_utils._append_bias_in_dygraph(
pre_bias, self.bias, axis=len(input.shape) - 1)
......
......@@ -280,6 +280,42 @@ class API_TestMm(unittest.TestCase):
"two value is\
{}\n{}, check diff!".format(np_res, expected_result))
def test_dygraph_with_out(self):
device = fluid.CPUPlace()
with fluid.dygraph.guard(device):
input_array1 = np.random.rand(3, 4).astype("float64")
input_array2 = np.random.rand(4, 3).astype("float64")
out_array = np.random.rand(3, 3).astype("float64")
data1 = fluid.dygraph.to_variable(input_array1)
data2 = fluid.dygraph.to_variable(input_array2)
paddle_out_holder = fluid.dygraph.to_variable(out_array)
out = paddle.mm(data1, data2, out=paddle_out_holder)
self.assertTrue(np.allclose(paddle_out_holder.numpy(), out.numpy()))
def test_dygraph_without_out(self):
device = fluid.CPUPlace()
with fluid.dygraph.guard(device):
input_array1 = np.random.rand(3, 4).astype("float64")
input_array2 = np.random.rand(4, 3).astype("float64")
data1 = fluid.dygraph.to_variable(input_array1)
data2 = fluid.dygraph.to_variable(input_array2)
out = paddle.mm(data1, data2)
expected_result = np.matmul(input_array1, input_array2)
self.assertTrue(np.allclose(expected_result, out.numpy()))
class Test_API_Matmul(unittest.TestCase):
def test_dygraph_without_out(self):
device = fluid.CPUPlace()
with fluid.dygraph.guard(device):
input_array1 = np.random.rand(3, 4).astype("float64")
input_array2 = np.random.rand(4, 3).astype("float64")
data1 = fluid.dygraph.to_variable(input_array1)
data2 = fluid.dygraph.to_variable(input_array2)
out = paddle.matmul(data1, data2)
expected_result = np.matmul(input_array1, input_array2)
self.assertTrue(np.allclose(expected_result, out.numpy()))
class API_TestMmError(unittest.TestCase):
def test_errors(self):
......
......@@ -14,7 +14,7 @@
from paddle.common_ops_import import *
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type
from ..fluid.framework import in_dygraph_mode
from ..fluid.framework import in_dygraph_mode, _varbase_creator
__all__ = [
'matmul',
......@@ -109,8 +109,10 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
}
if in_dygraph_mode():
return core.ops.matmul(x, y, 'transpose_X', transpose_x, 'transpose_Y',
out = _varbase_creator(dtype=x.dtype)
core.ops.matmul(x, y, out, 'transpose_X', transpose_x, 'transpose_Y',
transpose_y, 'alpha', float(alpha))
return out
def __check_input(x, y):
var_names = {'x': x, 'y': y}
......
......@@ -19,7 +19,7 @@ from __future__ import print_function
from paddle.common_ops_import import *
from ..fluid import layers
from ..fluid.framework import core
from ..fluid.framework import core, _varbase_creator
from ..fluid.layers.layer_function_generator import _generate_doc_string_
# TODO: define math functions
......@@ -902,7 +902,10 @@ def mm(input, mat2, out=None, name=None):
out = paddle.mm(x, mat2) # out shape is [2, 2]
"""
if in_dygraph_mode():
return core.ops.matmul(input, mat2)
if out is None:
out = _varbase_creator(dtype=input.dtype)
core.ops.matmul(input, mat2, out)
return out
def __check_input(x, y):
var_names = {'x': x, 'y': y}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册