svd_helper.h 28.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16

17
#include <Eigen/src/Core/util/Constants.h>
18

19 20 21
#include <Eigen/Dense>
#include <Eigen/SVD>
#include <iostream>
22

23 24
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
25 26 27
#include "paddle/fluid/operators/diag_op.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
28 29
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
30 31 32
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
33
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
34
#include "paddle/phi/kernels/funcs/math_function.h"
35 36 37 38

namespace paddle {
namespace operators {
namespace math {
39 40 41
using Tensor = phi::DenseTensor;
using InTensors = std::vector<const phi::DenseTensor*>;
using OutTensors = std::vector<phi::DenseTensor*>;
42
using OpName = std::string;
43 44
template <typename T,
          int MajorType = Eigen::RowMajor,
45 46
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
47 48 49

template <typename T>
struct PowFunctor {
50
  PowFunctor(const T* input, T* output, int64_t numel, T exp)
51 52 53 54 55 56 57 58
      : input_(input), output_(output), numel_(numel), exp_(exp) {}

  HOSTDEVICE void operator()(int64_t idx) const {
    output_[idx] = pow(input_[idx], exp_);
  }
  const T* input_;
  T* output_;
  int64_t numel_;
59
  T exp_;
60 61
};

L
Lijunhui 已提交
62 63 64 65 66 67
template <typename T>
struct RealMulComplexFunctor {
  // x: complex number (a+bj)
  // y: complex number (c+0j) pretend to be a real number
  // out: complex number (ac+bcj)
  inline HOSTDEVICE T operator()(T x, T y) {
68
    PADDLE_ENFORCE_LT(
69 70
        y.imag,
        1e-6,
71 72 73
        platform::errors::InvalidArgument("The image part of y must to be 0"
                                          "but got [%d]",
                                          y.imag));
74
    return platform::complex<phi::dtype::Real<T>>(x.real * y.real,
75
                                                  x.imag * y.real);
L
Lijunhui 已提交
76 77 78
  }
};

79
static std::vector<int> GetBroadcastShape(InTensors ins) {
80
  PADDLE_ENFORCE_EQ(
81 82
      ins.size(),
      2,
83 84 85
      platform::errors::InvalidArgument("GetBroadcastShape Receive 2 tensors"
                                        "but got [%d]",
                                        ins.size()));
86 87 88
  auto x_dim = ins[0]->dims();
  auto y_dim = ins[1]->dims();
  std::vector<int> broadcast_shape =
89 90
      (x_dim.size() > y_dim.size() ? phi::vectorize<int>(x_dim)
                                   : phi::vectorize<int>(y_dim));
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
  int rank_min = std::min(x_dim.size(), y_dim.size());
  int rank_x = x_dim.size();
  int rank_y = y_dim.size();
  int final_rank = broadcast_shape.size();
  for (int i = 1; i <= rank_min; ++i) {
    if (x_dim[rank_x - i] == y_dim[rank_y - i]) {
      broadcast_shape[final_rank - i] = x_dim[rank_x - i];
      continue;
    }
    if (x_dim[rank_x - i] == 1) {
      broadcast_shape[final_rank - i] = y_dim[rank_y - i];
      continue;
    }
    if (y_dim[rank_y - i] == 1) {
      broadcast_shape[final_rank - i] = x_dim[rank_x - i];
      continue;
    }
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Wrong Input Shape in broadcast operator: "
        "Input(X)'s shape must follow the broadcast rule with Input(Y)'s "
        "shape, but received [%s] (X) vs [%s] (Y).",
112 113
        x_dim,
        y_dim));
114 115 116 117
  }
  return broadcast_shape;
}

118
static inline framework::DDim ComputeAndCheckShapeForConcatOp(
119 120
    const bool is_runtime,
    const std::vector<framework::DDim>& inputs_dims,
121 122 123 124 125
    const size_t axis) {
  const size_t n = inputs_dims.size();
  auto out_dims = inputs_dims[0];
  size_t in_zero_dims_size = out_dims.size();
  for (size_t i = 1; i < n; i++) {
126 127 128 129 130 131 132 133 134 135 136
    PADDLE_ENFORCE_EQ(
        inputs_dims[i].size(),
        out_dims.size(),
        platform::errors::InvalidArgument("The shape of input[0] and input[%d] "
                                          "is expected to be equal."
                                          "But received input[0]'s shape = "
                                          "[%s], input[%d]'s shape = [%s].",
                                          i,
                                          inputs_dims[0],
                                          i,
                                          inputs_dims[i]));
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
    for (size_t j = 0; j < in_zero_dims_size; j++) {
      if (j == axis) {
        if (is_runtime) {
          out_dims[axis] += inputs_dims[i][j];
        } else {
          if (inputs_dims[i][j] == -1 || out_dims[j] == -1) {
            out_dims[axis] = -1;
          } else {
            out_dims[axis] += inputs_dims[i][j];
          }
        }
      } else {
        bool check_shape =
            is_runtime || (inputs_dims[0][j] > 0 && inputs_dims[i][j] > 0);
        if (check_shape) {
          // check all shape in run time
153 154
          PADDLE_ENFORCE_EQ(inputs_dims[0][j],
                            inputs_dims[i][j],
155 156 157 158 159
                            platform::errors::InvalidArgument(
                                "The %d-th dimension of input[0] and input[%d] "
                                "is expected to be equal."
                                "But received input[0]'s shape = "
                                "[%s], input[%d]'s shape = [%s].",
160 161 162 163 164
                                j,
                                i,
                                inputs_dims[0],
                                i,
                                inputs_dims[i]));
165 166 167 168 169 170 171 172 173 174 175 176
        }
        if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) {
          out_dims[j] = inputs_dims[i][j];
        }
      }
    }
  }
  return out_dims;
}

