math_kernel.cc 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/math_kernel.h"
16

17 18
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
19

20
namespace phi {
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

template <typename T, typename Context>
void MeanKernel(const Context& dev_ctx,
                const DenseTensor& x,
                const std::vector<int64_t>& dims,
                bool keep_dim,
                DenseTensor* out) {
  bool reduce_all = false;
  MeanRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}

template <typename T, typename Context>
void SumKernel(const Context& dev_ctx,
               const DenseTensor& x,
               const std::vector<int64_t>& dims,
               DataType out_dtype,
37
               bool keep_dim,
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
               DenseTensor* out) {
  bool reduce_all = false;
  SumRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out);
}

template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
               const DenseTensor& x,
               const DenseTensor& y,
               DenseTensor* out) {
  int axis = -1;
  AddRawKernel<T>(dev_ctx, x, y, axis, out);
}

template <typename T, typename Context>
void SubtractKernel(const Context& dev_ctx,
                    const DenseTensor& x,
                    const DenseTensor& y,
                    DenseTensor* out) {
  int axis = -1;
  SubtractRawKernel<T>(dev_ctx, x, y, axis, out);
}

template <typename T, typename Context>
void DivideKernel(const Context& dev_ctx,
                  const DenseTensor& x,
                  const DenseTensor& y,
                  DenseTensor* out) {
  int axis = -1;
  DivideRawKernel<T>(dev_ctx, x, y, axis, out);
}

template <typename T, typename Context>
void MultiplyKernel(const Context& dev_ctx,
                    const DenseTensor& x,
                    const DenseTensor& y,
                    DenseTensor* out) {
  int axis = -1;
  MultiplyRawKernel<T>(dev_ctx, x, y, axis, out);
}

79
}  // namespace phi
80

81 82
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
83 84

PT_REGISTER_KERNEL(
85
    mean, CPU, ALL_LAYOUT, phi::MeanKernel, float, double, bool) {}
86 87 88 89

PT_REGISTER_KERNEL(sum,
                   CPU,
                   ALL_LAYOUT,
90
                   phi::SumKernel,
91 92 93
                   bool,
                   float,
                   double,
94
                   phi::dtype::float16,
95
                   int16_t,
96 97 98 99 100 101 102 103 104 105
                   int,
                   int64_t,
                   complex64,
                   complex128) {
  kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}

PT_REGISTER_KERNEL(add,
                   CPU,
                   ALL_LAYOUT,
106
                   phi::AddKernel,
107 108
                   float,
                   double,
109
                   int16_t,
110 111 112 113 114 115 116
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
PT_REGISTER_KERNEL(subtract,
                   CPU,
                   ALL_LAYOUT,
117
                   phi::SubtractKernel,
118 119
                   float,
                   double,
120
                   int16_t,
121 122 123 124 125 126 127
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
PT_REGISTER_KERNEL(divide,
                   CPU,
                   ALL_LAYOUT,
128
                   phi::DivideKernel,
129 130 131 132 133 134 135 136 137
                   float,
                   double,
                   int,
                   int64_t,
                   complex64,
                   complex128) {}
PT_REGISTER_KERNEL(multiply,
                   CPU,
                   ALL_LAYOUT,
138
                   phi::MultiplyKernel,
139 140 141 142 143 144 145 146 147 148 149 150
                   float,
                   double,
                   int,
                   int64_t,
                   bool,
                   complex64,
                   complex128) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_KERNEL(mean,
                   GPU,
                   ALL_LAYOUT,
151
                   phi::MeanKernel,
152 153 154
                   float,
                   double,
                   bool,
155 156
                   int,
                   int64_t,
157
                   phi::dtype::float16) {}
158 159 160
PT_REGISTER_KERNEL(sum,
                   GPU,
                   ALL_LAYOUT,
161
                   phi::SumKernel,
162 163 164
                   bool,
                   float,
                   double,
165
                   phi::dtype::float16,
166
                   int16_t,
167 168 169 170 171 172 173 174 175
                   int,
                   int64_t,
                   complex64,
                   complex128) {
  kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
PT_REGISTER_KERNEL(add,
                   GPU,
                   ALL_LAYOUT,
176
                   phi::AddKernel,
177 178
                   float,
                   double,
179
                   int16_t,
180 181
                   int,
                   int64_t,
182
                   phi::dtype::float16,
183 184 185 186 187
                   complex64,
                   complex128) {}
PT_REGISTER_KERNEL(subtract,
                   GPU,
                   ALL_LAYOUT,
188
                   phi::SubtractKernel,
189 190
                   float,
                   double,
191
                   int16_t,
192 193
                   int,
                   int64_t,
194
                   phi::dtype::float16,
195 196 197 198 199
                   complex64,
                   complex128) {}
PT_REGISTER_KERNEL(divide,
                   GPU,
                   ALL_LAYOUT,
200
                   phi::DivideKernel,
201 202 203 204
                   float,
                   double,
                   int,
                   int64_t,
205
                   phi::dtype::float16,
206 207 208 209 210
                   complex64,
                   complex128) {}
PT_REGISTER_KERNEL(multiply,
                   GPU,
                   ALL_LAYOUT,
211
                   phi::MultiplyKernel,
212 213 214 215 216
                   float,
                   double,
                   int,
                   int64_t,
                   bool,
217
                   phi::dtype::float16,
218 219 220
                   complex64,
                   complex128) {}
#endif
新手
引导
客服 返回
顶部