math_kernel.cc 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pten/kernels/math_kernel.h"

#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/kernel_registry.h"
21
#include "paddle/pten/kernels/cpu/elementwise.h"
22
#include "paddle/pten/kernels/cpu/reduce.h"
23
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
24
#include "paddle/pten/kernels/funcs/reduce_functor.h"
25

26 27
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/eigen.h"
28 29
#include "paddle/pten/common/bfloat16.h"
#include "paddle/pten/common/complex.h"
30 31 32

namespace pten {

33 34
#define DEFINE_CPU_ELEMENTWISE_OP(name)                                     \
  template <typename T, typename Context>                                   \
35 36 37 38 39
  void name##RawKernel(const Context& dev_ctx,                              \
                       const DenseTensor& x,                                \
                       const DenseTensor& y,                                \
                       int axis,                                            \
                       DenseTensor* out) {                                  \
40
    out->mutable_data<T>(dev_ctx.GetPlace());                               \
41 42 43 44 45 46 47 48 49 50 51 52 53 54
    if (x.dims() == y.dims()) {                                             \
      SameDimsElementwiseCompute<SameDims##name##Functor<CPUContext, T>>()( \
          dev_ctx, x, y, out);                                              \
    } else {                                                                \
      auto x_dims = x.dims();                                               \
      auto y_dims = y.dims();                                               \
      if (x_dims.size() >= y_dims.size()) {                                 \
        ElementwiseCompute<funcs::name##Functor<T>, T>(                     \
            dev_ctx, x, y, axis, funcs::name##Functor<T>(), out);           \
      } else {                                                              \
        ElementwiseCompute<funcs::Inverse##name##Functor<T>, T>(            \
            dev_ctx, x, y, axis, funcs::Inverse##name##Functor<T>(), out);  \
      }                                                                     \
    }                                                                       \
55 56 57
  }

template <typename T, typename Context>
58 59 60 61 62 63
void MeanRawKernel(const Context& dev_ctx,
                   const DenseTensor& x,
                   const std::vector<int64_t>& dims,
                   bool keep_dim,
                   bool reduce_all,
                   DenseTensor* out) {
64
  auto out_dtype = x.dtype();
65
  pten::Reduce<CPUContext, T, pten::funcs::MeanFunctor>(
66 67 68 69
      dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}

template <typename T, typename Context>
70
void SumRawKernel(const Context& dev_ctx,
71
                  const DenseTensor& x,
72 73 74 75
                  const std::vector<int64_t>& dims,
                  bool keep_dim,
                  bool reduce_all,
                  DataType out_dtype,
76
                  DenseTensor* out) {
77 78 79 80 81 82 83 84 85 86
  pten::Reduce<CPUContext, T, pten::funcs::SumFunctor>(
      dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}

template <typename T, typename Context>
void DivideRawKernel(const Context& dev_ctx,
                     const DenseTensor& x,
                     const DenseTensor& y,
                     int axis,
                     DenseTensor* out) {
87
  // allocate memory for out
88
  out->mutable_data<T>(dev_ctx.GetPlace());
89
  if (x.dims() == y.dims() && std::is_floating_point<T>::value) {
90
    SameDimsElementwiseCompute<SameDimsDivideFunctor<CPUContext, T>>()(
91 92 93 94 95
        dev_ctx, x, y, out);
  } else {
    auto x_dims = x.dims();
    auto y_dims = y.dims();
    if (x_dims.size() >= y_dims.size()) {
96 97
      ElementwiseCompute<funcs::DivideFunctor<T>, T>(
          dev_ctx, x, y, axis, funcs::DivideFunctor<T>(), out);
98
    } else {
99 100
      ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>(
          dev_ctx, x, y, axis, funcs::InverseDivideFunctor<T>(), out);
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
    }
  }
}

// Create the definition of Add
DEFINE_CPU_ELEMENTWISE_OP(Add)

// Create the definition of Subtract
DEFINE_CPU_ELEMENTWISE_OP(Subtract)

// Create the definition of Multiply
DEFINE_CPU_ELEMENTWISE_OP(Multiply)

}  // namespace pten

using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;

// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
121
PT_REGISTER_KERNEL(add_raw,
122 123
                   CPU,
                   ALL_LAYOUT,
124
                   pten::AddRawKernel,
125 126 127 128 129 130
                   float,
                   double,
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
131
PT_REGISTER_KERNEL(subtract_raw,
132 133
                   CPU,
                   ALL_LAYOUT,
134
                   pten::SubtractRawKernel,
135 136 137 138 139 140
                   float,
                   double,
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
141
PT_REGISTER_KERNEL(divide_raw,
142 143
                   CPU,
                   ALL_LAYOUT,
144
                   pten::DivideRawKernel,
145 146 147 148 149 150
                   float,
                   double,
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
151
PT_REGISTER_KERNEL(multiply_raw,
152 153
                   CPU,
                   ALL_LAYOUT,
154
                   pten::MultiplyRawKernel,
155 156 157 158 159 160 161
                   float,
                   double,
                   int,
                   int64_t,
                   bool,
                   complex64,
                   complex128) {}
162
PT_REGISTER_KERNEL(sum_raw,
163 164
                   CPU,
                   ALL_LAYOUT,
165
                   pten::SumRawKernel,
166 167 168 169 170 171 172 173
                   bool,
                   float,
                   double,
                   paddle::platform::float16,
                   int,
                   int64_t,
                   complex64,
                   complex128) {
174 175
  kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
176 177
PT_REGISTER_KERNEL(
    mean_raw, CPU, ALL_LAYOUT, pten::MeanRawKernel, float, double, bool) {}