未验证 提交 666e6651 编写于 作者: C chentianyu03 提交者: GitHub

change the kron gradient when complex types (#29995)

上级 a5e422c8
......@@ -26,6 +26,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
// Process an element in the output, used with a parallel-for
template <typename T>
struct KronElemFunctor {
......@@ -172,6 +175,128 @@ struct KronGradElemFunctor {
const int ndims_;
};
template <>
struct KronGradElemFunctor<complex64> {
KronGradElemFunctor(const complex64* dout, const complex64* A,
const complex64* B, complex64* dout_a, complex64* dout_b,
const int64_t* stride_dout, const int64_t* stride_a,
const int64_t* stride_b, const int64_t* shape_b,
const int64_t numel_a, const int64_t numel_b,
const int ndims)
: dout_(dout),
A_(A),
B_(B),
dout_a_(dout_a),
dout_b_(dout_b),
stride_dout_(stride_dout),
stride_a_(stride_a),
stride_b_(stride_b),
shape_b_(shape_b),
numel_a_(numel_a),
numel_b_(numel_b),
ndims_(ndims) {}
HOSTDEVICE void operator()(int64_t idx) {
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_dout_[i];
index = index % stride_dout_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] =
dout_[idx] * complex64(B_[index_b].real, -B_[index_b].imag);
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] =
dout_[idx] * complex64(A_[index_a].real, -A_[index_a].imag);
}
}
private:
const complex64* dout_;
const complex64* A_;
const complex64* B_;
complex64* dout_a_;
complex64* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
};
template <>
struct KronGradElemFunctor<complex128> {
KronGradElemFunctor(const complex128* dout, const complex128* A,
const complex128* B, complex128* dout_a,
complex128* dout_b, const int64_t* stride_dout,
const int64_t* stride_a, const int64_t* stride_b,
const int64_t* shape_b, const int64_t numel_a,
const int64_t numel_b, const int ndims)
: dout_(dout),
A_(A),
B_(B),
dout_a_(dout_a),
dout_b_(dout_b),
stride_dout_(stride_dout),
stride_a_(stride_a),
stride_b_(stride_b),
shape_b_(shape_b),
numel_a_(numel_a),
numel_b_(numel_b),
ndims_(ndims) {}
HOSTDEVICE void operator()(int64_t idx) {
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_dout_[i];
index = index % stride_dout_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] =
dout_[idx] * complex128(B_[index_b].real, -B_[index_b].imag);
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] =
dout_[idx] * complex128(A_[index_a].real, -A_[index_a].imag);
}
}
private:
const complex128* dout_;
const complex128* A_;
const complex128* B_;
complex128* dout_a_;
complex128* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
};
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
......
......@@ -102,5 +102,90 @@ class TestKronLayer(unittest.TestCase):
np.testing.assert_allclose(c, np.kron(a, b))
class TestComplexKronOp(OpTest):
def setUp(self):
self.op_type = "kron"
self.x_shape = np.array([10, 10])
self.y_shape = np.array([3, 35])
self.out_shape = self.x_shape * self.y_shape
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(self.x_shape).astype(
self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(
self.dtype) + 1J * np.random.random(self.y_shape).astype(self.dtype)
self.out = np.kron(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones(self.out_shape, self.dtype) + 1J * np.ones(
self.out_shape, self.dtype)
self.grad_x = self.get_grad_x_by_numpy()
self.grad_y = self.get_grad_y_by_numpy()
def get_grad_x_by_numpy(self):
grad_x = np.zeros(self.x_shape, np.complex)
for x_i in range(self.x_shape[0]):
for x_j in range(self.x_shape[1]):
for i in range(self.y_shape[0]):
for j in range(self.y_shape[1]):
idx_i = x_i * self.y_shape[0] + i
idx_j = x_j * self.y_shape[1] + j
grad_x[x_i][x_j] += self.grad_out[idx_i][
idx_j] * np.conj(self.y[i][j])
return grad_x
def get_grad_y_by_numpy(self):
grad_y = np.zeros(self.y_shape, np.complex)
for y_i in range(self.y_shape[0]):
for y_j in range(self.y_shape[1]):
for x_i in range(self.x_shape[0]):
for x_j in range(self.x_shape[1]):
idx_i = x_i * self.y_shape[0] + y_i
idx_j = x_j * self.y_shape[1] + y_j
grad_y[y_i][y_j] += self.grad_out[idx_i][
idx_j] * np.conj(self.x[x_i][x_j])
return grad_y
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册