gather_scatter_functor.h 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16 17
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
18 19 20

#pragma once

21 22
namespace phi {
namespace funcs {
23

24 25 26 27 28 29
#define Instantiate_Template_Function(func)                                    \
  Instantiate_Template_Function_index_t(                                       \
      func, int) Instantiate_Template_Function_index_t(func, float)            \
      Instantiate_Template_Function_index_t(func, double)                      \
          Instantiate_Template_Function_index_t(func, int64_t)                 \
              Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
30 31
                  Instantiate_Template_Function_index_t(func, unsigned char)

32 33 34 35 36 37 38 39 40 41 42
#define Instantiate_Template_Function_index_t(func, tensor_t)          \
  template void func<tensor_t, int>(phi::DenseTensor input,            \
                                    int dim,                           \
                                    const phi::DenseTensor& index,     \
                                    phi::DenseTensor result,           \
                                    const phi::DeviceContext& ctx);    \
  template void func<tensor_t, int64_t>(phi::DenseTensor input,        \
                                        int dim,                       \
                                        const phi::DenseTensor& index, \
                                        phi::DenseTensor result,       \
                                        const phi::DeviceContext& ctx);
43 44

template <typename tensor_t, typename index_t>
45
void cpu_gather_kernel(phi::DenseTensor self,
46
                       int dim,
47
                       const phi::DenseTensor& index,
48
                       phi::DenseTensor result,
49
                       const phi::DeviceContext& ctx);
50

51
template <typename tensor_t, typename index_t>
52
void cpu_scatter_assign_kernel(phi::DenseTensor self,
53
                               int dim,
54
                               const phi::DenseTensor& index,
55
                               phi::DenseTensor src,
56
                               const phi::DeviceContext& ctx);
57

58
template <typename tensor_t, typename index_t>
59
void cpu_scatter_add_kernel(phi::DenseTensor self,
60
                            int dim,
61
                            const phi::DenseTensor& index,
62
                            phi::DenseTensor src,
63
                            const phi::DeviceContext& ctx);
64

65
template <typename tensor_t, typename index_t>
66
void cpu_scatter_mul_kernel(phi::DenseTensor self,
67
                            int dim,
68
                            const phi::DenseTensor& index,
69
                            phi::DenseTensor src,
70
                            const phi::DeviceContext& ctx);
71 72

template <typename tensor_t, typename index_t>
73
void cpu_scatter_input_grad_kernel(phi::DenseTensor self,
74
                                   int dim,
75
                                   const phi::DenseTensor& index,
76
                                   phi::DenseTensor result,
77
                                   const phi::DeviceContext& ctx);
78

79
template <typename tensor_t, typename index_t>
80
void gpu_gather_kernel(phi::DenseTensor self,
81
                       int dim,
82
                       const phi::DenseTensor& index,
83
                       phi::DenseTensor result,
84
                       const phi::DeviceContext& ctx);
85

86
template <typename tensor_t, typename index_t>
87
void gpu_scatter_assign_kernel(phi::DenseTensor self,
88
                               int dim,
89
                               const phi::DenseTensor& index,
90
                               phi::DenseTensor src,
91
                               const phi::DeviceContext& ctx);
92

93
template <typename tensor_t, typename index_t>
94
void gpu_scatter_add_kernel(phi::DenseTensor self,
95
                            int dim,
96
                            const phi::DenseTensor& index,
97
                            phi::DenseTensor src,
98
                            const phi::DeviceContext& ctx);
99

100
template <typename tensor_t, typename index_t>
101
void gpu_scatter_mul_kernel(phi::DenseTensor self,
102
                            int dim,
103
                            const phi::DenseTensor& index,
104
                            phi::DenseTensor src,
105
                            const phi::DeviceContext& ctx);
106 107

template <typename tensor_t, typename index_t>
108
void gpu_scatter_input_grad_kernel(phi::DenseTensor self,
109
                                   int dim,
110
                                   const phi::DenseTensor& index,
111
                                   phi::DenseTensor result,
112 113 114 115
                                   const phi::DeviceContext& ctx);

}  // namespace funcs
}  // namespace phi