static inline int64_t ComputeAxisForConcatOp(int64_t axis, int64_t rank) {
  PADDLE_ENFORCE_EQ(
177 178
      axis >= -rank && axis < rank,
      true,
179
      platform::errors::InvalidArgument(
180 181 182 183
          "The axis is expected to be in range of [%d, %d), but got %d",
          -rank,
          rank,
          axis));
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
  if (axis < 0) {
    axis = axis + rank;
  }
  return axis > 0 ? axis : 0;
}

// Prepared for the broadcast operation
static std::vector<int64_t> get_broadcast_batch_portion(
    std::vector<int64_t> x, std::vector<int64_t> y) {
  size_t size_x = x.size();
  size_t size_y = y.size();
  size_t size = std::max(size_x, size_y);
  std::vector<int64_t> batchPortion(size);

  ptrdiff_t i = (ptrdiff_t)size - 1;
  for (; i >= 0; --i) {
    ptrdiff_t offset = size - i - 1;
    ptrdiff_t dim_x = size_x - offset - 1;
    ptrdiff_t dim_y = size_y - offset - 1;
    int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1;
    int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1;

    PADDLE_ENFORCE_EQ(
207 208
        (x_size == y_size || x_size == 1 || y_size == 1),
        true,
209 210 211
        platform::errors::PreconditionNotMet(
            "The size of tensor x (%d) must match the size of tensor y "
            "(%d) at non-singleton dimension %d.",
212 213 214
            x_size,
            y_size,
            i));
215 216 217 218 219 220

    batchPortion[i] = x_size != 1 ? x_size : y_size;
  }
  return batchPortion;
}

221 222 223 224 225
#define DITO_TRANSPOSE_RANK_CASE(N)                   \
  case N: {                                           \
    phi::funcs::Transpose<DeviceContext, T, N> trans; \
    trans(dev_ctx, x, &ret, axis);                    \
    break;                                            \
226 227 228 229 230 231 232 233
  }

#define DITO_SLICE_RANK_CASE(N)                      \
  case N: {                                          \
    EigenSliceWrapper<N>(&x, offset, extends, &ret); \
    break;                                           \
  }

234 235
template <typename T, typename ValueType>
struct DiagAndFillFunctor {
236 237 238 239 240 241 242
  DiagAndFillFunctor(const int m,
                     const int n,
                     const int num_lower_diags,
                     const int num_upper_diags,
                     const ValueType* scale,
                     const T* input,
                     T* output)
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
      : m_(m),
        n_(n),
        num_lower_diags_(num_lower_diags),
        num_upper_diags_(num_upper_diags),
        scale_(scale),
        input_(input),
        output_(output) {}

