未验证 提交 db5fd2a1 编写于 作者: W will-jl944 提交者: GitHub

multiply supports bool

multiply supports bool  
上级 a2dbb0c2
......@@ -132,6 +132,7 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
......@@ -142,6 +143,7 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
......@@ -156,6 +158,8 @@ REGISTER_OP_CPU_KERNEL(
int>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
bool>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
......
......@@ -121,6 +121,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
......@@ -130,6 +131,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
......@@ -141,6 +143,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
......
......@@ -70,6 +70,12 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))
# test static computation graph: boolean
x_data = np.random.choice([True, False], size=[200])
y_data = np.random.choice([True, False], size=[200])
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))
# test dynamic computation graph: 1-d array
x_data = np.random.rand(200)
y_data = np.random.rand(200)
......@@ -88,6 +94,12 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))
# test dynamic computation graph: boolean
x_data = np.random.choice([True, False], size=[200])
y_data = np.random.choice([True, False], size=[200])
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))
class TestMultiplyError(unittest.TestCase):
def test_errors(self):
......
......@@ -196,10 +196,10 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(original_op_type)
assert y is not None, 'y cannot be None in {}'.format(original_op_type)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'],
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
original_op_type)
check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'],
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
original_op_type)
axis = helper.kwargs.get('axis', -1)
......@@ -473,8 +473,8 @@ def multiply(x, y, name=None):
``paddle.multiply`` supports broadcasting. If you would like to know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, its data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, its data type should be float32, float64, int32, int64.
x (Tensor): the input tensor, its data type should be one of float32, float64, int32, int64, bool.
y (Tensor): the input tensor, its data type should be one of float32, float64, int32, int64, bool.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册