math_kernel.cc 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/math_kernel.h"
16

17 18 19 20 21 22 23 24
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
25

26 27
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/eigen.h"
28 29
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
30

31
namespace phi {
32

33 34
#define DEFINE_CPU_ELEMENTWISE_OP(name)                                     \
  template <typename T, typename Context>                                   \
35 36 37 38 39
  void name##RawKernel(const Context& dev_ctx,                              \
                       const DenseTensor& x,                                \
                       const DenseTensor& y,                                \
                       int axis,                                            \
                       DenseTensor* out) {                                  \
40
    dev_ctx.template Alloc<T>(out);                                         \
41 42 43 44 45 46 47 48 49 50 51 52 53 54
    if (x.dims() == y.dims()) {                                             \
      SameDimsElementwiseCompute<SameDims##name##Functor<CPUContext, T>>()( \
          dev_ctx, x, y, out);                                              \
    } else {                                                                \
      auto x_dims = x.dims();                                               \
      auto y_dims = y.dims();                                               \
      if (x_dims.size() >= y_dims.size()) {                                 \
        ElementwiseCompute<funcs::name##Functor<T>, T>(                     \
            dev_ctx, x, y, axis, funcs::name##Functor<T>(), out);           \
      } else {                                                              \
        ElementwiseCompute<funcs::Inverse##name##Functor<T>, T>(            \
            dev_ctx, x, y, axis, funcs::Inverse##name##Functor<T>(), out);  \
      }                                                                     \
    }                                                                       \
55 56 57
  }

template <typename T, typename Context>
58 59 60 61 62 63
void MeanRawKernel(const Context& dev_ctx,
                   const DenseTensor& x,
                   const std::vector<int64_t>& dims,
                   bool keep_dim,
                   bool reduce_all,
                   DenseTensor* out) {
64
  auto out_dtype = x.dtype();
65
  phi::Reduce<CPUContext, T, phi::funcs::MeanFunctor>(
66 67 68 69
      dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}

template <typename T, typename Context>
70
void SumRawKernel(const Context& dev_ctx,
71
                  const DenseTensor& x,
72 73 74 75
                  const std::vector<int64_t>& dims,
                  bool keep_dim,
                  bool reduce_all,
                  DataType out_dtype,
76
                  DenseTensor* out) {
77
  phi::Reduce<CPUContext, T, phi::funcs::SumFunctor>(
78 79 80 81 82 83 84 85 86
      dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}

template <typename T, typename Context>
void DivideRawKernel(const Context& dev_ctx,
                     const DenseTensor& x,
                     const DenseTensor& y,
                     int axis,
                     DenseTensor* out) {
87
  // allocate memory for out
88
  dev_ctx.template Alloc<T>(out);
89
  if (x.dims() == y.dims() && std::is_floating_point<T>::value) {
90
    SameDimsElementwiseCompute<SameDimsDivideFunctor<CPUContext, T>>()(
91 92 93 94 95
        dev_ctx, x, y, out);
  } else {
    auto x_dims = x.dims();
    auto y_dims = y.dims();
    if (x_dims.size() >= y_dims.size()) {
96 97
      ElementwiseCompute<funcs::DivideFunctor<T>, T>(
          dev_ctx, x, y, axis, funcs::DivideFunctor<T>(), out);
98
    } else {
99 100
      ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>(
          dev_ctx, x, y, axis, funcs::InverseDivideFunctor<T>(), out);
101 102 103 104 105 106 107 108 109 110 111 112 113
    }
  }
}

// Create the definition of Add
DEFINE_CPU_ELEMENTWISE_OP(Add)

// Create the definition of Subtract
DEFINE_CPU_ELEMENTWISE_OP(Subtract)

// Create the definition of Multiply
DEFINE_CPU_ELEMENTWISE_OP(Multiply)

114
}  // namespace phi
115

116 117
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
118 119

// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
120
// using bfloat16 = ::phi::dtype::bfloat16;
121
PD_REGISTER_KERNEL(add_raw,
122 123
                   CPU,
                   ALL_LAYOUT,
124
                   phi::AddRawKernel,
125 126
                   float,
                   double,
127
                   int16_t,
128 129 130 131
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
132
PD_REGISTER_KERNEL(subtract_raw,
133 134
                   CPU,
                   ALL_LAYOUT,
135
                   phi::SubtractRawKernel,
136 137
                   float,
                   double,
138
                   int16_t,
139 140 141
                   int,
                   int64_t,
                   complex64,
142 143
                   complex128,
                   phi::dtype::bfloat16) {}
144
PD_REGISTER_KERNEL(divide_raw,
145 146
                   CPU,
                   ALL_LAYOUT,
147
                   phi::DivideRawKernel,
148 149 150 151 152 153
                   float,
                   double,
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
154
PD_REGISTER_KERNEL(multiply_raw,
155 156
                   CPU,
                   ALL_LAYOUT,
157
                   phi::MultiplyRawKernel,
158 159 160 161 162 163
                   float,
                   double,
                   int,
                   int64_t,
                   bool,
                   complex64,
164 165
                   complex128,
                   phi::dtype::bfloat16) {}
166
PD_REGISTER_KERNEL(sum_raw,
167 168
                   CPU,
                   ALL_LAYOUT,
169
                   phi::SumRawKernel,
170 171 172
                   bool,
                   float,
                   double,
173
                   phi::dtype::float16,
174
                   int16_t,
175 176 177 178
                   int,
                   int64_t,
                   complex64,
                   complex128) {
179 180
  kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
181
PD_REGISTER_KERNEL(
182
    mean_raw, CPU, ALL_LAYOUT, phi::MeanRawKernel, float, double, bool) {}