未验证 提交 d32beea2 编写于 作者: J joejiong 提交者: GitHub

Add Checking Type for "multiply" operation (#26508)

Co-authored-by: Nwujionghao <wujionghao@email.com>
As the title
上级 2f5bdd8d
......@@ -44,8 +44,8 @@ class TestMultiplyAPI(unittest.TestCase):
def __run_dynamic_graph_case(self, x_data, y_data, axis=-1):
paddle.disable_static()
x = paddle.to_variable(x_data)
y = paddle.to_variable(y_data)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
res = paddle.multiply(x, y, axis=axis)
return res.numpy()
......@@ -126,17 +126,31 @@ class TestMultiplyError(unittest.TestCase):
paddle.disable_static()
x_data = np.random.randn(200).astype(np.int8)
y_data = np.random.randn(200).astype(np.int8)
x = paddle.to_variable(x_data)
y = paddle.to_variable(y_data)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
self.assertRaises(fluid.core.EnforceNotMet, paddle.multiply, x, y)
# test dynamic computation graph: inputs must be broadcastable
x_data = np.random.rand(200, 5)
y_data = np.random.rand(200)
x = paddle.to_variable(x_data)
y = paddle.to_variable(y_data)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
self.assertRaises(fluid.core.EnforceNotMet, paddle.multiply, x, y)
# test dynamic computation graph: inputs must be broadcastable(python)
x_data = np.random.rand(200, 5)
y_data = np.random.rand(200)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
self.assertRaises(fluid.core.EnforceNotMet, paddle.multiply, x, y)
# test dynamic computation graph: dtype must be same
x_data = np.random.randn(200).astype(np.int64)
y_data = np.random.randn(200).astype(np.float64)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
self.assertRaises(TypeError, paddle.multiply, x, y)
if __name__ == '__main__':
unittest.main()
......@@ -562,10 +562,23 @@ floor_mod = remainder #DEFINE_ALIAS
def multiply(x, y, axis=-1, name=None):
"""
:alias_main: paddle.multiply
:alias: paddle.multiply,paddle.tensor.multiply,paddle.tensor.math.multiply
multiply two tensors element-wise. The equation is:
Examples:
.. math::
out = x * y
**Note**:
``paddle.multiply`` supports broadcasting. If you would like to know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, its data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, its data type should be float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. A location into which the result is stored. Its dimension equals with $x$.
Examples:
.. code-block:: python
......@@ -575,21 +588,26 @@ Examples:
paddle.disable_static()
x_data = np.array([[1, 2], [3, 4]], dtype=np.float32)
y_data = np.array([[5, 6], [7, 8]], dtype=np.float32)
x = paddle.to_variable(x_data)
y = paddle.to_variable(y_data)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
res = paddle.multiply(x, y)
print(res.numpy()) # [[5, 12], [21, 32]]
x_data = np.array([[[1, 2, 3], [1, 2, 3]]], dtype=np.float32)
y_data = np.array([1, 2], dtype=np.float32)
x = paddle.to_variable(x_data)
y = paddle.to_variable(y_data)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
res = paddle.multiply(x, y, axis=1)
print(res.numpy()) # [[[1, 2, 3], [2, 4, 6]]]
"""
op_type = 'elementwise_mul'
act = None
if x.dtype != y.dtype:
raise TypeError(
'Input tensors must be same type, but received type of x: %s, type of y: %s '
% (x.dtype, y.dtype))
if in_dygraph_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册