未验证 提交 0b79129d 编写于 作者: H huangjiyi 提交者: GitHub

[PHI decoupling] move gather_scatter_kernel from fluid to phi (#49132)

* move gather_scatter_kernel from fluid to phi

* mv gather_scatter_kernel to gather_scatter_functor
上级 7ffde4bc
...@@ -95,13 +95,7 @@ if(WITH_UNITY_BUILD) ...@@ -95,13 +95,7 @@ if(WITH_UNITY_BUILD)
include(unity_build_rule.cmake) include(unity_build_rule.cmake)
endif() endif()
if (WITH_ROCM) set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils backward_infermeta sparse_backward_infermeta)
hip_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
else()
cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta sparse_backward_infermeta)
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
......
...@@ -65,7 +65,6 @@ set(COMMON_KERNEL_DEPS ...@@ -65,7 +65,6 @@ set(COMMON_KERNEL_DEPS
deformable_conv_functor deformable_conv_functor
matrix_reduce matrix_reduce
segment_pooling segment_pooling
gather_scatter_kernel
pooling pooling
maxouting maxouting
matrix_inverse matrix_inverse
...@@ -76,7 +75,8 @@ set(COMMON_KERNEL_DEPS ...@@ -76,7 +75,8 @@ set(COMMON_KERNEL_DEPS
fft fft
phi_data_layout_transform phi_data_layout_transform
gpc gpc
utf8proc) utf8proc
gather_scatter_functor)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group)
if(WITH_NCCL OR WITH_RCCL) if(WITH_NCCL OR WITH_RCCL)
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h" #include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi { namespace phi {
...@@ -41,7 +41,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, ...@@ -41,7 +41,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if (x_grad) { if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_input_grad_kernel<T, int32_t>( phi::funcs::cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument out_grad, because it's // Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the // convenient to instantiate a bunch of template function with the
// same arguments list. // same arguments list.
...@@ -51,7 +51,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, ...@@ -51,7 +51,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
*x_grad, *x_grad,
dev_ctx); dev_ctx);
} else { } else {
paddle::operators::cpu_scatter_input_grad_kernel<T, int64_t>( phi::funcs::cpu_scatter_input_grad_kernel<T, int64_t>(
out_grad, axis, index, *x_grad, dev_ctx); out_grad, axis, index, *x_grad, dev_ctx);
} }
} }
...@@ -60,10 +60,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, ...@@ -60,10 +60,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
value_grad->Resize(index.dims()); value_grad->Resize(index.dims());
dev_ctx.template Alloc<T>(value_grad); dev_ctx.template Alloc<T>(value_grad);
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>( phi::funcs::cpu_gather_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx); out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>( phi::funcs::cpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx); out_grad, axis, index, *value_grad, dev_ctx);
} }
} }
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_kernel.h" #include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi { namespace phi {
...@@ -40,26 +40,26 @@ void PutAlongAxisKernel(const Context& dev_ctx, ...@@ -40,26 +40,26 @@ void PutAlongAxisKernel(const Context& dev_ctx,
const auto& index_type = index.dtype(); const auto& index_type = index.dtype();
if (reduce == "add") { if (reduce == "add") {
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>( phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>( phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
} else if (reduce == "multiply" || reduce == "mul") { } else if (reduce == "multiply" || reduce == "mul") {
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_mul_kernel<T, int32_t>( phi::funcs::cpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_mul_kernel<T, int64_t>( phi::funcs::cpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
} else if (reduce == "assign") { } else if (reduce == "assign") {
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_assign_kernel<T, int32_t>( phi::funcs::cpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_assign_kernel<T, int64_t>( phi::funcs::cpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
} else { } else {
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h" #include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
...@@ -46,14 +46,14 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, ...@@ -46,14 +46,14 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
const auto& index_type = const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype()); paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>( phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, *x_grad,
axis, axis,
index, index,
out_grad, out_grad,
dev_ctx); // the gradient of gather is scatter dev_ctx); // the gradient of gather is scatter
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>( phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx); *x_grad, axis, index, out_grad, dev_ctx);
} }
} }
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_kernel.h" #include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi { namespace phi {
...@@ -38,11 +38,9 @@ void TakeAlongAxisKernel(const Context& dev_ctx, ...@@ -38,11 +38,9 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
const auto& index_type = index.dtype(); const auto& index_type = index.dtype();
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>( phi::funcs::cpu_gather_kernel<T, int32_t>(x, axis, index, *out, dev_ctx);
x, axis, index, *out, dev_ctx);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>( phi::funcs::cpu_gather_kernel<T, int64_t>(x, axis, index, *out, dev_ctx);
x, axis, index, *out, dev_ctx);
} }
} }
......
...@@ -52,3 +52,15 @@ else() ...@@ -52,3 +52,15 @@ else()
math_library(selected_rows_functor DEPS selected_rows_utils math_function math_library(selected_rows_functor DEPS selected_rows_utils math_function
blas mixed_vector) blas mixed_vector)
endif() endif()
if(WITH_ROCM)
hip_library(
gather_scatter_functor
SRCS gather_scatter_functor.cc gather_scatter_functor.cu
DEPS tensor)
else()
cc_library(
gather_scatter_functor
SRCS gather_scatter_functor.cc gather_scatter_functor.cu
DEPS tensor)
endif()
...@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace paddle {
namespace operators { namespace phi {
namespace funcs {
class TensorAssign { class TensorAssign {
public: public:
...@@ -54,7 +55,7 @@ struct cpu_gather_scatter_functor { ...@@ -54,7 +55,7 @@ struct cpu_gather_scatter_functor {
const phi::DenseTensor& src, const phi::DenseTensor& src,
const std::string& method_name, const std::string& method_name,
const func_t& reduce_op, const func_t& reduce_op,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
if (index.numel() == 0) { if (index.numel() == 0) {
return; return;
} }
...@@ -69,7 +70,7 @@ struct cpu_gather_scatter_functor { ...@@ -69,7 +70,7 @@ struct cpu_gather_scatter_functor {
auto src_dims = src.dims(); auto src_dims = src.dims();
if (self_size == 0 || src_size == 0 || index_size == 0) { if (self_size == 0 || src_size == 0 || index_size == 0) {
VLOG(3) << "zero size input found"; VLOG(3) << "zero size input found";
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"self_size, src_size, index_size cannot be 0"); "self_size, src_size, index_size cannot be 0");
return; return;
} }
...@@ -132,7 +133,7 @@ void cpu_gather_kernel(phi::DenseTensor self, ...@@ -132,7 +133,7 @@ void cpu_gather_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor result, phi::DenseTensor result,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t, cpu_gather_scatter_functor<tensor_t,
index_t, index_t,
/*is_scatter_like=*/false>()( /*is_scatter_like=*/false>()(
...@@ -144,7 +145,7 @@ void cpu_scatter_assign_kernel(phi::DenseTensor self, ...@@ -144,7 +145,7 @@ void cpu_scatter_assign_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t, cpu_gather_scatter_functor<tensor_t,
index_t, index_t,
/*is_scatter_like=*/true>()( /*is_scatter_like=*/true>()(
...@@ -156,7 +157,7 @@ void cpu_scatter_add_kernel(phi::DenseTensor self, ...@@ -156,7 +157,7 @@ void cpu_scatter_add_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t, cpu_gather_scatter_functor<tensor_t,
index_t, index_t,
/*is_scatter_like=*/true>()( /*is_scatter_like=*/true>()(
...@@ -168,7 +169,7 @@ void cpu_scatter_mul_kernel(phi::DenseTensor self, ...@@ -168,7 +169,7 @@ void cpu_scatter_mul_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t, cpu_gather_scatter_functor<tensor_t,
index_t, index_t,
/*is_scatter_like=*/true>()( /*is_scatter_like=*/true>()(
...@@ -180,7 +181,7 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self, ...@@ -180,7 +181,7 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor output, phi::DenseTensor output,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
auto* index_data = index.data<index_t>(); auto* index_data = index.data<index_t>();
auto* output_data = output.data<tensor_t>(); auto* output_data = output.data<tensor_t>();
...@@ -219,5 +220,5 @@ Instantiate_Template_Function(cpu_gather_kernel) ...@@ -219,5 +220,5 @@ Instantiate_Template_Function(cpu_gather_kernel)
Instantiate_Template_Function(cpu_scatter_mul_kernel) Instantiate_Template_Function(cpu_scatter_mul_kernel)
Instantiate_Template_Function(cpu_scatter_input_grad_kernel) Instantiate_Template_Function(cpu_scatter_input_grad_kernel)
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace phi
...@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
class TensorAssign { class TensorAssign {
public: public:
...@@ -111,7 +112,7 @@ struct gpu_gather_scatter_functor { ...@@ -111,7 +112,7 @@ struct gpu_gather_scatter_functor {
phi::DenseTensor src, phi::DenseTensor src,
const std::string& method_name, const std::string& method_name,
const func_t& reduce_op, const func_t& reduce_op,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
if (index.numel() == 0) { if (index.numel() == 0) {
return; return;
} }
...@@ -162,7 +163,7 @@ void gpu_gather_kernel(phi::DenseTensor self, ...@@ -162,7 +163,7 @@ void gpu_gather_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor result, phi::DenseTensor result,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, gpu_gather_scatter_functor<tensor_t,
index_t, index_t,
/*is_scatter_like=*/false>()( /*is_scatter_like=*/false>()(
...@@ -175,7 +176,7 @@ void gpu_scatter_assign_kernel(phi::DenseTensor self, ...@@ -175,7 +176,7 @@ void gpu_scatter_assign_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, gpu_gather_scatter_functor<tensor_t,
index_t, index_t,
/*is_scatter_like=*/true>()( /*is_scatter_like=*/true>()(
...@@ -187,7 +188,7 @@ void gpu_scatter_add_kernel(phi::DenseTensor self, ...@@ -187,7 +188,7 @@ void gpu_scatter_add_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, gpu_gather_scatter_functor<tensor_t,
index_t, index_t,
/*is_scatter_like=*/true>()( /*is_scatter_like=*/true>()(
...@@ -199,7 +200,7 @@ void gpu_scatter_mul_kernel(phi::DenseTensor self, ...@@ -199,7 +200,7 @@ void gpu_scatter_mul_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, gpu_gather_scatter_functor<tensor_t,
index_t, index_t,
/*is_scatter_like=*/true>()( /*is_scatter_like=*/true>()(
...@@ -232,7 +233,7 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self, ...@@ -232,7 +233,7 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor grad, phi::DenseTensor grad,
const platform::DeviceContext& ctx) { const phi::DeviceContext& ctx) {
auto* index_data = index.data<index_t>(); auto* index_data = index.data<index_t>();
auto* grad_data = grad.data<tensor_t>(); auto* grad_data = grad.data<tensor_t>();
...@@ -273,5 +274,5 @@ Instantiate_Template_Function(gpu_gather_kernel) ...@@ -273,5 +274,5 @@ Instantiate_Template_Function(gpu_gather_kernel)
Instantiate_Template_Function(gpu_scatter_mul_kernel) Instantiate_Template_Function(gpu_scatter_mul_kernel)
Instantiate_Template_Function(gpu_scatter_input_grad_kernel) Instantiate_Template_Function(gpu_scatter_input_grad_kernel)
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace phi
...@@ -12,21 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,21 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/phi/common/float16.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/phi/core/device_context.h"
#pragma once #pragma once
namespace paddle { namespace phi {
namespace operators { namespace funcs {
#define Instantiate_Template_Function(func) \ #define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \ Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \ func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t(func, double) \ Instantiate_Template_Function_index_t(func, double) \
Instantiate_Template_Function_index_t(func, int64_t) \ Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, platform::float16) \ Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
Instantiate_Template_Function_index_t(func, unsigned char) Instantiate_Template_Function_index_t(func, unsigned char)
#define Instantiate_Template_Function_index_t(func, tensor_t) \ #define Instantiate_Template_Function_index_t(func, tensor_t) \
...@@ -34,81 +34,82 @@ namespace operators { ...@@ -34,81 +34,82 @@ namespace operators {
int dim, \ int dim, \
const phi::DenseTensor& index, \ const phi::DenseTensor& index, \
phi::DenseTensor result, \ phi::DenseTensor result, \
const platform::DeviceContext& ctx); \ const phi::DeviceContext& ctx); \
template void func<tensor_t, int64_t>(phi::DenseTensor input, \ template void func<tensor_t, int64_t>(phi::DenseTensor input, \
int dim, \ int dim, \
const phi::DenseTensor& index, \ const phi::DenseTensor& index, \
phi::DenseTensor result, \ phi::DenseTensor result, \
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t> template <typename tensor_t, typename index_t>
void cpu_gather_kernel(phi::DenseTensor self, void cpu_gather_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor result, phi::DenseTensor result,
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t> template <typename tensor_t, typename index_t>
void cpu_scatter_assign_kernel(phi::DenseTensor self, void cpu_scatter_assign_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t> template <typename tensor_t, typename index_t>
void cpu_scatter_add_kernel(phi::DenseTensor self, void cpu_scatter_add_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t> template <typename tensor_t, typename index_t>
void cpu_scatter_mul_kernel(phi::DenseTensor self, void cpu_scatter_mul_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t> template <typename tensor_t, typename index_t>
void cpu_scatter_input_grad_kernel(phi::DenseTensor self, void cpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor result, phi::DenseTensor result,
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t> template <typename tensor_t, typename index_t>
void gpu_gather_kernel(phi::DenseTensor self, void gpu_gather_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor result, phi::DenseTensor result,
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t> template <typename tensor_t, typename index_t>
void gpu_scatter_assign_kernel(phi::DenseTensor self, void gpu_scatter_assign_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t> template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(phi::DenseTensor self, void gpu_scatter_add_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t> template <typename tensor_t, typename index_t>
void gpu_scatter_mul_kernel(phi::DenseTensor self, void gpu_scatter_mul_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor src, phi::DenseTensor src,
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t> template <typename tensor_t, typename index_t>
void gpu_scatter_input_grad_kernel(phi::DenseTensor self, void gpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim, int dim,
const phi::DenseTensor& index, const phi::DenseTensor& index,
phi::DenseTensor result, phi::DenseTensor result,
const platform::DeviceContext& ctx); const phi::DeviceContext& ctx);
} // namespace operators
} // namespace paddle } // namespace funcs
} // namespace phi
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h" #include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi { namespace phi {
...@@ -41,10 +41,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, ...@@ -41,10 +41,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if (x_grad) { if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_input_grad_kernel<T, int32_t>( phi::funcs::gpu_scatter_input_grad_kernel<T, int32_t>(
out_grad, axis, index, *x_grad, dev_ctx); out_grad, axis, index, *x_grad, dev_ctx);
} else { } else {
paddle::operators::gpu_scatter_input_grad_kernel<T, int64_t>( phi::funcs::gpu_scatter_input_grad_kernel<T, int64_t>(
out_grad, axis, index, *x_grad, dev_ctx); out_grad, axis, index, *x_grad, dev_ctx);
} }
} }
...@@ -52,14 +52,14 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, ...@@ -52,14 +52,14 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
value_grad->Resize(index.dims()); value_grad->Resize(index.dims());
dev_ctx.template Alloc<T>(value_grad); dev_ctx.template Alloc<T>(value_grad);
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_gather_kernel<T, int32_t>( phi::funcs::gpu_gather_kernel<T, int32_t>(
out_grad, out_grad,
axis, axis,
index, index,
*value_grad, *value_grad,
dev_ctx); // the gradient of scatter is gather dev_ctx); // the gradient of scatter is gather
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_gather_kernel<T, int64_t>( phi::funcs::gpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx); out_grad, axis, index, *value_grad, dev_ctx);
} }
} }
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_kernel.h" #include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi { namespace phi {
...@@ -41,26 +41,26 @@ void PutAlongAxisKernel(const Context& dev_ctx, ...@@ -41,26 +41,26 @@ void PutAlongAxisKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
if (reduce == "add") { if (reduce == "add") {
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_add_kernel<T, int32_t>( phi::funcs::gpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_add_kernel<T, int64_t>( phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
} else if (reduce == "multiply" || reduce == "mul") { } else if (reduce == "multiply" || reduce == "mul") {
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_mul_kernel<T, int32_t>( phi::funcs::gpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_mul_kernel<T, int64_t>( phi::funcs::gpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
} else if (reduce == "assign") { } else if (reduce == "assign") {
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_assign_kernel<T, int32_t>( phi::funcs::gpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_assign_kernel<T, int64_t>( phi::funcs::gpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
} else { } else {
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h" #include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
...@@ -46,14 +46,14 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, ...@@ -46,14 +46,14 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
const auto& index_type = index.dtype(); const auto& index_type = index.dtype();
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_add_kernel<T, int32_t>( phi::funcs::gpu_scatter_add_kernel<T, int32_t>(
*x_grad, *x_grad,
axis, axis,
index, index,
out_grad, out_grad,
dev_ctx); // the gradient of gather is scatter dev_ctx); // the gradient of gather is scatter
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_add_kernel<T, int64_t>( phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx); *x_grad, axis, index, out_grad, dev_ctx);
} }
} }
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_kernel.h" #include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi { namespace phi {
...@@ -38,11 +38,9 @@ void TakeAlongAxisKernel(const Context& dev_ctx, ...@@ -38,11 +38,9 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
const auto& index_type = index.dtype(); const auto& index_type = index.dtype();
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_gather_kernel<T, int32_t>( phi::funcs::gpu_gather_kernel<T, int32_t>(x, axis, index, *out, dev_ctx);
x, axis, index, *out, dev_ctx);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_gather_kernel<T, int64_t>( phi::funcs::gpu_gather_kernel<T, int64_t>(x, axis, index, *out, dev_ctx);
x, axis, index, *out, dev_ctx);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册