未验证 提交 0b79129d 编写于 作者: H huangjiyi 提交者: GitHub

[PHI decoupling] move gather_scatter_kernel from fluid to phi (#49132)

* move gather_scatter_kernel from fluid to phi

* mv gather_scatter_kernel to gather_scatter_functor
上级 7ffde4bc
......@@ -95,13 +95,7 @@ if(WITH_UNITY_BUILD)
include(unity_build_rule.cmake)
endif()
if (WITH_ROCM)
hip_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
else()
cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta sparse_backward_infermeta)
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils backward_infermeta sparse_backward_infermeta)
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
......
......@@ -65,7 +65,6 @@ set(COMMON_KERNEL_DEPS
deformable_conv_functor
matrix_reduce
segment_pooling
gather_scatter_kernel
pooling
maxouting
matrix_inverse
......@@ -76,7 +75,8 @@ set(COMMON_KERNEL_DEPS
fft
phi_data_layout_transform
gpc
utf8proc)
utf8proc
gather_scatter_functor)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group)
if(WITH_NCCL OR WITH_RCCL)
......
......@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi {
......@@ -41,7 +41,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_input_grad_kernel<T, int32_t>(
phi::funcs::cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the
// same arguments list.
......@@ -51,7 +51,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
*x_grad,
dev_ctx);
} else {
paddle::operators::cpu_scatter_input_grad_kernel<T, int64_t>(
phi::funcs::cpu_scatter_input_grad_kernel<T, int64_t>(
out_grad, axis, index, *x_grad, dev_ctx);
}
}
......@@ -60,10 +60,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
value_grad->Resize(index.dims());
dev_ctx.template Alloc<T>(value_grad);
if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>(
phi::funcs::cpu_gather_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>(
phi::funcs::cpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
}
......
......@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi {
......@@ -40,26 +40,26 @@ void PutAlongAxisKernel(const Context& dev_ctx,
const auto& index_type = index.dtype();
if (reduce == "add") {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>(
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>(
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "multiply" || reduce == "mul") {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_mul_kernel<T, int32_t>(
phi::funcs::cpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_mul_kernel<T, int64_t>(
phi::funcs::cpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "assign") {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_assign_kernel<T, int32_t>(
phi::funcs::cpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_assign_kernel<T, int64_t>(
phi::funcs::cpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else {
......
......@@ -14,10 +14,10 @@
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -46,14 +46,14 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>(
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad,
axis,
index,
out_grad,
dev_ctx); // the gradient of gather is scatter
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>(
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx);
}
}
......
......@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi {
......@@ -38,11 +38,9 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
const auto& index_type = index.dtype();
if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>(
x, axis, index, *out, dev_ctx);
phi::funcs::cpu_gather_kernel<T, int32_t>(x, axis, index, *out, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>(
x, axis, index, *out, dev_ctx);
phi::funcs::cpu_gather_kernel<T, int64_t>(x, axis, index, *out, dev_ctx);
}
}
......
......@@ -52,3 +52,15 @@ else()
math_library(selected_rows_functor DEPS selected_rows_utils math_function
blas mixed_vector)
endif()
if(WITH_ROCM)
hip_library(
gather_scatter_functor
SRCS gather_scatter_functor.cc gather_scatter_functor.cu
DEPS tensor)
else()
cc_library(
gather_scatter_functor
SRCS gather_scatter_functor.cc gather_scatter_functor.cu
DEPS tensor)
endif()
......@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gather_scatter_kernel.h"
namespace paddle {
namespace operators {
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi {
namespace funcs {
class TensorAssign {
public:
......@@ -54,7 +55,7 @@ struct cpu_gather_scatter_functor {
const phi::DenseTensor& src,
const std::string& method_name,
const func_t& reduce_op,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
if (index.numel() == 0) {
return;
}
......@@ -69,7 +70,7 @@ struct cpu_gather_scatter_functor {
auto src_dims = src.dims();
if (self_size == 0 || src_size == 0 || index_size == 0) {
VLOG(3) << "zero size input found";
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"self_size, src_size, index_size cannot be 0");
return;
}
......@@ -132,7 +133,7 @@ void cpu_gather_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/false>()(
......@@ -144,7 +145,7 @@ void cpu_scatter_assign_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/true>()(
......@@ -156,7 +157,7 @@ void cpu_scatter_add_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/true>()(
......@@ -168,7 +169,7 @@ void cpu_scatter_mul_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/true>()(
......@@ -180,7 +181,7 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor output,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
auto* index_data = index.data<index_t>();
auto* output_data = output.data<tensor_t>();
......@@ -219,5 +220,5 @@ Instantiate_Template_Function(cpu_gather_kernel)
Instantiate_Template_Function(cpu_scatter_mul_kernel)
Instantiate_Template_Function(cpu_scatter_input_grad_kernel)
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
......@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
namespace phi {
namespace funcs {
class TensorAssign {
public:
......@@ -111,7 +112,7 @@ struct gpu_gather_scatter_functor {
phi::DenseTensor src,
const std::string& method_name,
const func_t& reduce_op,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
if (index.numel() == 0) {
return;
}
......@@ -162,7 +163,7 @@ void gpu_gather_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/false>()(
......@@ -175,7 +176,7 @@ void gpu_scatter_assign_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/true>()(
......@@ -187,7 +188,7 @@ void gpu_scatter_add_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/true>()(
......@@ -199,7 +200,7 @@ void gpu_scatter_mul_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t,
index_t,
/*is_scatter_like=*/true>()(
......@@ -232,7 +233,7 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor grad,
const platform::DeviceContext& ctx) {
const phi::DeviceContext& ctx) {
auto* index_data = index.data<index_t>();
auto* grad_data = grad.data<tensor_t>();
......@@ -273,5 +274,5 @@ Instantiate_Template_Function(gpu_gather_kernel)
Instantiate_Template_Function(gpu_scatter_mul_kernel)
Instantiate_Template_Function(gpu_scatter_input_grad_kernel)
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
......@@ -12,103 +12,104 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#pragma once
namespace paddle {
namespace operators {
namespace phi {
namespace funcs {
#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t(func, double) \
Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, platform::float16) \
#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t(func, double) \
Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
Instantiate_Template_Function_index_t(func, unsigned char)
#define Instantiate_Template_Function_index_t(func, tensor_t) \
template void func<tensor_t, int>(phi::DenseTensor input, \
int dim, \
const phi::DenseTensor& index, \
phi::DenseTensor result, \
const platform::DeviceContext& ctx); \
template void func<tensor_t, int64_t>(phi::DenseTensor input, \
int dim, \
const phi::DenseTensor& index, \
phi::DenseTensor result, \
const platform::DeviceContext& ctx);
#define Instantiate_Template_Function_index_t(func, tensor_t) \
template void func<tensor_t, int>(phi::DenseTensor input, \
int dim, \
const phi::DenseTensor& index, \
phi::DenseTensor result, \
const phi::DeviceContext& ctx); \
template void func<tensor_t, int64_t>(phi::DenseTensor input, \
int dim, \
const phi::DenseTensor& index, \
phi::DenseTensor result, \
const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void cpu_gather_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void cpu_scatter_assign_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void cpu_scatter_add_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void cpu_scatter_mul_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void cpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void gpu_gather_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void gpu_scatter_assign_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void gpu_scatter_mul_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor src,
const platform::DeviceContext& ctx);
const phi::DeviceContext& ctx);
template <typename tensor_t, typename index_t>
void gpu_scatter_input_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor result,
const platform::DeviceContext& ctx);
} // namespace operators
} // namespace paddle
const phi::DeviceContext& ctx);
} // namespace funcs
} // namespace phi
......@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi {
......@@ -41,10 +41,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_input_grad_kernel<T, int32_t>(
phi::funcs::gpu_scatter_input_grad_kernel<T, int32_t>(
out_grad, axis, index, *x_grad, dev_ctx);
} else {
paddle::operators::gpu_scatter_input_grad_kernel<T, int64_t>(
phi::funcs::gpu_scatter_input_grad_kernel<T, int64_t>(
out_grad, axis, index, *x_grad, dev_ctx);
}
}
......@@ -52,14 +52,14 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
value_grad->Resize(index.dims());
dev_ctx.template Alloc<T>(value_grad);
if (index_type == DataType::INT32) {
paddle::operators::gpu_gather_kernel<T, int32_t>(
phi::funcs::gpu_gather_kernel<T, int32_t>(
out_grad,
axis,
index,
*value_grad,
dev_ctx); // the gradient of scatter is gather
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_gather_kernel<T, int64_t>(
phi::funcs::gpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
}
......
......@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi {
......@@ -41,26 +41,26 @@ void PutAlongAxisKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
if (reduce == "add") {
if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_add_kernel<T, int32_t>(
phi::funcs::gpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_add_kernel<T, int64_t>(
phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "multiply" || reduce == "mul") {
if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_mul_kernel<T, int32_t>(
phi::funcs::gpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_mul_kernel<T, int64_t>(
phi::funcs::gpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "assign") {
if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_assign_kernel<T, int32_t>(
phi::funcs::gpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_assign_kernel<T, int64_t>(
phi::funcs::gpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else {
......
......@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -46,14 +46,14 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
const auto& index_type = index.dtype();
if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_add_kernel<T, int32_t>(
phi::funcs::gpu_scatter_add_kernel<T, int32_t>(
*x_grad,
axis,
index,
out_grad,
dev_ctx); // the gradient of gather is scatter
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_add_kernel<T, int64_t>(
phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx);
}
}
......
......@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
namespace phi {
......@@ -38,11 +38,9 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
const auto& index_type = index.dtype();
if (index_type == DataType::INT32) {
paddle::operators::gpu_gather_kernel<T, int32_t>(
x, axis, index, *out, dev_ctx);
phi::funcs::gpu_gather_kernel<T, int32_t>(x, axis, index, *out, dev_ctx);
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_gather_kernel<T, int64_t>(
x, axis, index, *out, dev_ctx);
phi::funcs::gpu_gather_kernel<T, int64_t>(x, axis, index, *out, dev_ctx);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册