  HOSTDEVICE void operator()(size_t index) const {
    const int col = index % n_;
    const int row = (index / n_) % m_;
    const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_);
    const int band_end =
        (num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1);
    if (col < band_start || col >= band_end) {
      output_[index] = input_[index];
    } else if (col == band_end - 1) {
      output_[index] = static_cast<T>(scale_[index % m_]);
    } else {
      output_[index] = input_[index];
    }
  }

 private:
  const int m_, n_, num_lower_diags_, num_upper_diags_;
  const ValueType* scale_;
  const T* input_;
  T* output_;
};

template <typename DeviceContext, typename T, typename ValueType = T>
274 275 276 277 278 279 280 281
struct DeviceIndependenceTensorOperations {
  // 1. Device indenpendence, for kernel reuse.
  // 2. Input and output is always tensor type.
  // 3. output Tensor is alway allocated
  // 4. Basic Tensor operator is supported
  // 5. The Reused Operator Kernel should only be considered as
  //    a wrap function
  using NameInTensorMap =
282
      std::map<std::string, std::vector<const phi::DenseTensor*>>;
283 284 285 286 287 288
  using NameOutTensor = std::vector<std::string>;

  explicit DeviceIndependenceTensorOperations(
      const framework::ExecutionContext& context)
      : context(context) {}

289 290
  phi::DenseTensor Pow(const phi::DenseTensor& x, T exp) {
    phi::DenseTensor out;
291 292
    auto for_range = GetForRange(x.numel());
    int numel = x.numel();
293 294
    PowFunctor<T> functor(
        x.data<T>(), out.mutable_data<T>(x.dims(), x.place()), numel, exp);
295 296 297
    for_range(functor);
    return out;
  }
298 299 300 301 302
  phi::DenseTensor Matmul(const phi::DenseTensor& mat_a,
                          const phi::DenseTensor& mat_b,
                          bool trans_a = false,
                          bool trans_b = false) {
    phi::DenseTensor ret;
303 304
    auto a_dim = mat_a.dims();
    auto b_dim = mat_b.dims();
305
    std::vector<int> x_vec = phi::vectorize<int>(a_dim);
306 307
    x_vec[x_vec.size() - 2] = a_dim[a_dim.size() - (trans_a ? 1 : 2)];
    x_vec[x_vec.size() - 1] = b_dim[b_dim.size() - (trans_b ? 2 : 1)];
308
    ret.Resize(phi::make_ddim(x_vec));
309 310
    ret.mutable_data<T>(context.GetPlace());
    auto blas = GetBlas();
311 312
    auto mat_a_discrib = phi::funcs::CreateMatrixDescriptor(a_dim, 0, trans_a);
    auto mat_b_discrib = phi::funcs::CreateMatrixDescriptor(b_dim, 0, trans_b);
313 314
    blas.MatMul(
        mat_a, mat_a_discrib, mat_b, mat_b_discrib, T(1.0), &ret, T(0.0));
315
    return ret;
316
  }
317

318
  phi::DenseTensor Transpose(const phi::DenseTensor& x) {
319
    // transpose the last two dimision
320
    phi::DenseTensor ret;
321
    auto x_dim = x.dims();
322
    auto x_vec = phi::vectorize<int>(x_dim);
323 324 325 326 327 328 329 330
    int rank = x_vec.size();
    std::swap(x_vec[rank - 1], x_vec[rank - 2]);
    std::vector<int> out_shape = x_vec;
    std::vector<int> axis(rank);
    for (int i = 0; i < rank; ++i) {
      axis[i] = i;
    }
    std::swap(axis[rank - 1], axis[rank - 2]);
331
    auto& dev_ctx = context.template device_context<DeviceContext>();
332
    ret.Resize(phi::make_ddim(x_vec));
333 334 335 336 337 338 339 340 341 342 343 344 345 346
    ret.mutable_data<T>(context.GetPlace());
    switch (rank) {
      DITO_TRANSPOSE_RANK_CASE(2);
      DITO_TRANSPOSE_RANK_CASE(3);
      DITO_TRANSPOSE_RANK_CASE(4);
      DITO_TRANSPOSE_RANK_CASE(5);
      DITO_TRANSPOSE_RANK_CASE(6);
      default: {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Invalid Rank number, "
            "currently only support rank between 2~6"));
      }
    }
    return ret;
