eig_op.h 2.0 KB
Newer Older
L
Lijunhui 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <math.h>
18

L
Lijunhui 已提交
19 20
#include <algorithm>
#include <complex>
21

L
Lijunhui 已提交
22 23
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
24
#include "paddle/phi/kernels/complex_kernel.h"
25 26 27
#include "paddle/phi/kernels/elementwise_divide_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
28
#include "paddle/phi/kernels/funcs/complex_functors.h"
29
#include "paddle/phi/kernels/funcs/diag_functor.h"
30 31
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
32
#include "paddle/phi/kernels/funcs/matrix_solve.h"
33 34 35 36 37
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"

L
Lijunhui 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
#define EPSILON 1e-6

namespace paddle {
namespace operators {

using paddle::framework::Tensor;

inline int BatchCount(const Tensor& matrix) {
  int count = 1;
  int num_dims = matrix.dims().size();
  for (int i = 0; i < num_dims - 2; ++i) {
    count *= matrix.dims()[i];
  }
  return count;
}

inline int MatrixStride(const Tensor& matrix) {
  framework::DDim dims_list = matrix.dims();
  int num_dims = dims_list.size();
  return dims_list[num_dims - 1] * dims_list[num_dims - 2];
}

}  // namespace operators
}  // namespace paddle