diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 8a405cc6fc1baefe997fb5b6133a56d6a2fc0438..7910d94298e7efb2cb5dc8616793013910a449d6 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gather_op.h" #include #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/ddim.h" @@ -198,17 +198,7 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker, ops::GatherGradOpMaker); REGISTER_OPERATOR(gather_grad, ops::GatherGradOp, ops::GatherGradNoNeedBufferVarInferer); -REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel, - ops::GatherOpKernel, ops::GatherOpKernel, - ops::GatherOpKernel, - ops::GatherOpKernel, - ops::GatherOpKernel); -REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel, - ops::GatherGradientOpKernel, - ops::GatherGradientOpKernel, - ops::GatherGradientOpKernel, - ops::GatherGradientOpKernel, - ops::GatherGradientOpKernel); + REGISTER_OP_VERSION(gather) .AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC", paddle::framework::compatible::OpVersionDesc().NewInput( diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu deleted file mode 100644 index e0db2f26d3e0534f924cc709b98689fb3f1a5cc6..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/gather_op.cu +++ /dev/null @@ -1,152 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/gather_op.h" -#include "paddle/phi/kernels/funcs/gather.cu.h" -#include "paddle/phi/kernels/funcs/scatter.cu.h" - -namespace paddle { -namespace operators { - -template -class GatherOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); - auto *output = ctx.Output("Out"); - - int axis = ctx.Attr("axis"); - - // get axis from tensor - if (ctx.HasInput("Axis")) { - Tensor cpu_axis; - const Tensor *axis_tensor = ctx.Input("Axis"); - framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); - const auto &axis_type = - framework::TransToProtoVarType(axis_tensor->dtype()); - if (axis_type == framework::proto::VarType::INT32) { - axis = static_cast(cpu_axis.data()[0]); - } else if (axis_type == framework::proto::VarType::INT64) { - axis = static_cast(cpu_axis.data()[0]); - } else if (axis_type == framework::proto::VarType::INT16) { - axis = static_cast(cpu_axis.data()[0]); - } - } - const auto &place = ctx.GetPlace(); - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - const auto &dev_ctx = ctx.cuda_device_context(); - if (axis != 0) { - if (index_type == framework::proto::VarType::INT32) { - phi::funcs::GatherV2CUDAFunction(x, index, axis, output, - dev_ctx); - } else if (index_type == framework::proto::VarType::INT64) { - phi::funcs::GatherV2CUDAFunction(x, index, axis, output, - dev_ctx); - } else if (index_type == framework::proto::VarType::INT16) { - phi::funcs::GatherV2CUDAFunction(x, index, axis, output, - dev_ctx); - } - return; - } - - output->mutable_data(ctx.GetPlace()); - if (x->numel() == 0) return; - if (index_type == framework::proto::VarType::INT32) { - phi::funcs::GPUGather(dev_ctx, *x, *index, output); - } else if (index_type == framework::proto::VarType::INT64) { - phi::funcs::GPUGather(dev_ctx, *x, *index, output); - } else if (index_type == framework::proto::VarType::INT16) { - phi::funcs::GPUGather(dev_ctx, *x, *index, output); - } - } -}; - -template -class GatherGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto *index = ctx.Input("Index"); - auto *dX = ctx.Output(framework::GradVarName("X")); - auto *dO = ctx.Input(framework::GradVarName("Out")); - - int axis = ctx.Attr("axis"); - if (ctx.HasInput("Axis")) { - const Tensor *axis_tensor = ctx.Input("Axis"); - Tensor cpu_axis; - framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); - const auto &axis_type = - framework::TransToProtoVarType(axis_tensor->dtype()); - if (axis_type == framework::proto::VarType::INT32) { - axis = static_cast(cpu_axis.data()[0]); - } else if (axis_type == framework::proto::VarType::INT64) { - axis = static_cast(cpu_axis.data()[0]); - } - } - - const auto &dev_ctx = ctx.cuda_device_context(); - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - if (axis != 0) { - if (index_type == framework::proto::VarType::INT32) { - phi::funcs::GatherV2GradCUDAFunction(dO, index, axis, dX, - dev_ctx); - } else if (index_type == framework::proto::VarType::INT64) { - phi::funcs::GatherV2GradCUDAFunction(dO, index, axis, dX, - dev_ctx); - } - return; - } - - dX->mutable_data(ctx.GetPlace()); - auto dxt = framework::EigenVector::Flatten(*dX); - auto &place = *ctx.template device_context() - .eigen_device(); - dxt.device(place) = dxt.constant(static_cast(0)); - if (dO->numel() == 0) return; - if (index_type == framework::proto::VarType::INT32) { - phi::funcs::GPUScatterAssign(dev_ctx, *dO, *index, dX, - ctx.Attr("overwrite")); - } else if (index_type == framework::proto::VarType::INT64) { - phi::funcs::GPUScatterAssign(dev_ctx, *dO, *index, dX, - ctx.Attr("overwrite")); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel, - ops::GatherOpCUDAKernel, - ops::GatherOpCUDAKernel, - ops::GatherOpCUDAKernel, - ops::GatherOpCUDAKernel, - ops::GatherOpCUDAKernel, - ops::GatherOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel, - ops::GatherGradOpCUDAKernel, - ops::GatherGradOpCUDAKernel, - ops::GatherGradOpCUDAKernel, - ops::GatherGradOpCUDAKernel, - ops::GatherGradOpCUDAKernel); diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h deleted file mode 100644 index 94de694b2f9bc484cdb60298b60d5a9433dac181..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/gather_op.h +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/gather.h" -#include "paddle/phi/kernels/funcs/scatter.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class GatherOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); - auto *output = ctx.Output("Out"); - - int axis = ctx.Attr("axis"); - // get axis from tensor - if (ctx.HasInput("Axis")) { - const Tensor *axis_tensor = ctx.Input("Axis"); - const auto &axis_type = axis_tensor->dtype(); - if (axis_type == phi::DataType::INT32) { - axis = static_cast(axis_tensor->data()[0]); - } else if (axis_type == phi::DataType::INT64) { - axis = static_cast(axis_tensor->data()[0]); - } - } - const auto &index_type = index->dtype(); - auto &dev_ctx = ctx.template device_context(); - if (axis != 0) { - if (index_type == phi::DataType::INT32) { - phi::funcs::GatherV2Function(dev_ctx, x, index, axis, - output); - } else if (index_type == phi::DataType::INT64) { - phi::funcs::GatherV2Function(dev_ctx, x, index, axis, - output); - } - return; - } - - output->mutable_data(ctx.GetPlace()); - if (x->numel() == 0) return; - if (index_type == phi::DataType::INT32) { - phi::funcs::CPUGather(dev_ctx, *x, *index, output); - } else if (index_type == phi::DataType::INT64) { - phi::funcs::CPUGather(dev_ctx, *x, *index, output); - } - } -}; - -template -class GatherGradientOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - - auto *index = ctx.Input("Index"); - auto *dX = ctx.Output(framework::GradVarName("X")); - auto *dO = ctx.Input(framework::GradVarName("Out")); - - int axis = ctx.Attr("axis"); - if (ctx.HasInput("Axis")) { - const Tensor *axis_tensor = ctx.Input("Axis"); - const auto &axis_type = axis_tensor->dtype(); - if (axis_type == phi::DataType::INT32) { - axis = static_cast(axis_tensor->data()[0]); - } else if (axis_type == phi::DataType::INT64) { - axis = static_cast(axis_tensor->data()[0]); - } - } - const auto &index_type = index->dtype(); - auto &dev_ctx = ctx.template device_context(); - - if (axis != 0) { - if (index_type == phi::DataType::INT32) { - phi::funcs::GatherV2GradFunction(dev_ctx, dO, index, axis, - dX); - } else if (index_type == phi::DataType::INT64) { - phi::funcs::GatherV2GradFunction(dev_ctx, dO, index, axis, - dX); - } - return; - } - - dX->mutable_data(ctx.GetPlace()); - auto dxt = framework::EigenVector::Flatten(*dX); - auto &place = *dev_ctx.eigen_device(); - dxt.device(place) = dxt.constant(static_cast(0)); - if (dO->numel() == 0) return; - bool overwrite = ctx.Attr("overwrite"); - - if (index_type == phi::DataType::INT32) { - if (overwrite) { - phi::funcs::ScatterAssign(dev_ctx, *dO, *index, dX); - } else { - phi::funcs::ScatterAssignAdd(dev_ctx, *dO, *index, dX); - } - } else if (index_type == phi::DataType::INT64) { - if (overwrite) { - phi::funcs::ScatterAssign(dev_ctx, *dO, *index, dX); - } else { - phi::funcs::ScatterAssignAdd(dev_ctx, *dO, *index, dX); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/gather_op_npu.cc b/paddle/fluid/operators/gather_op_npu.cc index 21093f585b59eea24a231b4dcdf264dc16178fbd..f996b1ede2f0fdbf7739d579380d71e9dc3448e7 100644 --- a/paddle/fluid/operators/gather_op_npu.cc +++ b/paddle/fluid/operators/gather_op_npu.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gather_op.h" #include #include #include + +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" diff --git a/paddle/fluid/operators/gather_op_npu_test.cc b/paddle/fluid/operators/gather_op_npu_test.cc index 3dce380360815c292153ef2bfb1a447357c90acb..b42050eabe300bea59c95c50c356d9e115c0dddf 100644 --- a/paddle/fluid/operators/gather_op_npu_test.cc +++ b/paddle/fluid/operators/gather_op_npu_test.cc @@ -24,16 +24,15 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/gather_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace f = paddle::framework; namespace p = paddle::platform; -USE_OP(gather); +USE_OP_ITSELF(gather); USE_OP_DEVICE_KERNEL(gather, NPU); -USE_OP(gather_grad); +USE_OP_ITSELF(gather_grad); USE_OP_DEVICE_KERNEL(gather_grad, NPU); template diff --git a/paddle/fluid/operators/gather_op_xpu.cc b/paddle/fluid/operators/gather_op_xpu.cc index 28f2f7d473bef308f581266bdb1925864aca4b78..6c691aa14ae77acc3c4ebc2077ea9182e4354d54 100644 --- a/paddle/fluid/operators/gather_op_xpu.cc +++ b/paddle/fluid/operators/gather_op_xpu.cc @@ -13,15 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/gather_op.h" #include #include #include + +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/ddim.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class GatherOpXPUKernel : public framework::OpKernel { using XPUType = typename XPUTypeTrait::Type; diff --git a/paddle/phi/kernels/cpu/gather_grad_kernel.cc b/paddle/phi/kernels/cpu/gather_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..f0a6948018afce277725c50e3cbb0e17ab495a83 --- /dev/null +++ b/paddle/phi/kernels/cpu/gather_grad_kernel.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gather_grad_kernel.h" + +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/gather.h" +#include "paddle/phi/kernels/funcs/scatter.h" + +namespace phi { + +template +void GatherGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + const Scalar& axis, + bool overwrite, + DenseTensor* x_grad) { + const auto& index_type = index.dtype(); + auto axis_v = axis.to(); + + if (axis_v != 0) { + if (index_type == phi::DataType::INT32) { + phi::funcs::GatherV2GradFunction( + dev_ctx, &out_grad, &index, axis_v, x_grad); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::GatherV2GradFunction( + dev_ctx, &out_grad, &index, axis_v, x_grad); + } + return; + } + + dev_ctx.template Alloc(x_grad); + + auto dxt = EigenVector::Flatten(*x_grad); + auto& place = *dev_ctx.eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (x_grad->numel() == 0) return; + + if (index_type == phi::DataType::INT32) { + if (overwrite) { + phi::funcs::ScatterAssign(dev_ctx, out_grad, index, x_grad); + } else { + phi::funcs::ScatterAssignAdd( + dev_ctx, out_grad, index, x_grad); + } + } else if (index_type == phi::DataType::INT64) { + if (overwrite) { + phi::funcs::ScatterAssign(dev_ctx, out_grad, index, x_grad); + } else { + phi::funcs::ScatterAssignAdd( + dev_ctx, out_grad, index, x_grad); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(gather_grad, + CPU, + ALL_LAYOUT, + phi::GatherGradKernel, + float, + double, + int, + uint8_t, + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/gather_kernel.cc b/paddle/phi/kernels/cpu/gather_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..9207a05b9dcce1daed95a1dbdb99db3c23c5c90d --- /dev/null +++ b/paddle/phi/kernels/cpu/gather_kernel.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gather_kernel.h" + +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/gather.h" + +namespace phi { + +template +void GatherKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const Scalar& axis, + DenseTensor* out) { + const auto& index_type = index.dtype(); + auto axis_v = axis.to(); + if (axis_v != 0) { + if (index_type == phi::DataType::INT32) { + phi::funcs::GatherV2Function( + dev_ctx, &x, &index, axis_v, out); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::GatherV2Function( + dev_ctx, &x, &index, axis_v, out); + } + return; + } + + dev_ctx.template Alloc(out); + + if (x.numel() == 0) { + return; + } + + if (index_type == phi::DataType::INT32) { + phi::funcs::CPUGather(dev_ctx, x, index, out); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::CPUGather(dev_ctx, x, index, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(gather, + CPU, + ALL_LAYOUT, + phi::GatherKernel, + float, + double, + int, + uint8_t, + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gather_grad_kernel.h b/paddle/phi/kernels/gather_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..e53da7b471c7b82efef2319915cc57537ee824b5 --- /dev/null +++ b/paddle/phi/kernels/gather_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GatherGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + const Scalar& axis, + bool overwrite, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/gather_kernel.h b/paddle/phi/kernels/gather_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..78ac09125b69298c59622fc69469ba8d28cae919 --- /dev/null +++ b/paddle/phi/kernels/gather_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GatherKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const Scalar& axis, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/gather_grad_kernel.cu b/paddle/phi/kernels/gpu/gather_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..04149a2f9ee41e797a66eedcb2d797fb87519041 --- /dev/null +++ b/paddle/phi/kernels/gpu/gather_grad_kernel.cu @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gather_kernel.h" + +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" + +namespace phi { + +template +void GatherGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const DenseTensor& out_grad, + const Scalar& axis, + bool overwrite, + DenseTensor* x_grad) { + const auto& index_type = index.dtype(); + auto axis_v = axis.to(); + + if (axis_v != 0) { + if (index_type == DataType::INT32) { + phi::funcs::GatherV2GradCUDAFunction( + &out_grad, &index, axis_v, x_grad, dev_ctx); + } else if (index_type == DataType::INT64) { + phi::funcs::GatherV2GradCUDAFunction( + &out_grad, &index, axis_v, x_grad, dev_ctx); + } + return; + } + + dev_ctx.template Alloc(x_grad); + auto dxt = EigenVector::Flatten(*x_grad); + auto& place = *dev_ctx.eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (out_grad.numel() == 0) return; + if (index_type == DataType::INT32) { + phi::funcs::GPUScatterAssign( + dev_ctx, out_grad, index, x_grad, overwrite); + } else if (index_type == DataType::INT64) { + phi::funcs::GPUScatterAssign( + dev_ctx, out_grad, index, x_grad, overwrite); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(gather_grad, + GPU, + ALL_LAYOUT, + phi::GatherGradKernel, + float, + double, + int64_t, + int, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gather_kernel.cu b/paddle/phi/kernels/gpu/gather_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..7e0c6cc168564e94c5af2e26a8f9ba4acc0594ed --- /dev/null +++ b/paddle/phi/kernels/gpu/gather_kernel.cu @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gather_kernel.h" + +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" + +namespace phi { + +template +void GatherKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& index, + const Scalar& axis, + DenseTensor* out) { + const auto& index_type = index.dtype(); + auto axis_v = axis.to(); + if (axis_v != 0) { + if (index_type == phi::DataType::INT32) { + phi::funcs::GatherV2CUDAFunction( + &x, &index, axis_v, out, dev_ctx); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::GatherV2CUDAFunction( + &x, &index, axis_v, out, dev_ctx); + } else if (index_type == phi::DataType::INT16) { + phi::funcs::GatherV2CUDAFunction( + &x, &index, axis_v, out, dev_ctx); + } + return; + } + + dev_ctx.template Alloc(out); + + if (x.numel() == 0) return; + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUGather(dev_ctx, x, index, out); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::GPUGather(dev_ctx, x, index, out); + } else if (index_type == phi::DataType::INT16) { + phi::funcs::GPUGather(dev_ctx, x, index, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(gather, + GPU, + ALL_LAYOUT, + phi::GatherKernel, + float, + double, + int64_t, + int, + int16_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/ops/compat/gather_sig.cc b/paddle/phi/ops/compat/gather_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c47bbe48b8ee18527cfef41fad3488bef6c1dd9 --- /dev/null +++ b/paddle/phi/ops/compat/gather_sig.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Axis")) { + return KernelSignature("gather", {"X", "Index"}, {"Axis"}, {"Out"}); + } else { + return KernelSignature("gather", {"X", "Index"}, {"axis"}, {"Out"}); + } +} + +KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Axis")) { + return KernelSignature("gather_grad", + {"X", "Index", GradVarName("Out")}, + {"Axis", "overwrite"}, + {GradVarName("X")}); + } else { + return KernelSignature("gather_grad", + {"X", "Index", GradVarName("Out")}, + {"axis", "overwrite"}, + {GradVarName("X")}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(gather, phi::GatherOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(gather_grad, phi::GatherGradOpArgumentMapping);