math_kernel.cu 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/math_kernel.h"
16

17
#include "paddle/phi/backends/gpu/gpu_context.h"
18
#include "paddle/phi/kernels/funcs/broadcast_function.h"
19 20
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/gpu/reduce.h"
21 22 23 24 25 26 27 28 29

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

30 31 32 33 34
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
35

36
namespace phi {
37

38 39
#define DEFINE_CUDA_ELEMENTWISE_OP(name)                             \
  template <typename T, typename Context>                            \
40 41 42 43 44
  void name##RawKernel(const Context& dev_ctx,                       \
                       const DenseTensor& x,                         \
                       const DenseTensor& y,                         \
                       int axis,                                     \
                       DenseTensor* out) {                           \
45 46 47 48 49
    std::vector<const DenseTensor*> inputs;                          \
    std::vector<DenseTensor*> outputs;                               \
    inputs.emplace_back(&x);                                         \
    inputs.emplace_back(&y);                                         \
    outputs.emplace_back(out);                                       \
50
    dev_ctx.template Alloc<T>(out);                                  \
51
    funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(          \
52
        dev_ctx, inputs, &outputs, axis, funcs::name##Functor<T>()); \
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
  }

/**
 * Kernels
 */

// Create the definition of Add
DEFINE_CUDA_ELEMENTWISE_OP(Add)
// Create the definition of Subtract
DEFINE_CUDA_ELEMENTWISE_OP(Subtract)
// Create the definition of Multiply
DEFINE_CUDA_ELEMENTWISE_OP(Multiply)
// Create the definition of Divide
DEFINE_CUDA_ELEMENTWISE_OP(Divide)

68
}  // namespace phi
69

70
using float16 = phi::dtype::float16;
71
using bfloat16 = phi::dtype::bfloat16;
72 73
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
74

75
PD_REGISTER_KERNEL(add_raw,
76 77
                   GPU,
                   ALL_LAYOUT,
78
                   phi::AddRawKernel,
79 80
                   float,
                   double,
81
                   int16_t,
82 83 84
                   int,
                   int64_t,
                   float16,
85
                   bfloat16,
86 87
                   complex64,
                   complex128) {}
88
PD_REGISTER_KERNEL(subtract_raw,
89 90
                   GPU,
                   ALL_LAYOUT,
91
                   phi::SubtractRawKernel,
92 93
                   float,
                   double,
94
                   int16_t,
95 96 97
                   int,
                   int64_t,
                   float16,
98
                   bfloat16,
99 100
                   complex64,
                   complex128) {}
101
PD_REGISTER_KERNEL(divide_raw,
102 103
                   GPU,
                   ALL_LAYOUT,
104
                   phi::DivideRawKernel,
105 106 107 108 109
                   float,
                   double,
                   int,
                   int64_t,
                   float16,
110
                   bfloat16,
111 112
                   complex64,
                   complex128) {}
113
PD_REGISTER_KERNEL(multiply_raw,
114 115
                   GPU,
                   ALL_LAYOUT,
116
                   phi::MultiplyRawKernel,
117 118 119 120 121 122 123
                   float,
                   double,
                   int,
                   int64_t,
                   bool,
                   float16,
                   complex64,
124 125
                   complex128,
                   bfloat16) {}