compare_kernel.cu 5.2 KB
Newer Older
F
From00 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
F
From00 已提交
17 18
#include "paddle/phi/kernels/impl/compare_kernel_impl.h"

19 20 21
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/phi/backends/xpu/xpu_context.h"
#else
F
From00 已提交
22
#include <thrust/fill.h>
23

F
From00 已提交
24
#include <vector>
25

F
From00 已提交
26
#include "paddle/phi/core/dense_tensor.h"
27
#include "paddle/phi/kernels/compare_kernel.h"
F
From00 已提交
28 29 30
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
31
#endif
F
From00 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

namespace phi {

template <typename T>
struct BitwiseAdd {
  // Bitwise add operator, returns <tt>a + b</tt>
  inline T initial() { return static_cast<T>(true); }

  __host__ __device__ __forceinline__ T operator()(const T& a,
                                                   const T& b) const {
    return a & b;
  }
};

template <typename T,
          typename Context,
          typename Functor,
          typename InverseFunctor>
inline void CompareKernelImpl(const Context& ctx,
                              const DenseTensor& x,
                              const DenseTensor& y,
                              int axis,
                              DenseTensor* out) {
55 56 57
  if (!out->IsSharedWith(x)) {
    ctx.template Alloc<bool>(out);
  }
F
From00 已提交
58 59
  std::vector<const DenseTensor*> ins{&x, &y};
  std::vector<DenseTensor*> outs{out};
60 61 62 63 64
  if (!out->IsSharedWith(x)) {
    funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
  } else {
    funcs::BroadcastKernel<T>(ctx, ins, &outs, Functor(), axis);
  }
F
From00 已提交
65 66
}

67
#ifndef PADDLE_WITH_XPU_KP
F
From00 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
template <typename T, typename Context, typename Functor>
inline void CompareAllKernelImpl(const Context& ctx,
                                 const DenseTensor& x,
                                 const DenseTensor& y,
                                 DenseTensor* out) {
  bool* out_data = ctx.template Alloc<bool>(out);

  if (x.dims() != y.dims()) {
    thrust::device_ptr<bool> out_dev_ptr(out_data);
    thrust::fill(out_dev_ptr, out_dev_ptr + 1, false);
    return;
  }

  DenseTensor tmp;
  tmp.Resize(x.dims());
  ctx.template Alloc<bool>(&tmp);

  std::vector<const DenseTensor*> ins{&x, &y};
  std::vector<DenseTensor*> outs{&tmp};
  funcs::ElementwiseKernel<bool>(ctx, ins, &outs, Functor());

  // Reduce by 'bitwise and' operator
  std::vector<int> reduce_dims;
  reduce_dims.resize(tmp.dims().size());
  for (int i = 0; i < reduce_dims.size(); ++i) {
    reduce_dims[i] = i;
  }
95 96
  funcs::ReduceKernel<bool, bool, BitwiseAdd, kps::IdentityFunctor<bool>>(
      ctx, tmp, out, kps::IdentityFunctor<bool>(), reduce_dims);
F
From00 已提交
97
}
98
#endif
F
From00 已提交
99 100 101

}  // namespace phi

102
#ifdef PADDLE_WITH_XPU_KP
103
PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) {
104
  kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
105 106
}
PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) {
107
  kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
108
}
109
PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, phi::GreaterThanKernel, int) {
110
  kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
111 112
}
PD_REGISTER_KERNEL(
113
    greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) {
114
  kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
115 116
}
PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) {
117
  kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
118 119
}
PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) {
120
  kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
121
}
122

123
#else
F
From00 已提交
124 125

PD_REGISTER_KERNEL(equal_all,
126
                   KPS,
F
From00 已提交
127 128 129 130 131 132
                   ALL_LAYOUT,
                   phi::EqualAllKernel,
                   bool,
                   int,
                   int64_t,
                   float,
133
                   double) {
134
  kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
135 136
}

137 138 139 140 141 142 143 144 145 146 147 148
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
  PD_REGISTER_KERNEL(name,                     \
                     KPS,                      \
                     ALL_LAYOUT,               \
                     phi::func##Kernel,        \
                     bool,                     \
                     int16_t,                  \
                     int,                      \
                     int64_t,                  \
                     float,                    \
                     double,                   \
                     phi::dtype::float16) {}
149 150 151 152 153 154 155 156

PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual)
PD_REGISTER_COMPARE_KERNEL(equal, Equal)
PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual)

157
#endif