math_kernel.cc 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/math_kernel.h"
16

17 18 19 20 21
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
22
#include "paddle/phi/kernels/funcs/elementwise_base.h"
23
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
24

25 26
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/eigen.h"
27 28
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
29

30
namespace phi {
31

32 33
#define DEFINE_CPU_ELEMENTWISE_OP(name)                                     \
  template <typename T, typename Context>                                   \
34 35 36 37 38
  void name##RawKernel(const Context& dev_ctx,                              \
                       const DenseTensor& x,                                \
                       const DenseTensor& y,                                \
                       int axis,                                            \
                       DenseTensor* out) {                                  \
39
    dev_ctx.template Alloc<T>(out);                                         \
40 41 42 43 44 45 46
    if (x.dims() == y.dims()) {                                             \
      SameDimsElementwiseCompute<SameDims##name##Functor<CPUContext, T>>()( \
          dev_ctx, x, y, out);                                              \
    } else {                                                                \
      auto x_dims = x.dims();                                               \
      auto y_dims = y.dims();                                               \
      if (x_dims.size() >= y_dims.size()) {                                 \
47
        funcs::ElementwiseCompute<funcs::name##Functor<T>, T>(              \
48 49
            dev_ctx, x, y, axis, funcs::name##Functor<T>(), out);           \
      } else {                                                              \
50
        funcs::ElementwiseCompute<funcs::Inverse##name##Functor<T>, T>(     \
51 52 53
            dev_ctx, x, y, axis, funcs::Inverse##name##Functor<T>(), out);  \
      }                                                                     \
    }                                                                       \
54 55
  }

56 57 58 59 60 61
template <typename T, typename Context>
void DivideRawKernel(const Context& dev_ctx,
                     const DenseTensor& x,
                     const DenseTensor& y,
                     int axis,
                     DenseTensor* out) {
62
  // allocate memory for out
63
  dev_ctx.template Alloc<T>(out);
64
  if (x.dims() == y.dims() && std::is_floating_point<T>::value) {
65
    SameDimsElementwiseCompute<SameDimsDivideFunctor<CPUContext, T>>()(
66 67 68 69 70
        dev_ctx, x, y, out);
  } else {
    auto x_dims = x.dims();
    auto y_dims = y.dims();
    if (x_dims.size() >= y_dims.size()) {
71
      funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
72
          dev_ctx, x, y, axis, funcs::DivideFunctor<T>(), out);
73
    } else {
74
      funcs::ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>(
75
          dev_ctx, x, y, axis, funcs::InverseDivideFunctor<T>(), out);
76 77 78 79 80 81 82 83 84 85 86 87 88
    }
  }
}

// Create the definition of Add
DEFINE_CPU_ELEMENTWISE_OP(Add)

// Create the definition of Subtract
DEFINE_CPU_ELEMENTWISE_OP(Subtract)

// Create the definition of Multiply
DEFINE_CPU_ELEMENTWISE_OP(Multiply)

89
}  // namespace phi
90

91 92
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
93 94

// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
95
// using bfloat16 = ::phi::dtype::bfloat16;
96
PD_REGISTER_KERNEL(add_raw,
97 98
                   CPU,
                   ALL_LAYOUT,
99
                   phi::AddRawKernel,
100 101
                   float,
                   double,
102
                   int16_t,
103 104 105 106
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
107
PD_REGISTER_KERNEL(subtract_raw,
108 109
                   CPU,
                   ALL_LAYOUT,
110
                   phi::SubtractRawKernel,
111 112
                   float,
                   double,
113
                   int16_t,
114 115 116
                   int,
                   int64_t,
                   complex64,
117 118
                   complex128,
                   phi::dtype::bfloat16) {}
119
PD_REGISTER_KERNEL(divide_raw,
120 121
                   CPU,
                   ALL_LAYOUT,
122
                   phi::DivideRawKernel,
123 124 125 126 127 128
                   float,
                   double,
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
129
PD_REGISTER_KERNEL(multiply_raw,
130 131
                   CPU,
                   ALL_LAYOUT,
132
                   phi::MultiplyRawKernel,
133 134 135 136 137 138
                   float,
                   double,
                   int,
                   int64_t,
                   bool,
                   complex64,
139 140
                   complex128,
                   phi::dtype::bfloat16) {}