347
  }
348 349 350 351
  phi::DenseTensor Diag(const phi::DenseTensor& x,
                        int offset = 0,
                        // FIXME  link error
                        int padding_value = 0) {
352 353
    PADDLE_ENFORCE_EQ(padding_value,
                      0,
354 355
                      platform::errors::InvalidArgument(
                          "Current diag only support padding_value = 0"));
356 357
    PADDLE_ENFORCE_EQ(offset,
                      0,
358 359 360 361
                      platform::errors::InvalidArgument(
                          "Current diag only support offset = 0,"
                          "you can use DiagOp instead(not recommend)"));

362
    phi::DenseTensor ret;
363 364 365
    int x_rank = x.dims().size();
    std::vector<int> out_shape;
    if (x_rank == 2) {
366 367 368 369
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Current diag only support vector"
          "-> diagonalized matrix, not support matrix -> vector,"
          " Use DiagOp instead."));
370 371 372 373 374 375 376
    } else if (x_rank == 1) {
      out_shape.push_back(x.dims()[0]);
      out_shape.push_back(x.dims()[0]);
    } else {
      PADDLE_THROW(
          platform::errors::InvalidArgument("Rank must less or equal than 2"));
    }
377 378 379 380 381 382
    ret = Fill({out_shape[0], out_shape[0]}, 0.0);
    T* output = ret.mutable_data<T>(context.GetPlace());
    auto for_range = GetForRange(x.numel());
    for_range(DiagFunctor<T>(x.data<T>(), x.numel(), output));
    return ret;
  }
L
Lijunhui 已提交
383 384

  // batch_diag for CPU only
385
  Tensor BatchDiag(const phi::DenseTensor& x, int batch) {
L
Lijunhui 已提交
386
    Tensor out;
387
    auto* x_data = x.data<phi::dtype::Real<T>>();
L
Lijunhui 已提交
388
    auto numel = x.numel();
389
    auto* out_data = out.mutable_data<phi::dtype::Real<T>>(
390 391
        x.dims(),
        context.GetPlace(),
392
        static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
L
Lijunhui 已提交
393 394 395 396 397 398 399 400

    auto x_dims = x.dims();
    int num_dims = x_dims.size();
    std::vector<int> out_shape;

    for (int i = 0; i < num_dims - 1; ++i) {
      out_shape.push_back(x.dims()[i]);
    }
401
    out.Resize(phi::make_ddim(out_shape));
L
Lijunhui 已提交
402 403 404 405 406 407 408 409 410 411 412 413
    int order = x.dims()[num_dims - 1];
    int stride_out = order * order;
    int stride_in = order + 1;
    for (int i = 0; i < batch; ++i) {
      for (int j = 0; j < order; ++j) {
        out_data[i * order + j] = x_data[stride_out * i + stride_in * j];
      }
    }
    return out;
  }

  // a complex number x times a real number y, which is represented as (a+0j)
414 415
  Tensor RealMulComplex(const phi::DenseTensor& x, const phi::DenseTensor& y) {
    phi::DenseTensor ret;
L
Lijunhui 已提交
416
    std::vector<int> out_shape = GetBroadcastShape({&x, &y});
417
    ret.Resize(phi::make_ddim(out_shape));
L
Lijunhui 已提交
418 419 420 421 422
    ElementwiseComputeEx<RealMulComplexFunctor<T>, DeviceContext, T>(
        context, &x, &y, -1, RealMulComplexFunctor<T>(), &ret);
    return ret;
  }

423 424
  phi::DenseTensor Div(const phi::DenseTensor& x, const phi::DenseTensor& y) {
    phi::DenseTensor ret;
425 426 427 428 429 430 431 432 433 434
    if (x.type() != y.type()) {
      ret.mutable_data<T>(x.dims(), context.GetPlace());
      auto x_vector = EigenVector<T>::Flatten(x);
      auto y_vector = EigenVector<ValueType>::Flatten(y);
      auto out_vector = EigenVector<T>::Flatten(ret);
      auto& place =
          *context.template device_context<DeviceContext>().eigen_device();
      out_vector.device(place) = x_vector / y_vector;
    } else {
      std::vector<int> out_shape = GetBroadcastShape({&x, &y});
435
      ret.Resize(phi::make_ddim(out_shape));
436 437 438
      ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
          context, &x, &y, -1, DivFunctor<T>(), &ret);
    }
439
    return ret;
440
  }
