diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 882706dc8dd62c5dcf7e64ec645ef6efa1c13698..4aeb3d6b74398032344ad28fabe7cfac724eabf4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -95,13 +95,7 @@ if(WITH_UNITY_BUILD) include(unity_build_rule.cmake) endif() -if (WITH_ROCM) - hip_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor) -else() - cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor) -endif() - -set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta sparse_backward_infermeta) +set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils backward_infermeta sparse_backward_infermeta) register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 735ba7beaa2d1e78fa5d5351cef5f3e872fcb7f9..cbe7b25ea0755fd34a7b81fc6e8c8f98b03a82df 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -65,7 +65,6 @@ set(COMMON_KERNEL_DEPS deformable_conv_functor matrix_reduce segment_pooling - gather_scatter_kernel pooling maxouting matrix_inverse @@ -76,7 +75,8 @@ set(COMMON_KERNEL_DEPS fft phi_data_layout_transform gpc - utf8proc) + utf8proc + gather_scatter_functor) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group) if(WITH_NCCL OR WITH_RCCL) diff --git a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc index 969c5b9fe330643e157be7cd90b5f5a573b5615c..a4af5f6db57daf6ea0c34a76074236c727831dcd 100644 --- a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc @@ -14,12 +14,12 @@ #include "paddle/phi/kernels/put_along_axis_grad_kernel.h" -#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" namespace phi { @@ -41,7 +41,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, if (x_grad) { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); if (index_type == DataType::INT32) { - paddle::operators::cpu_scatter_input_grad_kernel( + phi::funcs::cpu_scatter_input_grad_kernel( // Here passing an unused argument out_grad, because it's // convenient to instantiate a bunch of template function with the // same arguments list. @@ -51,7 +51,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, *x_grad, dev_ctx); } else { - paddle::operators::cpu_scatter_input_grad_kernel( + phi::funcs::cpu_scatter_input_grad_kernel( out_grad, axis, index, *x_grad, dev_ctx); } } @@ -60,10 +60,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, value_grad->Resize(index.dims()); dev_ctx.template Alloc(value_grad); if (index_type == DataType::INT32) { - paddle::operators::cpu_gather_kernel( + phi::funcs::cpu_gather_kernel( out_grad, axis, index, *value_grad, dev_ctx); } else if (index_type == DataType::INT64) { - paddle::operators::cpu_gather_kernel( + phi::funcs::cpu_gather_kernel( out_grad, axis, index, *value_grad, dev_ctx); } } diff --git a/paddle/phi/kernels/cpu/put_along_axis_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_kernel.cc index e0cf5f6730c247c428c4cbe7d798bc1a822bb8c7..34516a7e4dd935303659292a9e03857946d0a9cd 100644 --- a/paddle/phi/kernels/cpu/put_along_axis_kernel.cc +++ b/paddle/phi/kernels/cpu/put_along_axis_kernel.cc @@ -14,12 +14,12 @@ #include "paddle/phi/kernels/put_along_axis_kernel.h" -#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" namespace phi { @@ -40,26 +40,26 @@ void PutAlongAxisKernel(const Context& dev_ctx, const auto& index_type = index.dtype(); if (reduce == "add") { if (index_type == DataType::INT32) { - paddle::operators::cpu_scatter_add_kernel( + phi::funcs::cpu_scatter_add_kernel( *out, axis, index, value, dev_ctx); } else if (index_type == DataType::INT64) { - paddle::operators::cpu_scatter_add_kernel( + phi::funcs::cpu_scatter_add_kernel( *out, axis, index, value, dev_ctx); } } else if (reduce == "multiply" || reduce == "mul") { if (index_type == DataType::INT32) { - paddle::operators::cpu_scatter_mul_kernel( + phi::funcs::cpu_scatter_mul_kernel( *out, axis, index, value, dev_ctx); } else if (index_type == DataType::INT64) { - paddle::operators::cpu_scatter_mul_kernel( + phi::funcs::cpu_scatter_mul_kernel( *out, axis, index, value, dev_ctx); } } else if (reduce == "assign") { if (index_type == DataType::INT32) { - paddle::operators::cpu_scatter_assign_kernel( + phi::funcs::cpu_scatter_assign_kernel( *out, axis, index, value, dev_ctx); } else if (index_type == DataType::INT64) { - paddle::operators::cpu_scatter_assign_kernel( + phi::funcs::cpu_scatter_assign_kernel( *out, axis, index, value, dev_ctx); } } else { diff --git a/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc index e8fc15cc171b5d7d1d7e5abc1f6f98fbec14fbc2..435490d93266e047b4672a6d4efbb829d39a946d 100644 --- a/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc @@ -14,10 +14,10 @@ #include "paddle/phi/kernels/take_along_axis_grad_kernel.h" -#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -46,14 +46,14 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, const auto& index_type = paddle::framework::TransToProtoVarType(index.dtype()); if (index_type == paddle::framework::proto::VarType::INT32) { - paddle::operators::cpu_scatter_add_kernel( + phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, index, out_grad, dev_ctx); // the gradient of gather is scatter } else if (index_type == paddle::framework::proto::VarType::INT64) { - paddle::operators::cpu_scatter_add_kernel( + phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, index, out_grad, dev_ctx); } } diff --git a/paddle/phi/kernels/cpu/take_along_axis_kernel.cc b/paddle/phi/kernels/cpu/take_along_axis_kernel.cc index cd1ff2e92628871a74a59f2cde182ae761d7ebc5..417b4169d3fdb2f57f59af89d7f06ffaca433e02 100644 --- a/paddle/phi/kernels/cpu/take_along_axis_kernel.cc +++ b/paddle/phi/kernels/cpu/take_along_axis_kernel.cc @@ -14,11 +14,11 @@ #include "paddle/phi/kernels/take_along_axis_kernel.h" -#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" namespace phi { @@ -38,11 +38,9 @@ void TakeAlongAxisKernel(const Context& dev_ctx, const auto& index_type = index.dtype(); if (index_type == DataType::INT32) { - paddle::operators::cpu_gather_kernel( - x, axis, index, *out, dev_ctx); + phi::funcs::cpu_gather_kernel(x, axis, index, *out, dev_ctx); } else if (index_type == DataType::INT64) { - paddle::operators::cpu_gather_kernel( - x, axis, index, *out, dev_ctx); + phi::funcs::cpu_gather_kernel(x, axis, index, *out, dev_ctx); } } diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index 0d2cfa150f292de1993718202bb395bb962e4197..1b0e3a32f7e2fa20ebeb0aec1c4f8384b5d048e1 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -52,3 +52,15 @@ else() math_library(selected_rows_functor DEPS selected_rows_utils math_function blas mixed_vector) endif() + +if(WITH_ROCM) + hip_library( + gather_scatter_functor + SRCS gather_scatter_functor.cc gather_scatter_functor.cu + DEPS tensor) +else() + cc_library( + gather_scatter_functor + SRCS gather_scatter_functor.cc gather_scatter_functor.cu + DEPS tensor) +endif() diff --git a/paddle/fluid/operators/gather_scatter_kernel.cc b/paddle/phi/kernels/funcs/gather_scatter_functor.cc similarity index 92% rename from paddle/fluid/operators/gather_scatter_kernel.cc rename to paddle/phi/kernels/funcs/gather_scatter_functor.cc index 1c6b2e6c1a0951c034ef9dd69ac93fca7f2bdeb4..67af6a3322d9a6d44d31ac9690ed7ee38fbf8d4c 100644 --- a/paddle/fluid/operators/gather_scatter_kernel.cc +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cc @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gather_scatter_kernel.h" -namespace paddle { -namespace operators { +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" + +namespace phi { +namespace funcs { class TensorAssign { public: @@ -54,7 +55,7 @@ struct cpu_gather_scatter_functor { const phi::DenseTensor& src, const std::string& method_name, const func_t& reduce_op, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { if (index.numel() == 0) { return; } @@ -69,7 +70,7 @@ struct cpu_gather_scatter_functor { auto src_dims = src.dims(); if (self_size == 0 || src_size == 0 || index_size == 0) { VLOG(3) << "zero size input found"; - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "self_size, src_size, index_size cannot be 0"); return; } @@ -132,7 +133,7 @@ void cpu_gather_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor result, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { cpu_gather_scatter_functor()( @@ -144,7 +145,7 @@ void cpu_scatter_assign_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { cpu_gather_scatter_functor()( @@ -156,7 +157,7 @@ void cpu_scatter_add_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { cpu_gather_scatter_functor()( @@ -168,7 +169,7 @@ void cpu_scatter_mul_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { cpu_gather_scatter_functor()( @@ -180,7 +181,7 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor output, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { auto* index_data = index.data(); auto* output_data = output.data(); @@ -219,5 +220,5 @@ Instantiate_Template_Function(cpu_gather_kernel) Instantiate_Template_Function(cpu_scatter_mul_kernel) Instantiate_Template_Function(cpu_scatter_input_grad_kernel) -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/gather_scatter_kernel.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu similarity index 95% rename from paddle/fluid/operators/gather_scatter_kernel.cu rename to paddle/phi/kernels/funcs/gather_scatter_functor.cu index 1cb4e4a4e9d78223d1cbf30e0b2b76f45cecb035..b53de3beef9aa4871d9b60862e020c75083171a6 100644 --- a/paddle/fluid/operators/gather_scatter_kernel.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gather_scatter_kernel.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { class TensorAssign { public: @@ -111,7 +112,7 @@ struct gpu_gather_scatter_functor { phi::DenseTensor src, const std::string& method_name, const func_t& reduce_op, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { if (index.numel() == 0) { return; } @@ -162,7 +163,7 @@ void gpu_gather_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor result, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { gpu_gather_scatter_functor()( @@ -175,7 +176,7 @@ void gpu_scatter_assign_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { gpu_gather_scatter_functor()( @@ -187,7 +188,7 @@ void gpu_scatter_add_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { gpu_gather_scatter_functor()( @@ -199,7 +200,7 @@ void gpu_scatter_mul_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { gpu_gather_scatter_functor()( @@ -232,7 +233,7 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor grad, - const platform::DeviceContext& ctx) { + const phi::DeviceContext& ctx) { auto* index_data = index.data(); auto* grad_data = grad.data(); @@ -273,5 +274,5 @@ Instantiate_Template_Function(gpu_gather_kernel) Instantiate_Template_Function(gpu_scatter_mul_kernel) Instantiate_Template_Function(gpu_scatter_input_grad_kernel) -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/gather_scatter_kernel.h b/paddle/phi/kernels/funcs/gather_scatter_functor.h similarity index 76% rename from paddle/fluid/operators/gather_scatter_kernel.h rename to paddle/phi/kernels/funcs/gather_scatter_functor.h index 9cf3c3e33009a0bdc2e08022e20dc5fece5120ef..00e8f45d8ffb0ad75ed5732620b598494f02f2f4 100644 --- a/paddle/fluid/operators/gather_scatter_kernel.h +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.h @@ -12,103 +12,104 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/float16.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" #pragma once -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { -#define Instantiate_Template_Function(func) \ - Instantiate_Template_Function_index_t( \ - func, int) Instantiate_Template_Function_index_t(func, float) \ - Instantiate_Template_Function_index_t(func, double) \ - Instantiate_Template_Function_index_t(func, int64_t) \ - Instantiate_Template_Function_index_t(func, platform::float16) \ +#define Instantiate_Template_Function(func) \ + Instantiate_Template_Function_index_t( \ + func, int) Instantiate_Template_Function_index_t(func, float) \ + Instantiate_Template_Function_index_t(func, double) \ + Instantiate_Template_Function_index_t(func, int64_t) \ + Instantiate_Template_Function_index_t(func, phi::dtype::float16) \ Instantiate_Template_Function_index_t(func, unsigned char) -#define Instantiate_Template_Function_index_t(func, tensor_t) \ - template void func(phi::DenseTensor input, \ - int dim, \ - const phi::DenseTensor& index, \ - phi::DenseTensor result, \ - const platform::DeviceContext& ctx); \ - template void func(phi::DenseTensor input, \ - int dim, \ - const phi::DenseTensor& index, \ - phi::DenseTensor result, \ - const platform::DeviceContext& ctx); +#define Instantiate_Template_Function_index_t(func, tensor_t) \ + template void func(phi::DenseTensor input, \ + int dim, \ + const phi::DenseTensor& index, \ + phi::DenseTensor result, \ + const phi::DeviceContext& ctx); \ + template void func(phi::DenseTensor input, \ + int dim, \ + const phi::DenseTensor& index, \ + phi::DenseTensor result, \ + const phi::DeviceContext& ctx); template void cpu_gather_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor result, - const platform::DeviceContext& ctx); + const phi::DeviceContext& ctx); template void cpu_scatter_assign_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx); + const phi::DeviceContext& ctx); template void cpu_scatter_add_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx); + const phi::DeviceContext& ctx); template void cpu_scatter_mul_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx); + const phi::DeviceContext& ctx); template void cpu_scatter_input_grad_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor result, - const platform::DeviceContext& ctx); + const phi::DeviceContext& ctx); template void gpu_gather_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor result, - const platform::DeviceContext& ctx); + const phi::DeviceContext& ctx); template void gpu_scatter_assign_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx); + const phi::DeviceContext& ctx); template void gpu_scatter_add_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx); + const phi::DeviceContext& ctx); template void gpu_scatter_mul_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor src, - const platform::DeviceContext& ctx); + const phi::DeviceContext& ctx); template void gpu_scatter_input_grad_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, phi::DenseTensor result, - const platform::DeviceContext& ctx); -} // namespace operators -} // namespace paddle + const phi::DeviceContext& ctx); + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu index fcf43f9f42718d8bb035977c7e0e4ccc9f3c5c66..afc9984463661f5749a4c497f5bc6278fa275814 100644 --- a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu @@ -14,12 +14,12 @@ #include "paddle/phi/kernels/put_along_axis_grad_kernel.h" -#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" namespace phi { @@ -41,10 +41,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, if (x_grad) { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); if (index_type == DataType::INT32) { - paddle::operators::gpu_scatter_input_grad_kernel( + phi::funcs::gpu_scatter_input_grad_kernel( out_grad, axis, index, *x_grad, dev_ctx); } else { - paddle::operators::gpu_scatter_input_grad_kernel( + phi::funcs::gpu_scatter_input_grad_kernel( out_grad, axis, index, *x_grad, dev_ctx); } } @@ -52,14 +52,14 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, value_grad->Resize(index.dims()); dev_ctx.template Alloc(value_grad); if (index_type == DataType::INT32) { - paddle::operators::gpu_gather_kernel( + phi::funcs::gpu_gather_kernel( out_grad, axis, index, *value_grad, dev_ctx); // the gradient of scatter is gather } else if (index_type == DataType::INT64) { - paddle::operators::gpu_gather_kernel( + phi::funcs::gpu_gather_kernel( out_grad, axis, index, *value_grad, dev_ctx); } } diff --git a/paddle/phi/kernels/gpu/put_along_axis_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_kernel.cu index b43d6fafa72a6cc5b59e61a51d169ba1829913d4..12127516f0ef32c952c4d8922aa0d4f11dd5c681 100644 --- a/paddle/phi/kernels/gpu/put_along_axis_kernel.cu +++ b/paddle/phi/kernels/gpu/put_along_axis_kernel.cu @@ -14,12 +14,12 @@ #include "paddle/phi/kernels/put_along_axis_kernel.h" -#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" namespace phi { @@ -41,26 +41,26 @@ void PutAlongAxisKernel(const Context& dev_ctx, phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); if (reduce == "add") { if (index_type == DataType::INT32) { - paddle::operators::gpu_scatter_add_kernel( + phi::funcs::gpu_scatter_add_kernel( *out, axis, index, value, dev_ctx); } else if (index_type == DataType::INT64) { - paddle::operators::gpu_scatter_add_kernel( + phi::funcs::gpu_scatter_add_kernel( *out, axis, index, value, dev_ctx); } } else if (reduce == "multiply" || reduce == "mul") { if (index_type == DataType::INT32) { - paddle::operators::gpu_scatter_mul_kernel( + phi::funcs::gpu_scatter_mul_kernel( *out, axis, index, value, dev_ctx); } else if (index_type == DataType::INT64) { - paddle::operators::gpu_scatter_mul_kernel( + phi::funcs::gpu_scatter_mul_kernel( *out, axis, index, value, dev_ctx); } } else if (reduce == "assign") { if (index_type == DataType::INT32) { - paddle::operators::gpu_scatter_assign_kernel( + phi::funcs::gpu_scatter_assign_kernel( *out, axis, index, value, dev_ctx); } else if (index_type == DataType::INT64) { - paddle::operators::gpu_scatter_assign_kernel( + phi::funcs::gpu_scatter_assign_kernel( *out, axis, index, value, dev_ctx); } } else { diff --git a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu index 07afc3ba8bb187f4e5d84684921e83272bbe9685..8530e3af1892f0978b8f1e108a8fef719efc6c06 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu @@ -14,11 +14,11 @@ #include "paddle/phi/kernels/take_along_axis_grad_kernel.h" -#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -46,14 +46,14 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, const auto& index_type = index.dtype(); if (index_type == DataType::INT32) { - paddle::operators::gpu_scatter_add_kernel( + phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, index, out_grad, dev_ctx); // the gradient of gather is scatter } else if (index_type == DataType::INT64) { - paddle::operators::gpu_scatter_add_kernel( + phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, index, out_grad, dev_ctx); } } diff --git a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu index 28a1c9b657d7ae661e2e72f4bfe9c16178c13685..a548ac1a14ddffabac180d30d2ae20f4fba2c5d6 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu @@ -14,11 +14,11 @@ #include "paddle/phi/kernels/take_along_axis_kernel.h" -#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/gather_scatter_functor.h" namespace phi { @@ -38,11 +38,9 @@ void TakeAlongAxisKernel(const Context& dev_ctx, const auto& index_type = index.dtype(); if (index_type == DataType::INT32) { - paddle::operators::gpu_gather_kernel( - x, axis, index, *out, dev_ctx); + phi::funcs::gpu_gather_kernel(x, axis, index, *out, dev_ctx); } else if (index_type == DataType::INT64) { - paddle::operators::gpu_gather_kernel( - x, axis, index, *out, dev_ctx); + phi::funcs::gpu_gather_kernel(x, axis, index, *out, dev_ctx); } }