expand_kernel_impl.h 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <vector>

20 21
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
22 23
#define MAX_RANK_SUPPORTED 6

24
namespace phi {
25 26 27 28 29 30 31 32 33
using Tensor = DenseTensor;

template <typename Context, typename T, int Rank>
void Expand(const Context& ctx,
            const DenseTensor& x,
            const ScalarArray& shape,
            DenseTensor* out) {
  auto in_dims = x.dims();
  auto expand_shape = shape.GetData();
34
  auto vec_in_dims = phi::vectorize<int>(in_dims);
35 36 37 38 39 40 41
  auto diff = expand_shape.size() - vec_in_dims.size();
  vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
  std::vector<int> repeat_times(vec_in_dims.size());
  for (size_t i = 0; i < vec_in_dims.size(); ++i) {
    PADDLE_ENFORCE_NE(
        expand_shape[i],
        0,
42
        phi::errors::InvalidArgument("The expanded size cannot be zero."));
43 44 45 46
    if (i < diff) {
      PADDLE_ENFORCE_GT(
          expand_shape[i],
          0,
47
          phi::errors::InvalidArgument(
48 49 50 51 52 53 54 55 56
              "The expanded size (%d) for non-existing dimensions must be "
              "positive for expand_v2 op.",
              expand_shape[i]));
      repeat_times[i] = expand_shape[i];
    } else if (expand_shape[i] > 0) {
      if (vec_in_dims[i] != 1) {
        PADDLE_ENFORCE_EQ(
            vec_in_dims[i],
            expand_shape[i],
57
            phi::errors::InvalidArgument(
58 59 60 61 62 63 64 65 66 67 68 69
                "The value (%d) of the non-singleton dimension does not match"
                " the corresponding value (%d) in shape for expand_v2 op.",
                vec_in_dims[i],
                expand_shape[i]));
        repeat_times[i] = 1;
      } else {
        repeat_times[i] = expand_shape[i];
      }
    } else {
      PADDLE_ENFORCE_EQ(
          expand_shape[i],
          -1,
70
          phi::errors::InvalidArgument(
71 72 73 74 75 76 77 78 79 80 81 82
              "When the value in shape is negative for expand_v2 op, "
              "only -1 is supported, but the value received is %d.",
              expand_shape[i]));
      repeat_times[i] = 1;
    }
  }

  Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
  for (size_t i = 0; i < repeat_times.size(); ++i) {
    bcast_dims[i] = repeat_times[i];
  }

83
  DDim new_in_dims = phi::make_ddim(vec_in_dims);
84
  DDim out_dims(new_in_dims);
85 86 87 88 89 90 91 92 93 94 95 96 97 98
  for (size_t i = 0; i < repeat_times.size(); ++i) {
    out_dims[i] *= repeat_times[i];
  }

  out->Resize(out_dims);
  auto x0 = EigenTensor<T, Rank>::From(x, new_in_dims);
  ctx.template Alloc<T>(out);
  out->data<T>();

  auto y = EigenTensor<T, Rank>::From(*out, out_dims);
  auto& place = *ctx.eigen_device();
  // use 32-bit index to speed up
  bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
  if (use_32bit_index) {
99
    phi::funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
100 101
        place, To32BitIndex(y), To32BitIndex(x0), bcast_dims);
  } else {
102
    phi::funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
103 104 105 106 107 108 109 110 111 112 113 114 115
        place, y, x0, bcast_dims);
  }
}

template <typename T, typename Context>
void ExpandKernel(const Context& ctx,
                  const DenseTensor& x,
                  const ScalarArray& shape,
                  DenseTensor* out) {
  auto rank = x.dims().size();
  PADDLE_ENFORCE_GE(
      rank,
      1,
116
      phi::errors::InvalidArgument(
117 118 119 120 121 122
          "The rank of the input 'X' for expand_v2 op must be positive, "
          "but the value received is %d.",
          rank));
  PADDLE_ENFORCE_LE(
      rank,
      MAX_RANK_SUPPORTED,
123
      phi::errors::InvalidArgument(
124 125 126 127 128 129 130 131 132
          "The rank of the input 'X' for expand_v2 op must be less than "
          "or equal to %d, but the value received is %d.",
          MAX_RANK_SUPPORTED,
          rank));
  auto expand_shape = shape.GetData();
  auto shape_size = expand_shape.size();
  PADDLE_ENFORCE_GE(
      shape_size,
      rank,
133
      phi::errors::InvalidArgument(
134 135 136 137 138 139 140
          "The number (%d) of elements of 'shape' for expand_v2 op must be "
          "greater than or equal to the rank (%d) of the input 'X'.",
          shape_size,
          rank));
  PADDLE_ENFORCE_LE(
      shape_size,
      MAX_RANK_SUPPORTED,
141
      phi::errors::InvalidArgument(
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
          "The number (%d) of elements of 'shape' for expand_v2 op must be "
          "less than or equal to %d.",
          shape_size,
          MAX_RANK_SUPPORTED));
  rank = std::max(rank, static_cast<int>(shape_size));
  switch (rank) {
    case 1:
      Expand<Context, T, 1>(ctx, x, shape, out);
      break;
    case 2:
      Expand<Context, T, 2>(ctx, x, shape, out);
      break;
    case 3:
      Expand<Context, T, 3>(ctx, x, shape, out);
      break;
    case 4:
      Expand<Context, T, 4>(ctx, x, shape, out);
      break;
    case 5:
      Expand<Context, T, 5>(ctx, x, shape, out);
      break;
    case 6:
      Expand<Context, T, 6>(ctx, x, shape, out);
      break;
  }
}

169
}  // namespace phi