441
  phi::DenseTensor Add(const phi::DenseTensor& x, const phi::DenseTensor& y) {
442
    // element wise add, support numpy broadcast.
443
    phi::DenseTensor ret;
444
    std::vector<int> out_shape = GetBroadcastShape({&x, &y});
445
    ret.Resize(phi::make_ddim(out_shape));
446 447 448
    ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
        context, &x, &y, -1, AddFunctor<T>(), &ret);
    return ret;
449
  }
450 451
  phi::DenseTensor Mul(const phi::DenseTensor& x, const phi::DenseTensor& y) {
    phi::DenseTensor ret;
452
    std::vector<int> out_shape = GetBroadcastShape({&x, &y});
453
    ret.Resize(phi::make_ddim(out_shape));
454 455 456 457 458
    ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
        context, &x, &y, -1, MulFunctor<T>(), &ret);
    return ret;
  }

459 460
  phi::DenseTensor ReduceSum(const phi::DenseTensor& x,
                             std::vector<int> out_dim) {
461 462 463 464 465 466
    framework::AttributeMap attrs;
    attrs["dim"] = std::vector<int>{-1};
    NameInTensorMap inputs({{"X", {&x}}});
    return CreateOpRunAndReturnTensor("reduce_sum", inputs, attrs, out_dim);
  }

467 468
  phi::DenseTensor ReduceMax(const phi::DenseTensor& x,
                             std::vector<int> out_dim) {
469 470 471 472
    framework::AttributeMap attrs;
    attrs["dim"] = std::vector<int>{-1};
    NameInTensorMap inputs({{"X", {&x}}});
    return CreateOpRunAndReturnTensor("reduce_max", inputs, attrs, out_dim);
473
  }
474 475
  // Support float and complex type subtraction,the default is T type
  template <typename InT = T>
