未验证 提交 2a260d9b 编写于 作者: C chentianyu03 提交者: GitHub

change the grad of div when complex types (#29804)

* change the grad of div when complex types

* fix the grads of inputs args order not match bug
上级 e219b8cc
...@@ -75,6 +75,45 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, ...@@ -75,6 +75,45 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,
} }
} }
template <>
__global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex64>(
const paddle::platform::complex64* x, const paddle::platform::complex64* y,
const paddle::platform::complex64* out,
const paddle::platform::complex64* dout, int64_t size,
paddle::platform::complex64* dx, paddle::platform::complex64* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
paddle::platform::complex64 o = dout[col];
paddle::platform::complex64 y_conj(y[col].real, -y[col].imag);
paddle::platform::complex64 out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
}
}
template <>
__global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex128>(
const paddle::platform::complex128* x,
const paddle::platform::complex128* y,
const paddle::platform::complex128* out,
const paddle::platform::complex128* dout, int64_t size,
paddle::platform::complex128* dx, paddle::platform::complex128* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
paddle::platform::complex128 o = dout[col];
paddle::platform::complex128 y_conj(y[col].real, -y[col].imag);
paddle::platform::complex128 out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
......
...@@ -73,6 +73,27 @@ struct DivGradDX { ...@@ -73,6 +73,27 @@ struct DivGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; } HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; }
}; };
template <>
struct DivGradDX<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 y_conj(y.real, -y.imag);
return dout / y_conj;
}
};
template <>
struct DivGradDX<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 y_conj(y.real, -y.imag);
return dout / y_conj;
}
};
template <typename T> template <typename T>
struct DivGradDY { struct DivGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
...@@ -80,6 +101,28 @@ struct DivGradDY { ...@@ -80,6 +101,28 @@ struct DivGradDY {
} }
}; };
template <>
struct DivGradDY<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 out_div_y_conj((out / y).real, -(out / y).imag);
return -dout * out_div_y_conj;
}
};
template <>
struct DivGradDY<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 out_div_y_conj((out / y).real,
-(out / y).imag);
return -dout * out_div_y_conj;
}
};
template <typename T> template <typename T>
struct DivDoubleDY { struct DivDoubleDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
......
...@@ -1616,7 +1616,7 @@ class OpTest(unittest.TestCase): ...@@ -1616,7 +1616,7 @@ class OpTest(unittest.TestCase):
targets = [ targets = [
outputs[name] for name in outputs if name in output_names outputs[name] for name in outputs if name in output_names
] ]
inputs = [inputs[name] for name in inputs if name in input_to_check] inputs = [inputs[name] for name in input_to_check if name in inputs]
grad_inputs = paddle.static.gradients(targets, inputs, grad_outputs, grad_inputs = paddle.static.gradients(targets, inputs, grad_outputs,
no_grad_set) no_grad_set)
fetch_list = grad_inputs fetch_list = grad_inputs
......
...@@ -261,5 +261,65 @@ class TestDivideOp(unittest.TestCase): ...@@ -261,5 +261,65 @@ class TestDivideOp(unittest.TestCase):
self.assertEqual((np_z == z_expected).all(), True) self.assertEqual((np_z == z_expected).all(), True)
class TestComplexElementwiseDivOp(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype(self.dtype)
self.y = np.random.random(
(2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype(self.dtype)
self.out = self.x / self.y
def init_grad_input_output(self):
self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones(
(2, 3, 4, 5), self.dtype)
self.grad_x = self.grad_out / np.conj(self.y)
self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册