math_kernel.cc 6.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/math_kernel.h"
16

17 18
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
19

20
namespace phi {
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

template <typename T, typename Context>
void MeanKernel(const Context& dev_ctx,
                const DenseTensor& x,
                const std::vector<int64_t>& dims,
                bool keep_dim,
                DenseTensor* out) {
  bool reduce_all = false;
  MeanRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}

template <typename T, typename Context>
void SumKernel(const Context& dev_ctx,
               const DenseTensor& x,
               const std::vector<int64_t>& dims,
               DataType out_dtype,
37
               bool keep_dim,
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
               DenseTensor* out) {
  bool reduce_all = false;
  SumRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out);
}

template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
               const DenseTensor& x,
               const DenseTensor& y,
               DenseTensor* out) {
  int axis = -1;
  AddRawKernel<T>(dev_ctx, x, y, axis, out);
}

template <typename T, typename Context>
void SubtractKernel(const Context& dev_ctx,
                    const DenseTensor& x,
                    const DenseTensor& y,
                    DenseTensor* out) {
  int axis = -1;
  SubtractRawKernel<T>(dev_ctx, x, y, axis, out);
}

template <typename T, typename Context>
void DivideKernel(const Context& dev_ctx,
                  const DenseTensor& x,
                  const DenseTensor& y,
                  DenseTensor* out) {
  int axis = -1;
  DivideRawKernel<T>(dev_ctx, x, y, axis, out);
}

template <typename T, typename Context>
void MultiplyKernel(const Context& dev_ctx,
                    const DenseTensor& x,
                    const DenseTensor& y,
                    DenseTensor* out) {
  int axis = -1;
  MultiplyRawKernel<T>(dev_ctx, x, y, axis, out);
}

79
}  // namespace phi
80

81 82
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
83

84
PD_REGISTER_KERNEL(
85
    mean, CPU, ALL_LAYOUT, phi::MeanKernel, float, double, bool) {}
86

87
PD_REGISTER_KERNEL(sum,
88 89
                   CPU,
                   ALL_LAYOUT,
90
                   phi::SumKernel,
91 92 93
                   bool,
                   float,
                   double,
94
                   phi::dtype::float16,
95
                   int16_t,
96 97 98 99 100 101 102
                   int,
                   int64_t,
                   complex64,
                   complex128) {
  kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}

103
PD_REGISTER_KERNEL(add,
104 105
                   CPU,
                   ALL_LAYOUT,
106
                   phi::AddKernel,
107 108
                   float,
                   double,
109
                   int16_t,
110 111 112 113
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
114
PD_REGISTER_KERNEL(subtract,
115 116
                   CPU,
                   ALL_LAYOUT,
117
                   phi::SubtractKernel,
118 119
                   float,
                   double,
120
                   int16_t,
121 122 123
                   int,
                   int64_t,
                   complex64,
124 125
                   complex128,
                   phi::dtype::bfloat16) {}
126
PD_REGISTER_KERNEL(divide,
127 128
                   CPU,
                   ALL_LAYOUT,
129
                   phi::DivideKernel,
130 131 132 133 134 135
                   float,
                   double,
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
136
PD_REGISTER_KERNEL(multiply,
137 138
                   CPU,
                   ALL_LAYOUT,
139
                   phi::MultiplyKernel,
140 141 142 143 144 145
                   float,
                   double,
                   int,
                   int64_t,
                   bool,
                   complex64,
146 147
                   complex128,
                   phi::dtype::bfloat16) {}
148 149

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
150
PD_REGISTER_KERNEL(mean,
151 152
                   GPU,
                   ALL_LAYOUT,
153
                   phi::MeanKernel,
154 155 156
                   float,
                   double,
                   bool,
157 158
                   int,
                   int64_t,
159
                   phi::dtype::float16) {}
160
PD_REGISTER_KERNEL(sum,
161 162
                   GPU,
                   ALL_LAYOUT,
163
                   phi::SumKernel,
164 165 166
                   bool,
                   float,
                   double,
167
                   phi::dtype::float16,
168
                   phi::dtype::bfloat16,
169
                   int16_t,
170 171 172 173 174 175
                   int,
                   int64_t,
                   complex64,
                   complex128) {
  kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
176
PD_REGISTER_KERNEL(add,
177 178
                   GPU,
                   ALL_LAYOUT,
179
                   phi::AddKernel,
180 181
                   float,
                   double,
182
                   int16_t,
183 184
                   int,
                   int64_t,
185
                   phi::dtype::float16,
186
                   phi::dtype::bfloat16,
187 188
                   complex64,
                   complex128) {}
189
PD_REGISTER_KERNEL(subtract,
190 191
                   GPU,
                   ALL_LAYOUT,
192
                   phi::SubtractKernel,
193 194
                   float,
                   double,
195
                   int16_t,
196 197
                   int,
                   int64_t,
198
                   phi::dtype::float16,
199 200
                   complex64,
                   complex128) {}
201
PD_REGISTER_KERNEL(divide,
202 203
                   GPU,
                   ALL_LAYOUT,
204
                   phi::DivideKernel,
205 206 207 208
                   float,
                   double,
                   int,
                   int64_t,
209
                   phi::dtype::float16,
210 211
                   complex64,
                   complex128) {}
212
PD_REGISTER_KERNEL(multiply,
213 214
                   GPU,
                   ALL_LAYOUT,
215
                   phi::MultiplyKernel,
216 217 218 219 220
                   float,
                   double,
                   int,
                   int64_t,
                   bool,
221
                   phi::dtype::float16,
222 223 224
                   complex64,
                   complex128) {}
#endif