1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include <algorithm>
#include <complex>
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
#include "paddle/pten/kernels/funcs/lapack/lapack_function.h"
#include "paddle/pten/kernels/funcs/math_function.h"
#define EPSILON 1e-6
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
inline int BatchCount(const Tensor& matrix) {
int count = 1;
int num_dims = matrix.dims().size();
for (int i = 0; i < num_dims - 2; ++i) {
count *= matrix.dims()[i];
}
return count;
}
inline int MatrixStride(const Tensor& matrix) {
framework::DDim dims_list = matrix.dims();
int num_dims = dims_list.size();
return dims_list[num_dims - 1] * dims_list[num_dims - 2];
}
// Transpose two axis of a Tensor
template <typename DeviceContext, typename T>
void TransposeTwoAxis(const Tensor& input, Tensor* transposed_input,
const int axis1, const int axis2,
const framework::ExecutionContext& context) {
std::vector<int> permute(input.dims().size());
std::iota(permute.begin(), permute.end(), 0);
permute[axis1] = axis2;
permute[axis2] = axis1;
transposed_input->mutable_data<T>(input.dims(), context.GetPlace());
auto& dev_ctx = context.template device_context<platform::CPUDeviceContext>();
TransCompute<DeviceContext, T>(input.dims().size(), dev_ctx, input,
transposed_input, permute);
}
// Apply eig to a batch of matrices, values, vectors and (intermidiate
// tensor) info are overritten
template <typename T>
void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info,
const framework::ExecutionContext& context) {
char jobvl = 'N';
char jobvr = 'V'; // only right eigenvectors are computed
int num_dims = input->dims().size();
int order = input->dims()[num_dims - 1];
T* input_data = input->data<T>();
int lda = std::max<int>(1, order);
T* values_data = values->mutable_data<T>(context.GetPlace());
T* lvector_data = nullptr;
int ldvl = 1;
T* rvector_data = vectors->mutable_data<T>(context.GetPlace());
int ldvr = lda;
int lwork = -1;
int batch_count = BatchCount(*input);
int matrix_stride = MatrixStride(*input);
int values_stride = values->dims()[values->dims().size() - 1];
Tensor rwork;
pten::funcs::Real<T>* rwork_data = nullptr;
rwork.Resize(framework::make_ddim({lda * 2}));
rwork_data = rwork.mutable_data<pten::funcs::Real<T>>(context.GetPlace());
// call lapackEig once to compute the size of work;
T computed_work_size;
pten::funcs::lapackEig<T, pten::funcs::Real<T>>(
jobvl, jobvr, order, input_data, lda, values_data, lvector_data, ldvl,
rvector_data, ldvr, &computed_work_size, lwork, rwork_data, &info);
lwork = std::max<int>(
1, static_cast<int>(pten::funcs::Real<T>(computed_work_size)));
Tensor work;
work.Resize(framework::make_ddim({lwork}));
T* work_data = work.mutable_data<T>(context.GetPlace());
for (auto i = 0; i < batch_count; ++i) {
T* current_matrix = &input_data[i * matrix_stride];
T* current_values = &values_data[i * values_stride];
T* current_rvectors = &rvector_data[i * matrix_stride];
pten::funcs::lapackEig<T, pten::funcs::Real<T>>(
jobvl, jobvr, order, current_matrix, lda, current_values, lvector_data,
ldvl, current_rvectors, ldvr, work_data, lwork, rwork_data, &info);
PADDLE_ENFORCE_EQ(
info, 0,
platform::errors::PreconditionNotMet(
"current info is not 0, computation failed. "
"= 0: successful exit."
"< 0: if INFO = -i, the i-th argument had an illegal value."
"> 0: if INFO = i, the QR algorithm failed to compute all the "
"eigenvalues, and no eigenvectors have been computed; "
"elements i+1:N of WR and WI contain eigenvalues which "
"have converged."));
}
}
template <typename DeviceContext, typename T>
void ApplyEigKernel(const Tensor& input, Tensor* values, Tensor* vectors,
const framework::ExecutionContext& context) {
Tensor input_column_major;
Tensor vectors_row_major;
int num_dims = input.dims().size();
// transfer to column-major memory layout i.e. make_ddim from tranposed_input:
// [batch,row,col]->[batch,col,row]
TransposeTwoAxis<DeviceContext, T>(input, &input_column_major, num_dims - 1,
num_dims - 2, context);
// make sure 'vectors_row_major' holds memory before passed to LapackEig()
vectors_row_major.Resize(input.dims());
int info = 0;
LapackEig<T>(&input_column_major, values, &vectors_row_major, info, context);
// transfer column-major layout back
// vectors_row_major: column-major layout
// vector: original layout
TransposeTwoAxis<DeviceContext, T>(vectors_row_major, vectors, num_dims - 1,
num_dims - 2, context);
}
template <typename T, typename Tout>
void ConstructComplexVectors(Tensor* c_vectors, const Tensor& c_values,
const Tensor& r_vectors,
const framework::ExecutionContext& ctx,
int batch_count, int order) {
int matrix_stride = MatrixStride(r_vectors);
auto* c_vectors_data = c_vectors->mutable_data<Tout>(ctx.GetPlace());
auto* c_values_data = c_values.data<Tout>();
auto* r_v_data = r_vectors.data<T>();
for (int b = 0; b < batch_count; b++) {
auto* vecs = &r_v_data[b * matrix_stride];
auto* res = &c_vectors_data[b * matrix_stride];
auto* vals = &c_values_data[b * order];
for (int j = 0; j < order; j++) {
if (vals[j].imag < EPSILON) {
for (int i = 0; i < order; i++) {
res[j * order + i] = platform::complex<T>(vecs[j * order + i], 0);
}
} else {
for (int i = 0; i < order; i++) {
res[j * order + i] = platform::complex<T>(vecs[j * order + i],
vecs[(j + 1) * order + i]);
res[(j + 1) * order + i] = platform::complex<T>(
vecs[j * order + i], -vecs[(j + 1) * order + i]);
}
j++;
}
}
}
}
template <typename DeviceContext, typename T, typename Tout>
class EigKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out_values = context.Output<Tensor>("Eigenvalues");
auto* out_vectors = context.Output<Tensor>("Eigenvectors");
if (!framework::IsComplexType(framework::TransToProtoVarType(x->dtype()))) {
out_values->mutable_data<Tout>(context.GetPlace());
out_vectors->mutable_data<Tout>(context.GetPlace());
int batch_count = BatchCount(*x);
int order = x->dims()[x->dims().size() - 1];
Tensor real_values;
Tensor real_vectors;
// double the size of real_values, the first half stores the real part,
// the next half stores the imag part
std::vector<int> origin_dim =
framework::vectorize<int>(out_values->dims());
int last_item = origin_dim.back();
origin_dim.pop_back();
origin_dim.push_back(last_item * 2);
framework::DDim big_dim = framework::make_ddim(origin_dim);
real_values.mutable_data<pten::funcs::Real<T>>(big_dim,
context.GetPlace());
real_vectors.mutable_data<pten::funcs::Real<T>>(x->dims(),
context.GetPlace());
ApplyEigKernel<DeviceContext, pten::funcs::Real<T>>(
*x, &real_values, &real_vectors, context);
auto dito = math::DeviceIndependenceTensorOperations<
DeviceContext, pten::funcs::Real<T>, Tout>(context);
// 1. extract real part & imag part from real_values
Tensor real_part = dito.Slice(real_values, {-1}, {0}, {order});
Tensor imag_part = dito.Slice(real_values, {-1}, {order}, {order * 2});
// 2. construct complex values
auto* real_part_data = real_part.data<pten::funcs::Real<T>>();
auto* imag_part_data = imag_part.data<pten::funcs::Real<T>>();
int out_values_numel = out_values->numel();
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(), out_values_numel);
pten::funcs::RealImagToComplexFunctor<Tout> functor(
real_part_data, imag_part_data,
out_values->mutable_data<Tout>(context.GetPlace()), out_values_numel);
for_range(functor);
// 3. construct complex vectors
Tensor real_vector_trans = dito.Transpose(real_vectors);
Tensor out_vectors_trans;
out_vectors_trans.mutable_data<Tout>(x->dims(), context.GetPlace());
ConstructComplexVectors<pten::funcs::Real<T>, Tout>(
&out_vectors_trans, *out_values, real_vector_trans, context,
batch_count, order);
TransposeTwoAxis<DeviceContext, Tout>(out_vectors_trans, out_vectors,
x->dims().size() - 1,
x->dims().size() - 2, context);
} else {
out_values->mutable_data<T>(context.GetPlace());
out_vectors->mutable_data<T>(context.GetPlace());
ApplyEigKernel<DeviceContext, T>(*x, out_values, out_vectors, context);
}
}
};
template <typename DeviceContext, typename Tout>
void ComputeBackwardForComplexInput(
const Tensor& V, const Tensor& L, const Tensor& gL, const Tensor& gV,
Tout* x_grad_data, int batch_count, int order,
const framework::ExecutionContext& context) {
auto dito =
math::DeviceIndependenceTensorOperations<DeviceContext, Tout, Tout>(
context);
Tensor trans_v = dito.Transpose(V);
Tensor Vh = dito.Conj(trans_v);
Tensor Lconj = dito.Conj(L);
Tensor Econj = dito.Sub(dito.Unsqueeze(Lconj, -2), dito.Unsqueeze(Lconj, -1));
Tensor VhgV = dito.Matmul(Vh, gV);
Tensor diag_real = dito.Real(VhgV);
Tensor diag_res = dito.BatchDiag(diag_real, batch_count);
Tensor diag_unsqueezed = dito.Unsqueeze(diag_res, -2);
// turn diag_unsqueezed into complex
auto numel = diag_unsqueezed.numel();
Tensor diag_unsqueezed_complex;
auto* data_diag_un = diag_unsqueezed.data<pten::funcs::Real<Tout>>();
auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data<Tout>(
diag_unsqueezed.dims(), context.GetPlace(),
static_cast<size_t>(numel * sizeof(Tout)));
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
pten::funcs::RealToComplexFunctor<Tout> functor(data_diag_un,
data_diag_un_com, numel);
for_range(functor);
// real tensor multiply complex tensor in broadcast manner
Tensor res1 = dito.RealMulComplex(V, diag_unsqueezed_complex);
Tensor res2 = dito.Matmul(Vh, res1);
Tensor result = dito.Sub(VhgV, res2);
result.mutable_data<Tout>(V.dims(), context.GetPlace());
result = dito.Div(result, Econj);
result = dito.DiagFill(order, order, order, 0, gL, result);
Tensor rhs = dito.Matmul(result, Vh);
// solve linear system
// solve(Vh, rhs, out, m, k)
// Vh: matrix with shape [m,m]
// rhs: rhs with shape [m,k]
// x_grad: out
int m = Vh.dims()[Vh.dims().size() - 1];
int k = rhs.dims()[rhs.dims().size() - 1];
auto* matrix_data = Vh.data<Tout>();
auto* rhs_data = rhs.data<Tout>();
math::SolveLinearSystem<Tout>(matrix_data, rhs_data, x_grad_data, m, k,
batch_count);
}
template <typename DeviceContext, typename T, typename Tout>
class EigGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& L = *context.Input<Tensor>("Eigenvalues");
auto& V = *context.Input<Tensor>("Eigenvectors");
auto& gL = *context.Input<Tensor>(framework::GradVarName("Eigenvalues"));
auto& gV = *context.Input<Tensor>(framework::GradVarName("Eigenvectors"));
auto& x_grad = *context.Output<Tensor>(framework::GradVarName("X"));
auto* x_grad_data = x_grad.mutable_data<Tout>(context.GetPlace());
auto& dims = V.dims();
framework::DDim dim_origin = dims;
int num_dims = dim_origin.size();
int batch_count = BatchCount(V);
const int order = dim_origin[num_dims - 1];
ComputeBackwardForComplexInput<DeviceContext, Tout>(
V, L, gL, gV, x_grad_data, batch_count, order, context);
}
};
} // namespace operators
} // namespace paddle