476 477
  phi::DenseTensor Sub(const phi::DenseTensor& x, const phi::DenseTensor& y) {
    phi::DenseTensor ret;
478
    std::vector<int> out_shape = GetBroadcastShape({&x, &y});
479
    ret.Resize(phi::make_ddim(out_shape));
480 481 482 483
    if (platform::is_gpu_place(context.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
      // For GPU, there is no need to define XxxInverseFunctor and call
      // ElementwiseComputeEx in two branches.
484 485
      ElementwiseComputeEx<SubFunctor<InT>, DeviceContext, InT>(
          context, &x, &y, -1, SubFunctor<InT>(), &ret);
486
#endif
487
    } else {
488
      if (x.dims().size() >= y.dims().size()) {
489 490
        ElementwiseComputeEx<SubFunctor<InT>, DeviceContext, InT>(
            context, &x, &y, -1, SubFunctor<InT>(), &ret);
491
      } else {
492 493 494 495
        // This is copyed from elementwise_sub, which means we
        // need reverse will xrank < yrank
        ElementwiseComputeEx<InverseSubFunctor<InT>, DeviceContext, InT>(
            context, &x, &y, -1, InverseSubFunctor<InT>(), &ret);
496
      }
497 498
    }
    return ret;
499
  }
500
  const phi::DenseTensor Unsqueeze(const phi::DenseTensor& x, int axis = 0) {
501
    // don't copy data, only change the dims
502
    phi::DenseTensor out;
503
    out.ShareDataWith(x);
504
    std::vector<int> out_shape = phi::vectorize<int>(x.dims());
505 506 507 508 509 510 511
    if (axis >= 0) {
      auto index = (out_shape.begin() + axis);
      out_shape.insert(index, 1);
    } else if (axis < 0) {
      auto index = (out_shape.end() + axis + 1);
      out_shape.insert(index, 1);
    }
512
    out.Resize(phi::make_ddim(out_shape));
513 514
    return out;
  }
515 516
  phi::DenseTensor Fill(std::vector<int> shape, float fill_value) {
    phi::DenseTensor ret;
517
    ret.Resize(phi::make_ddim(shape));
518 519
    ret.mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<DeviceContext>();
520
    phi::funcs::SetConstant<DeviceContext, T>()(dev_ctx, &ret, T(fill_value));
521
    return ret;
522
  }
523
  phi::DenseTensor Infinits(std::vector<int> shape) {
524 525
    auto value = static_cast<T>(std::numeric_limits<double>::infinity());
    return Fill(shape, value);
526
  }
527
  phi::DenseTensor Eye(int n) {
528
    auto output = Fill({n}, 1);
529 530 531
    auto ret = Diag(output);
    return ret;
  }
532 533 534 535 536
  phi::DenseTensor Slice(const phi::DenseTensor& x,
                         std::vector<int> axes,
                         std::vector<int> starts,
                         std::vector<int> ends) {
    phi::DenseTensor ret;
537
    std::vector<int> new_axes = axes;
538
    std::vector<int> out_shape = phi::vectorize<int>(x.dims());
539
    size_t rank = out_shape.size();
540
    PADDLE_ENFORCE_EQ(
541 542
        axes.size(),
        starts.size(),
543 544
        platform::errors::InvalidArgument("Slice Operator Argument Invalided"));
    PADDLE_ENFORCE_EQ(
545 546
        ends.size(),
        starts.size(),
547 548 549 550 551 552 553
        platform::errors::InvalidArgument("Slice Operator Argument Invalided"));
    for (unsigned int i = 0; i < axes.size(); ++i) {
      int axis = axes[i];
      if (axis < 0) axis = rank + axis;
      new_axes[i] = axis;  // change negative to positive
      int st = starts[i];
      int ed = ends[i];
554 555
      PADDLE_ENFORCE_GT(ed,
                        st,
556 557 558 559
                        platform::errors::InvalidArgument(
                            "C++ Slice Operation Not Support End < Start"));
      out_shape[axis] = ed - st;
    }
560 561 562 563 564 565 566 567 568
    std::vector<int> offset(rank), extends(rank);
    for (size_t i = 0; i < rank; ++i) {
      offset[i] = 0;
      extends[i] = x.dims()[i];
    }
    for (size_t i = 0; i < new_axes.size(); ++i) {
      offset[new_axes[i]] = starts[i];
      extends[new_axes[i]] = ends[i] - starts[i];
    }
569
    ret.Resize(phi::make_ddim(out_shape));
570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
    ret.mutable_data<T>(context.GetPlace());
    switch (rank) {
      DITO_SLICE_RANK_CASE(1);
      DITO_SLICE_RANK_CASE(2);
      DITO_SLICE_RANK_CASE(3);
      DITO_SLICE_RANK_CASE(4);
      DITO_SLICE_RANK_CASE(5);
      DITO_SLICE_RANK_CASE(6);
      default: {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Invalid Rank number, "
            "currently only support rank between 2~6"));
      }
    }
    return ret;
585 586
  }

587 588 589
  phi::DenseTensor TrilTriu(const phi::DenseTensor& x,
                            int diagonal,
                            bool lower) {
590 591 592 593 594
    framework::AttributeMap attrs;
    attrs["diagonal"] = diagonal;
    attrs["lower"] = lower;
    NameInTensorMap inputs({{"X", {&x}}});
    int x_rank = x.dims().size();
595
    PADDLE_ENFORCE_GE(
596 597
        x_rank,
        2,
598
        platform::errors::InvalidArgument("Rank must be at least 2."));
599
    std::vector<int> out_shape = phi::vectorize<int>(x.dims());
600 601 602
    return CreateOpRunAndReturnTensor("tril_triu", inputs, attrs, out_shape);
  }

603 604 605 606 607
  phi::DenseTensor TriangularSolve(const phi::DenseTensor& x,
                                   const phi::DenseTensor& y,
                                   bool upper,
                                   bool transpose,
                                   bool unitriangular) {
608 609 610 611 612 613 614 615
    framework::AttributeMap attrs;
    attrs["upper"] = upper;
    attrs["transpose"] = transpose;
    attrs["unitriangular"] = unitriangular;
    NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}});
    auto x_dims = x.dims();
    auto y_dims = y.dims();
    auto y_dims_n = y_dims.size();
