From 2a260d9b0e6fdda3ef3720be55a3cc8002b31fe5 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 22 Dec 2020 17:58:44 +0800 Subject: [PATCH] change the grad of div when complex types (#29804) * change the grad of div when complex types * fix the grads of inputs args order not match bug --- .../elementwise/elementwise_div_op.cu | 39 ++++++++++++ .../elementwise/elementwise_div_op.h | 43 +++++++++++++ .../paddle/fluid/tests/unittests/op_test.py | 2 +- .../unittests/test_elementwise_div_op.py | 60 +++++++++++++++++++ 4 files changed, 143 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index df5a2115c3b..96583d06571 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -75,6 +75,45 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, } } +template <> +__global__ void SimpleElemwiseDivGradCUDAKernel( + const paddle::platform::complex64* x, const paddle::platform::complex64* y, + const paddle::platform::complex64* out, + const paddle::platform::complex64* dout, int64_t size, + paddle::platform::complex64* dx, paddle::platform::complex64* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + paddle::platform::complex64 o = dout[col]; + paddle::platform::complex64 y_conj(y[col].real, -y[col].imag); + paddle::platform::complex64 out_div_y_conj((out[col] / y[col]).real, + -(out[col] / y[col]).imag); + dx[col] = o / y_conj; + dy[col] = -o * out_div_y_conj; + col += blockDim.x * gridDim.x; + } +} + +template <> +__global__ void SimpleElemwiseDivGradCUDAKernel( + const paddle::platform::complex128* x, + const paddle::platform::complex128* y, + const paddle::platform::complex128* out, + const paddle::platform::complex128* dout, int64_t size, + paddle::platform::complex128* dx, paddle::platform::complex128* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + paddle::platform::complex128 o = dout[col]; + paddle::platform::complex128 y_conj(y[col].real, -y[col].imag); + paddle::platform::complex128 out_div_y_conj((out[col] / y[col]).real, + -(out[col] / y[col]).imag); + dx[col] = o / y_conj; + dy[col] = -o * out_div_y_conj; + col += blockDim.x * gridDim.x; + } +} + template typename std::enable_if< std::is_same::value>::type diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 1d016fba34b..d824014713d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -73,6 +73,27 @@ struct DivGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; } }; +template <> +struct DivGradDX { + HOSTDEVICE paddle::platform::complex64 operator()( + paddle::platform::complex64 x, paddle::platform::complex64 y, + paddle::platform::complex64 out, paddle::platform::complex64 dout) const { + paddle::platform::complex64 y_conj(y.real, -y.imag); + return dout / y_conj; + } +}; + +template <> +struct DivGradDX { + HOSTDEVICE paddle::platform::complex128 operator()( + paddle::platform::complex128 x, paddle::platform::complex128 y, + paddle::platform::complex128 out, + paddle::platform::complex128 dout) const { + paddle::platform::complex128 y_conj(y.real, -y.imag); + return dout / y_conj; + } +}; + template struct DivGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -80,6 +101,28 @@ struct DivGradDY { } }; +template <> +struct DivGradDY { + HOSTDEVICE paddle::platform::complex64 operator()( + paddle::platform::complex64 x, paddle::platform::complex64 y, + paddle::platform::complex64 out, paddle::platform::complex64 dout) const { + paddle::platform::complex64 out_div_y_conj((out / y).real, -(out / y).imag); + return -dout * out_div_y_conj; + } +}; + +template <> +struct DivGradDY { + HOSTDEVICE paddle::platform::complex128 operator()( + paddle::platform::complex128 x, paddle::platform::complex128 y, + paddle::platform::complex128 out, + paddle::platform::complex128 dout) const { + paddle::platform::complex128 out_div_y_conj((out / y).real, + -(out / y).imag); + return -dout * out_div_y_conj; + } +}; + template struct DivDoubleDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index f077a0286d3..25c0e3bced9 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1616,7 +1616,7 @@ class OpTest(unittest.TestCase): targets = [ outputs[name] for name in outputs if name in output_names ] - inputs = [inputs[name] for name in inputs if name in input_to_check] + inputs = [inputs[name] for name in input_to_check if name in inputs] grad_inputs = paddle.static.gradients(targets, inputs, grad_outputs, no_grad_set) fetch_list = grad_inputs diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index 3cfbac8b613..f93802c47c9 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -261,5 +261,65 @@ class TestDivideOp(unittest.TestCase): self.assertEqual((np_z == z_expected).all(), True) +class TestComplexElementwiseDivOp(OpTest): + def setUp(self): + self.op_type = "elementwise_div" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.y = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.out = self.x / self.y + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( + (2, 3, 4, 5), self.dtype) + self.grad_x = self.grad_out / np.conj(self.y) + self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + if __name__ == '__main__': + paddle.enable_static() unittest.main() -- GitLab