未验证 提交 ddfc3d2c 编写于 作者: C chentianyu03 提交者: GitHub

change grad elementwise_mul for complex types (#29757)

* add conj op for complex types

* add conj for complex types

* add more test case

* add conj_op test

* modify conj api and impl

* add complex type for fill_constant_op xpu

* add setConstant for complex type

* remove complex conj test file

* user define grad for test_conj_op

* add test case for static mode of conj api

* modify conj doc

* change input args name to x

* remove useless codes

* conj support real types

* add conj test case for real number

* delete no need to calculate inputs in dygraph op_test

* delete no need to calculate inputs in dygraph op_test

* modify grad of mul for complex types

* fix the grads of inputs args order not match bug
上级 2a260d9b
......@@ -75,6 +75,36 @@ static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y,
}
}
template <>
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex64>(
const plat::complex64* x, const plat::complex64* y,
const plat::complex64* out, const plat::complex64* dout, int64_t size,
plat::complex64* dx, plat::complex64* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
plat::complex64 o = dout[col];
dx[col] = plat::complex64(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex64(x[col].real, -x[col].imag) * o;
col += blockDim.x * gridDim.x;
}
}
template <>
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex128>(
const plat::complex128* x, const plat::complex128* y,
const plat::complex128* out, const plat::complex128* dout, int64_t size,
plat::complex128* dx, plat::complex128* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
plat::complex128 o = dout[col];
dx[col] = plat::complex128(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex128(x[col].real, -x[col].imag) * o;
col += blockDim.x * gridDim.x;
}
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
......
......@@ -132,11 +132,53 @@ struct MulGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
};
template <>
struct MulGradDX<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 y_conj(y.real, -y.imag);
return dout * y_conj;
}
};
template <>
struct MulGradDX<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 y_conj(y.real, -y.imag);
return dout * y_conj;
}
};
template <typename T>
struct MulGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; }
};
template <>
struct MulGradDY<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 x_conj(x.real, -x.imag);
return dout * x_conj;
}
};
template <>
struct MulGradDY<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 x_conj(x.real, -x.imag);
return dout * x_conj;
}
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
......
......@@ -13,13 +13,17 @@
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, compiler, program_guard
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from op_test import OpTest, skip_check_grad_ci
class ElementwiseMulOp(OpTest):
......@@ -241,5 +245,65 @@ class TestElementwiseMulOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.elementwise_mul, x2, y2)
class TestComplexElementwiseMulOp(OpTest):
def setUp(self):
self.op_type = "elementwise_mul"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype(self.dtype)
self.y = np.random.random(
(2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype(self.dtype)
self.out = self.x * self.y
def init_grad_input_output(self):
self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones(
(2, 3, 4, 5), self.dtype)
self.grad_x = self.grad_out * np.conj(self.y)
self.grad_y = self.grad_out * np.conj(self.x)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册