616 617
    std::vector<int64_t> x_dims_vec = phi::vectorize<int64_t>(x_dims);
    std::vector<int64_t> y_dims_vec = phi::vectorize<int64_t>(y_dims);
618 619 620 621 622 623 624
    std::vector<int64_t> x_dims_vec_cut(x_dims_vec.begin(),
                                        x_dims_vec.end() - 2);
    std::vector<int64_t> y_dims_vec_cut(y_dims_vec.begin(),
                                        y_dims_vec.end() - 2);
    std::vector<int64_t> expand_batch_portion =
        get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut);
    std::vector<int64_t> y_broadcast_dims({expand_batch_portion});
625 626 627
    y_broadcast_dims.insert(
        y_broadcast_dims.end(),
        {y_dims_vec[y_dims_n - 2], y_dims_vec[y_dims_n - 1]});
628 629
    std::vector<int> out_shape(y_broadcast_dims.begin(),
                               y_broadcast_dims.end());
630 631
    return CreateOpRunAndReturnTensor(
        "triangular_solve", inputs, attrs, out_shape);
632 633
  }

634 635 636
  phi::DenseTensor ConcatTwoTensors(const phi::DenseTensor& x,
                                    const phi::DenseTensor& y,
                                    int axis) {
637 638 639 640 641 642 643 644 645 646 647 648
    framework::AttributeMap attrs;
    attrs["axis"] = axis;
    std::vector<framework::DDim> inputs_dims({x.dims(), y.dims()});
    NameInTensorMap inputs({{"X", {&x, &y}}});
    size_t axis_ =
        ComputeAxisForConcatOp(static_cast<int64_t>(axis),
                               static_cast<int64_t>(inputs_dims[0].size()));
    framework::DDim out_dims =
        ComputeAndCheckShapeForConcatOp(true, inputs_dims, axis_);
    if (out_dims[axis_] < 0) {
      out_dims[axis_] = -1;
    }
649
    std::vector<int> out_shape = phi::vectorize<int>(out_dims);
650 651 652
    return CreateOpRunAndReturnTensor("concat", inputs, attrs, out_shape);
  }

653
  Tensor Conj(const phi::DenseTensor& x) {
654 655 656 657
    Tensor out;
    auto* out_data = out.mutable_data<T>(x.dims(), context.GetPlace());
    auto* x_data = x.data<T>();
    auto for_range = GetForRange(x.numel());
658
    phi::funcs::ConjFunctor<T> functor(x_data, x.numel(), out_data);
659 660 661 662
    for_range(functor);
    return out;
  }

