未验证 提交 db5fd2a1 编写于 作者: W will-jl944 提交者: GitHub

multiply supports bool

multiply supports bool  
上级 a2dbb0c2
...@@ -132,6 +132,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -132,6 +132,7 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, double>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
...@@ -142,6 +143,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -142,6 +143,7 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
...@@ -156,6 +158,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -156,6 +158,8 @@ REGISTER_OP_CPU_KERNEL(
int>, int>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>, int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
bool>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
......
...@@ -121,6 +121,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -121,6 +121,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>); ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
...@@ -130,6 +131,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -130,6 +131,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<float>>, plat::complex<float>>,
...@@ -141,6 +143,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -141,6 +143,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>, plat::complex<float>>,
......
...@@ -70,6 +70,12 @@ class TestMultiplyApi(unittest.TestCase): ...@@ -70,6 +70,12 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_static_graph_case(x_data, y_data) res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data))) self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))
# test static computation graph: boolean
x_data = np.random.choice([True, False], size=[200])
y_data = np.random.choice([True, False], size=[200])
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))
# test dynamic computation graph: 1-d array # test dynamic computation graph: 1-d array
x_data = np.random.rand(200) x_data = np.random.rand(200)
y_data = np.random.rand(200) y_data = np.random.rand(200)
...@@ -88,6 +94,12 @@ class TestMultiplyApi(unittest.TestCase): ...@@ -88,6 +94,12 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_dynamic_graph_case(x_data, y_data) res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data))) self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))
# test dynamic computation graph: boolean
x_data = np.random.choice([True, False], size=[200])
y_data = np.random.choice([True, False], size=[200])
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))
class TestMultiplyError(unittest.TestCase): class TestMultiplyError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
...@@ -196,10 +196,10 @@ def _elementwise_op(helper): ...@@ -196,10 +196,10 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(original_op_type) assert x is not None, 'x cannot be None in {}'.format(original_op_type)
assert y is not None, 'y cannot be None in {}'.format(original_op_type) assert y is not None, 'y cannot be None in {}'.format(original_op_type)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
original_op_type) original_op_type)
check_variable_and_dtype( check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'], y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
original_op_type) original_op_type)
axis = helper.kwargs.get('axis', -1) axis = helper.kwargs.get('axis', -1)
...@@ -473,8 +473,8 @@ def multiply(x, y, name=None): ...@@ -473,8 +473,8 @@ def multiply(x, y, name=None):
``paddle.multiply`` supports broadcasting. If you would like to know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . ``paddle.multiply`` supports broadcasting. If you would like to know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args: Args:
x (Tensor): the input tensor, its data type should be float32, float64, int32, int64. x (Tensor): the input tensor, its data type should be one of float32, float64, int32, int64, bool.
y (Tensor): the input tensor, its data type should be float32, float64, int32, int64. y (Tensor): the input tensor, its data type should be one of float32, float64, int32, int64, bool.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册