diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index b3b4b054490d6df17593156e22e3263c30d1dce1..5b598ab2d788ebf03a86c79b6346d46d0889aa2c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -75,6 +75,36 @@ static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y, } } +template <> +__global__ void SimpleElemwiseMulGradCUDAKernel( + const plat::complex64* x, const plat::complex64* y, + const plat::complex64* out, const plat::complex64* dout, int64_t size, + plat::complex64* dx, plat::complex64* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + plat::complex64 o = dout[col]; + dx[col] = plat::complex64(y[col].real, -y[col].imag) * o; + dy[col] = plat::complex64(x[col].real, -x[col].imag) * o; + col += blockDim.x * gridDim.x; + } +} + +template <> +__global__ void SimpleElemwiseMulGradCUDAKernel( + const plat::complex128* x, const plat::complex128* y, + const plat::complex128* out, const plat::complex128* dout, int64_t size, + plat::complex128* dx, plat::complex128* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + plat::complex128 o = dout[col]; + dx[col] = plat::complex128(y[col].real, -y[col].imag) * o; + dy[col] = plat::complex128(x[col].real, -x[col].imag) * o; + col += blockDim.x * gridDim.x; + } +} + template typename std::enable_if< std::is_same::value>::type diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index a5bd7221c75414ef70a031efeca64ac81468b0f0..66a9e6dd0fcf277904d8805ff66fee8c6971b1d7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -132,11 +132,53 @@ struct MulGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } }; +template <> +struct MulGradDX { + HOSTDEVICE paddle::platform::complex64 operator()( + paddle::platform::complex64 x, paddle::platform::complex64 y, + paddle::platform::complex64 out, paddle::platform::complex64 dout) const { + paddle::platform::complex64 y_conj(y.real, -y.imag); + return dout * y_conj; + } +}; + +template <> +struct MulGradDX { + HOSTDEVICE paddle::platform::complex128 operator()( + paddle::platform::complex128 x, paddle::platform::complex128 y, + paddle::platform::complex128 out, + paddle::platform::complex128 dout) const { + paddle::platform::complex128 y_conj(y.real, -y.imag); + return dout * y_conj; + } +}; + template struct MulGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } }; +template <> +struct MulGradDY { + HOSTDEVICE paddle::platform::complex64 operator()( + paddle::platform::complex64 x, paddle::platform::complex64 y, + paddle::platform::complex64 out, paddle::platform::complex64 dout) const { + paddle::platform::complex64 x_conj(x.real, -x.imag); + return dout * x_conj; + } +}; + +template <> +struct MulGradDY { + HOSTDEVICE paddle::platform::complex128 operator()( + paddle::platform::complex128 x, paddle::platform::complex128 y, + paddle::platform::complex128 out, + paddle::platform::complex128 dout) const { + paddle::platform::complex128 x_conj(x.real, -x.imag); + return dout * x_conj; + } +}; + template typename std::enable_if< std::is_same::value>::type diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index fd2fe73ad5186612164d636ca818864809557d05..f69fa7084edb193515a626e23b91cda6b1122829 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -13,13 +13,17 @@ # limitations under the License. from __future__ import print_function + import unittest + import numpy as np -from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid import paddle.fluid.core as core +from paddle.fluid import Program, compiler, program_guard from paddle.fluid.op import Operator -import paddle.fluid as fluid -from paddle.fluid import compiler, Program, program_guard + +from op_test import OpTest, skip_check_grad_ci class ElementwiseMulOp(OpTest): @@ -241,5 +245,65 @@ class TestElementwiseMulOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.elementwise_mul, x2, y2) +class TestComplexElementwiseMulOp(OpTest): + def setUp(self): + self.op_type = "elementwise_mul" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.y = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.out = self.x * self.y + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( + (2, 3, 4, 5), self.dtype) + self.grad_x = self.grad_out * np.conj(self.y) + self.grad_y = self.grad_out * np.conj(self.x) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + if __name__ == '__main__': + paddle.enable_static() unittest.main()