663
  Tensor Real(const phi::DenseTensor& x) {
L
Lijunhui 已提交
664 665
    Tensor out;
    auto numel = x.numel();
666
    auto* out_data = out.mutable_data<phi::dtype::Real<T>>(
667 668
        x.dims(),
        context.GetPlace(),
669
        static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
L
Lijunhui 已提交
670 671
    auto* x_data = x.data<T>();
    auto for_range = GetForRange(numel);
672
    phi::funcs::RealFunctor<T> functor(x_data, out_data, numel);
L
Lijunhui 已提交
673 674 675 676
    for_range(functor);
    return out;
  }

677 678 679 680
  Tensor DiagFill(const int m,
                  const int n,
                  const int num_lower_diags,
                  const int num_upper_diags,
681 682
                  const phi::DenseTensor& scale,
                  const phi::DenseTensor& input) {
683 684 685 686
    Tensor out;
    auto& dev_ctx = context.template device_context<DeviceContext>();
    platform::ForRange<DeviceContext> for_range(dev_ctx, input.numel());
    DiagAndFillFunctor<T, ValueType> diag_and_copy_functor(
687 688 689 690 691 692 693
        m,
        n,
        num_lower_diags,
        num_upper_diags,
        scale.data<ValueType>(),
        input.data<T>(),
        out.mutable_data<T>(input.dims(), input.place()));
694 695 696 697
    for_range(diag_and_copy_functor);
    return out;
  }

698 699
 private:
  const framework::ExecutionContext& context;
700 701
  phi::funcs::BlasT<DeviceContext, T> GetBlas() {
    return phi::funcs::GetBlas<DeviceContext, T>(context);
702 703 704 705 706
  }
  platform::ForRange<DeviceContext> GetForRange(int numel) {
    auto& dev_ctx = context.template device_context<DeviceContext>();
    return platform::ForRange<DeviceContext>(dev_ctx, numel);
  }
707
  template <size_t D>
708
  void EigenSliceWrapper(const phi::DenseTensor* in,
709
                         const std::vector<int>& start,
710
                         const std::vector<int>& end,
711
                         phi::DenseTensor* out) {
712 713
    // Slice by call Eigen Tensor Function `.slice()`
    size_t rank = in->dims().size();
714 715
    PADDLE_ENFORCE_EQ(start.size(),
                      rank,
716 717 718
                      platform::errors::InvalidArgument(
                          "EigenSliceWrapper function start "
                          "argument must have the same length as input rank."));
719 720
    PADDLE_ENFORCE_EQ(end.size(),
                      rank,
721 722 723 724 725 726 727 728 729 730 731 732 733 734
                      platform::errors::InvalidArgument(
                          "EigenSliceWrapper function end "
                          "argument must have the same length as input rank."));
    auto eigen_place_ptr =
        context.template device_context<DeviceContext>().eigen_device();
    auto eigen_place = *eigen_place_ptr;
    auto out_t = framework::EigenTensor<T, D>::From(*out, out->dims());
    auto in_t = framework::EigenTensor<T, D>::From(*in, in->dims());
    Eigen::DSizes<int, D> offsets_32bit, extents_32bit;
    for (size_t i = 0; i < D; i++) {
      offsets_32bit[i] = start[i];
      extents_32bit[i] = end[i];
    }
    EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
735 736 737 738 739
        eigen_place,
        framework::To32BitIndex(out_t),
        framework::To32BitIndex(in_t),
        offsets_32bit,
        extents_32bit);
740
  }
741
  phi::DenseTensor CreateOpRunAndReturnTensor(
742 743 744 745
      const std::string& type,
      const NameInTensorMap& inputs,
      const framework::AttributeMap& attrs,
      std::vector<int> out_shape,
746
      NameOutTensor out_str = {"Out"}) {
747
    // varialble set dims must be phi::DenseTensor / SelectedRowTensor
748 749 750
    framework::Scope& local_scope = context.scope().NewScope();
    framework::VariableNameMap op_outputs;
    for (auto out_name : out_str) {
751
      local_scope.Var("tmp_" + out_name)->GetMutable<phi::DenseTensor>();
752 753 754 755
      op_outputs[out_name].emplace_back("tmp_" + out_name);
    }
    auto out_var = local_scope.Var("tmp_Out");  // return the Out
    // create Out Tensor and allocat memory
756
    out_var->GetMutable<phi::DenseTensor>()->mutable_data<T>(
757 758
        phi::make_ddim(out_shape), context.GetPlace());
    // phi::make_ddim(out_shape)
759 760 761 762 763 764 765 766 767
    framework::VariableNameMap op_inputs;
    int counter = 0;
    for (auto item : inputs) {
      auto& tensors = item.second;
      std::vector<std::string> name_vector;
      for (auto each_tensor : tensors) {
        // create score variable and reset the tensor.
        std::string _name = "tmp" + std::to_string(counter++);
        auto in_var = local_scope.Var(_name);  // create
768
        phi::DenseTensor tmp_tns;
769
        tmp_tns.ShareDataWith(*each_tensor);  // tensor -> lodtensor
770
        (*in_var->GetMutable<phi::DenseTensor>()) =
771 772 773 774 775
            tmp_tns;  // initialize and set value
        name_vector.emplace_back(_name);
      }
      op_inputs[item.first] = name_vector;
    }
776

777 778 779
    auto op =
        framework::OpRegistry::CreateOp(type, op_inputs, op_outputs, attrs);
    op->Run(local_scope, context.GetPlace());
780
    phi::DenseTensor out;
781
    out.ShareDataWith(*(out_var->GetMutable<phi::DenseTensor>()));
782
    out.Resize(phi::make_ddim(out_shape));
783 784 785 786 787 788 789
    context.scope().DeleteScope(&local_scope);
    return out;
  }
};
}  // namespace math
}  // namespace operators
}  // namespace paddle