arg_min_max_kernel.cc 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/arg_min_max_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

enum ArgMinMaxType { kArgMin, kArgMax };

template <typename Context,
          typename T,
          typename Tout,
          int64_t Rank,
          ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};

#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value)  \
  template <typename Context, typename T, typename Tout, int64_t Rank>    \
  struct ArgMinMaxFunctor<Context, T, Tout, Rank, enum_argminmax_value> { \
    void operator()(const Context& dev_ctx,                               \
                    const DenseTensor& in,                                \
                    DenseTensor* out,                                     \
                    phi::DDim x_dims,                                     \
                    int64_t axis,                                         \
                    bool keepdims) {                                      \
      auto in_eigen = EigenTensor<T, Rank>::From(in, x_dims);             \
      if (keepdims) {                                                     \
        auto out_eigen = EigenTensor<Tout, Rank>::From(*out);             \
        out_eigen.device(*(dev_ctx.eigen_device())) =                     \
            in_eigen.eigen_op_type(axis).template cast<Tout>();           \
      } else {                                                            \
        auto out_eigen = EigenTensor<Tout, Rank - 1>::From(*out);         \
        out_eigen.device(*(dev_ctx.eigen_device())) =                     \
            in_eigen.eigen_op_type(axis).template cast<Tout>();           \
      }                                                                   \
    }                                                                     \
  }

DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);

template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
struct VisitDataArgMinMaxFunctor {
  const Context& dev_ctx;
  const DenseTensor& x;
  int64_t axis;
  bool keepdims;
  bool flatten;
  DenseTensor* out;

  explicit VisitDataArgMinMaxFunctor(const Context& dev_ctx,
                                     const DenseTensor& x,
                                     int64_t axis,
                                     bool keepdims,
                                     bool flatten,
                                     DenseTensor* out)
      : dev_ctx(dev_ctx),
        x(x),
        axis(axis),
        keepdims(keepdims),
        flatten(flatten),
        out(out) {}
  template <typename Tout>
  void apply() const {
    dev_ctx.template Alloc<Tout>(out);
    bool new_keepdims = keepdims;
    if (flatten) new_keepdims = true;

    // if flatten, will construct the new dims for the cacluate
    phi::DDim x_dims;
    int new_axis = axis;
    if (flatten) {
      x_dims = phi::make_ddim({x.numel()});
      // if flatten, the axis just as 0
      new_axis = 0;
    } else {
      x_dims = x.dims();
      if (axis < 0) new_axis = axis + x_dims.size();
    }

#define CALL_ARG_MINMAX_FUNCTOR(rank)                                         \
  ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
  functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims)

    switch (x_dims.size()) {
      case 1:
        CALL_ARG_MINMAX_FUNCTOR(1);
        break;
      case 2:
        CALL_ARG_MINMAX_FUNCTOR(2);
        break;
      case 3:
        CALL_ARG_MINMAX_FUNCTOR(3);
        break;
      case 4:
        CALL_ARG_MINMAX_FUNCTOR(4);
        break;
      case 5:
        CALL_ARG_MINMAX_FUNCTOR(5);
        break;
      case 6:
        CALL_ARG_MINMAX_FUNCTOR(6);
        break;
      default:
        PADDLE_ENFORCE_LE(
            x_dims.size(),
            6,
            phi::errors::InvalidArgument(
                "%s operator doesn't supports tensors whose ranks are greater "
                "than 6.",
                (EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")));
        break;
#undef CALL_ARG_MINMAX_FUNCTOR
    }
  }
};

template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMaxKernel(const Context& dev_ctx,
                     const DenseTensor& x,
138
                     const Scalar& axis,
139 140 141 142 143 144 145 146 147
                     bool keepdims,
                     bool flatten,
                     int dtype,
                     DenseTensor* out) {
  if (dtype < 0) {
    paddle::framework::VisitDataTypeTiny(
        static_cast<paddle::framework::proto::VarType::Type>(
            paddle::framework::proto::VarType::INT64),
        VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
148
            dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
149 150 151 152 153
    return;
  }
  paddle::framework::VisitDataTypeTiny(
      static_cast<paddle::framework::proto::VarType::Type>(dtype),
      VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
154
          dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
155 156 157 158 159
}

template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
                  const DenseTensor& x,
160
                  const Scalar& axis,
161 162 163 164 165 166 167 168 169 170 171
                  bool keepdims,
                  bool flatten,
                  int dtype,
                  DenseTensor* out) {
  ArgMinMaxKernel<Context, T, ArgMinMaxType::kArgMin>(
      dev_ctx, x, axis, keepdims, flatten, dtype, out);
}

template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
                  const DenseTensor& x,
172
                  const Scalar& axis,
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
                  bool keepdims,
                  bool flatten,
                  int dtype,
                  DenseTensor* out) {
  ArgMinMaxKernel<Context, T, ArgMinMaxType::kArgMax>(
      dev_ctx, x, axis, keepdims, flatten, dtype, out);
}

}  // namespace phi

PD_REGISTER_KERNEL(arg_min,
                   CPU,
                   ALL_LAYOUT,
                   phi::ArgMinKernel,
                   float,
                   double,
                   int32_t,
                   int64_t,
                   int16_t,
                   uint8_t) {}

PD_REGISTER_KERNEL(arg_max,
                   CPU,
                   ALL_LAYOUT,
                   phi::ArgMaxKernel,
                   float,
                   double,
                   int32_t,
                   int64_t,
                   int16_t,
                   uint8_t) {}