diff --git a/paddle/fluid/operators/gather_nd_op.cc b/paddle/fluid/operators/gather_nd_op.cc index 8da900d84f9bcedd5e4b318837fe1bb29697a6be..fcd3384ac2444451dddcf41f9761330b29e1d64b 100644 --- a/paddle/fluid/operators/gather_nd_op.cc +++ b/paddle/fluid/operators/gather_nd_op.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gather_nd_op.h" -#include -#include -#include -#include "paddle/phi/core/ddim.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -25,48 +25,10 @@ class GatherNdOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Input(X) of GatherNdOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true, - platform::errors::InvalidArgument( - "Input(Index) of GatherNdOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output(Out) of GatherNdOp should not be null.")); - - auto x_dims = ctx->GetInputDim("X"); - auto x_dims_size = x_dims.size(); - auto index_dims = ctx->GetInputDim("Index"); - auto index_dims_size = index_dims.size(); - - PADDLE_ENFORCE_LE( - index_dims[index_dims_size - 1], x_dims_size, - platform::errors::InvalidArgument( - "Input(Index).shape[-1] should be no greater than Input(X).rank")); - PADDLE_ENFORCE_GE(index_dims_size, 1UL, - platform::errors::InvalidArgument( - "The rank of Input(Index) should be greater than 1")); - - std::vector result_dims; - // The result dims is - // Index.shape[:-1] + X.shape[Index.shape[-1]:] - for (int i = 0; i < index_dims_size - 1; ++i) { - result_dims.emplace_back(index_dims[i]); - } - for (int i = index_dims[index_dims_size - 1]; i < x_dims_size; ++i) { - result_dims.emplace_back(x_dims[i]); - } - - ctx->SetOutputDim("Out", phi::make_ddim(result_dims)); - ctx->ShareLoD("X", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); + auto* x = ctx.Input("X"); const auto& x_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType( x_type, @@ -80,11 +42,6 @@ class GatherNdGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X")); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -173,23 +130,17 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherNdGradNoNeedBufferVarInferer, "X"); namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(gather_nd, GatherNdInferShapeFunctor, + PT_INFER_META(phi::GatherNdInferMeta)); + +DELCARE_INFER_SHAPE_FUNCTOR(gather_nd_grad, GatherNdGradInferShapeFunctor, + PT_INFER_META(phi::GatherNdGradInferMeta)); + REGISTER_OPERATOR(gather_nd, ops::GatherNdOp, ops::GatherNdOpMaker, ops::GatherNdGradOpMaker, - ops::GatherNdGradOpMaker); + ops::GatherNdGradOpMaker, + GatherNdInferShapeFunctor); REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp, - ops::GatherNdGradNoNeedBufferVarInferer); - -REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel, - ops::GatherNdOpKernel, - ops::GatherNdOpKernel, - ops::GatherNdOpKernel, - ops::GatherNdOpKernel, - ops::GatherNdOpKernel, - ops::GatherNdOpKernel); - -REGISTER_OP_CPU_KERNEL(gather_nd_grad, ops::GatherNdGradOpKernel, - ops::GatherNdGradOpKernel, - ops::GatherNdGradOpKernel, - ops::GatherNdGradOpKernel, - ops::GatherNdGradOpKernel); + ops::GatherNdGradNoNeedBufferVarInferer, + GatherNdGradInferShapeFunctor); diff --git a/paddle/fluid/operators/gather_nd_op.cu b/paddle/fluid/operators/gather_nd_op.cu deleted file mode 100644 index 338c44116183415ab09881c470e6d34283b015ed..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/gather_nd_op.cu +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/gather_nd_op.h" -#include "paddle/phi/kernels/funcs/gather.cu.h" -#include "paddle/phi/kernels/funcs/scatter.cu.h" - -namespace paddle { -namespace operators { - -template -class GatherNdOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); - auto *output = ctx.Output("Out"); - - output->mutable_data(ctx.GetPlace()); - if (x->numel() == 0) return; - const auto &index_type = index->dtype(); - bool index_type_match = index_type == phi::DataType::INT32 || - index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s], but " - "desires to be [%s] or [%s].", - index_type, phi::DataType::INT32, phi::DataType::INT64)); - auto &dev_ctx = ctx.cuda_device_context(); - if (index_type == phi::DataType::INT32) { - phi::funcs::GPUGatherNd(dev_ctx, *x, *index, output); - } else if (index_type == phi::DataType::INT64) { - phi::funcs::GPUGatherNd(dev_ctx, *x, *index, output); - } - } -}; - -template -class GatherNdGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto *index = ctx.Input("Index"); - auto *dX = ctx.Output(framework::GradVarName("X")); - auto *dO = ctx.Input(framework::GradVarName("Out")); - - dX->mutable_data(ctx.GetPlace()); - auto dxt = framework::EigenVector::Flatten(*dX); - auto &place = *ctx.template device_context() - .eigen_device(); - dxt.device(place) = dxt.constant(static_cast(0)); - if (dO->numel() == 0) return; - - const auto &index_type = index->dtype(); - bool index_type_match = index_type == phi::DataType::INT32 || - index_type == phi::DataType::INT64; - - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - index_type, phi::DataType::INT32, phi::DataType::INT64)); - - auto &dev_ctx = ctx.cuda_device_context(); - if (index_type == phi::DataType::INT32) { - phi::funcs::GPUScatterNdAdd(dev_ctx, *dO, *index, dX); - } else if (index_type == phi::DataType::INT64) { - phi::funcs::GPUScatterNdAdd(dev_ctx, *dO, *index, dX); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel); - -REGISTER_OP_CUDA_KERNEL(gather_nd_grad, ops::GatherNdGradOpCUDAKernel, - ops::GatherNdGradOpCUDAKernel, - ops::GatherNdGradOpCUDAKernel, - ops::GatherNdGradOpCUDAKernel, - ops::GatherNdGradOpCUDAKernel); diff --git a/paddle/fluid/operators/gather_nd_op.h b/paddle/fluid/operators/gather_nd_op.h deleted file mode 100644 index d54261008e47b89151248a8372ede4b524d999bf..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/gather_nd_op.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/gather.h" -#include "paddle/phi/kernels/funcs/scatter.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class GatherNdOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); - auto *output = ctx.Output("Out"); - - output->mutable_data(ctx.GetPlace()); - if (x->numel() == 0) return; - - auto index_type = index->dtype(); - bool index_type_match = index_type == phi::DataType::INT32 || - index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s]", - index_type, phi::DataType::INT32, phi::DataType::INT64)); - auto &dev_ctx = ctx.template device_context(); - if (index_type == phi::DataType::INT32) { - phi::funcs::CPUGatherNd(dev_ctx, *x, *index, output); - } else if (index_type == phi::DataType::INT64) { - phi::funcs::CPUGatherNd(dev_ctx, *x, *index, output); - } - } -}; - -template -class GatherNdGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - - auto *index = ctx.Input("Index"); - auto *dX = ctx.Output(framework::GradVarName("X")); - auto *dO = ctx.Input(framework::GradVarName("Out")); - dX->mutable_data(ctx.GetPlace()); - auto dxt = framework::EigenVector::Flatten(*dX); - auto &place = *ctx.template device_context() - .eigen_device(); - dxt.device(place) = dxt.constant(static_cast(0)); - if (dO->numel() == 0) return; - - auto index_type = index->dtype(); - bool index_type_match = index_type == phi::DataType::INT32 || - index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s]", - index_type, phi::DataType::INT32, phi::DataType::INT64)); - - auto &dev_ctx = ctx.template device_context(); - if (index_type == phi::DataType::INT32) { - phi::funcs::ScatterNdAdd(dev_ctx, *dO, *index, dX); - } else if (index_type == phi::DataType::INT64) { - phi::funcs::ScatterNdAdd(dev_ctx, *dO, *index, dX); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/gather_nd_op_npu.cc b/paddle/fluid/operators/gather_nd_op_npu.cc index 995ab5d0ddf0fda19a163ec31a00a14985b5dbb9..c916f44b874a08a13fb967aae1f8b6a136023b31 100644 --- a/paddle/fluid/operators/gather_nd_op_npu.cc +++ b/paddle/fluid/operators/gather_nd_op_npu.cc @@ -12,8 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gather_nd_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" +#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/gather_nd_op_xpu.cc b/paddle/fluid/operators/gather_nd_op_xpu.cc index 9f4c522bd145bedd09fd746781cef5efec15c139..d4cb799e825b640a2a4e0a464e18d63c5e5ed516 100644 --- a/paddle/fluid/operators/gather_nd_op_xpu.cc +++ b/paddle/fluid/operators/gather_nd_op_xpu.cc @@ -11,7 +11,10 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/gather_nd_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { @@ -20,9 +23,9 @@ template class GatherNdXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); - auto *out = ctx.Output("Out"); + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *out = ctx.Output("Out"); out->template mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; diff --git a/paddle/fluid/operators/scatter_nd_add_op.cc b/paddle/fluid/operators/scatter_nd_add_op.cc index bb02bb541e14f551bb749c890877e4753d225c3c..b7be4cfb2a3950575710e0cfea52695a22a43e56 100644 --- a/paddle/fluid/operators/scatter_nd_add_op.cc +++ b/paddle/fluid/operators/scatter_nd_add_op.cc @@ -12,10 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/scatter_nd_add_op.h" #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -24,73 +27,6 @@ class ScatterNdAddOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Input(X) of ScatterNdAddOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Index"), true, - platform::errors::InvalidArgument( - "Input(Index) of ScatterNdAddOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Updates"), true, - platform::errors::InvalidArgument( - "Input(Updates) of ScatterNdAddOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output(Out) of ScatterNdAddOp should not be null.")); - - auto ref_dims = ctx->GetInputDim("X"); - auto ref_dims_size = ref_dims.size(); - auto index_dims = ctx->GetInputDim("Index"); - auto index_dims_size = index_dims.size(); - auto updates_dims = ctx->GetInputDim("Updates"); - auto updates_dims_size = updates_dims.size(); - - PADDLE_ENFORCE_LE( - index_dims[index_dims_size - 1], ref_dims_size, - platform::errors::InvalidArgument( - "The last dimension of Input(Index)'s shape should be no greater " - "than the rank of Input(X), but received the last dimension of " - "Input(Index)'s shape is %d, the rank of Input(X) is %d.", - index_dims[index_dims_size - 1], ref_dims_size)); - PADDLE_ENFORCE_GE(index_dims_size, 2UL, - platform::errors::InvalidArgument( - "The rank of Input(Index) should be greater than 1, " - "but received the rank of Input(Index) is %d.", - index_dims_size)); - - // update.shape = index.shape[:-1] + output.shape[index.shape[-1]:] - std::vector r_updates_dims; - for (int64_t i = 0; i < index_dims_size - 1; ++i) { - r_updates_dims.emplace_back(index_dims[i]); - } - for (int64_t i = index_dims[index_dims_size - 1]; i < ref_dims_size; ++i) { - r_updates_dims.emplace_back(ref_dims[i]); - } - - PADDLE_ENFORCE_EQ( - r_updates_dims.size(), updates_dims_size, - platform::errors::InvalidArgument( - "Updates has wrong shape. The shape of Updates and Input(Updates) " - "should be same, but received the shape of Updates is %d, " - "the shape of Input(Updates) is %d.", - r_updates_dims.size(), updates_dims_size)); - - for (int64_t i = 0; i < updates_dims_size; ++i) { - PADDLE_ENFORCE_EQ( - r_updates_dims[i], updates_dims[i], - platform::errors::InvalidArgument( - "Updates has wrong shape. The dimensions of Updates and " - "Input(Updates) should match, but received Updates's" - "%d-th dimension is %d, Input(Updates)'s %d-th " - "dimension is %d.", - i, r_updates_dims[i], i, updates_dims[i])); - } - ctx->SetOutputDim("Out", ref_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -99,7 +35,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "Ref and Updates must have same type")); return framework::OpKernelType( - framework::TransToProtoVarType(ctx.Input("X")->type()), + framework::TransToProtoVarType( + ctx.Input("X")->type()), ctx.device_context()); } }; @@ -108,17 +45,6 @@ class ScatterNdAddGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - if (ctx->HasOutput(framework::GradVarName("Updates"))) { - ctx->SetOutputDim(framework::GradVarName("Updates"), - ctx->GetInputDim("Updates")); - } - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), - ctx->GetInputDim(framework::GradVarName("Out"))); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -193,22 +119,18 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterNdAddGradNoNeedBufferVarsInferer, namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(scatter_nd_add, ScatterNdAddInferShapeFunctor, + PT_INFER_META(phi::ScatterNdAddInferMeta)); + +DELCARE_INFER_SHAPE_FUNCTOR(scatter_nd_add_grad, + ScatterNdAddGradInferShapeFunctor, + PT_INFER_META(phi::ScatterNdAddGradInferMeta)); + REGISTER_OPERATOR(scatter_nd_add, ops::ScatterNdAddOp, ops::ScatterNdAddOpMaker, ops::ScatterNdAddGradMaker, - ops::ScatterNdAddGradMaker); + ops::ScatterNdAddGradMaker, + ScatterNdAddInferShapeFunctor); REGISTER_OPERATOR(scatter_nd_add_grad, ops::ScatterNdAddGradOp, - ops::ScatterNdAddGradNoNeedBufferVarsInferer); - -REGISTER_OP_CPU_KERNEL(scatter_nd_add, ops::ScatterNdAddOpKernel, - ops::ScatterNdAddOpKernel, - ops::ScatterNdAddOpKernel, - ops::ScatterNdAddOpKernel, - ops::ScatterNdAddOpKernel); - -REGISTER_OP_CPU_KERNEL(scatter_nd_add_grad, - ops::ScatterNdAddGradientOpKernel, - ops::ScatterNdAddGradientOpKernel, - ops::ScatterNdAddGradientOpKernel, - ops::ScatterNdAddGradientOpKernel, - ops::ScatterNdAddGradientOpKernel); + ops::ScatterNdAddGradNoNeedBufferVarsInferer, + ScatterNdAddGradInferShapeFunctor); diff --git a/paddle/fluid/operators/scatter_nd_add_op.cu b/paddle/fluid/operators/scatter_nd_add_op.cu deleted file mode 100644 index 2fe3fcb759d348b36cd6a7a2609bea210d24705f..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/scatter_nd_add_op.cu +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/gather_op.h" -#include "paddle/fluid/operators/scatter_nd_add_op.h" -#include "paddle/phi/kernels/funcs/gather.cu.h" -#include "paddle/phi/kernels/funcs/scatter.cu.h" - -namespace paddle { -namespace operators { - -template -class ScatterNdAddOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto *X = ctx.Input("X"); - auto *Ids = ctx.Input("Index"); - auto *Updates = ctx.Input("Updates"); - auto *Out = ctx.Output("Out"); - - framework::TensorCopySync(*X, ctx.GetPlace(), Out); - const auto &index_type = Ids->dtype(); - bool index_type_match = index_type == phi::DataType::INT32 || - index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s], but " - "desires to be [%s] or [%s].", - index_type, phi::DataType::INT32, phi::DataType::INT64)); - auto &dev_ctx = ctx.cuda_device_context(); - if (index_type == phi::DataType::INT32) { - phi::funcs::GPUScatterNdAdd(dev_ctx, *Updates, *Ids, Out); - } else { - phi::funcs::GPUScatterNdAdd(dev_ctx, *Updates, *Ids, Out); - } - } -}; - -template -class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto *dX = ctx.Output(framework::GradVarName("X")); - auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); - auto *Ids = ctx.Input("Index"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - if (dX) { - framework::TensorCopy(*dOut, ctx.GetPlace(), dX); - } - if (dUpdates) { - dUpdates->mutable_data(ctx.GetPlace()); - auto &dev_ctx = ctx.cuda_device_context(); - // Gradient by Gather - const auto &index_type = Ids->dtype(); - if (index_type == phi::DataType::INT32) { - phi::funcs::GPUGatherNd(dev_ctx, *dOut, *Ids, dUpdates); - } else { - phi::funcs::GPUGatherNd(dev_ctx, *dOut, *Ids, dUpdates); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CUDA = paddle::platform::CUDADeviceContext; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL(scatter_nd_add, - ops::ScatterNdAddOpCUDAKernel, - ops::ScatterNdAddOpCUDAKernel, - ops::ScatterNdAddOpCUDAKernel, - ops::ScatterNdAddOpCUDAKernel, - ops::ScatterNdAddOpCUDAKernel); - -REGISTER_OP_CUDA_KERNEL(scatter_nd_add_grad, - ops::ScatterNdAddGradOpCUDAKernel, - ops::ScatterNdAddGradOpCUDAKernel, - ops::ScatterNdAddGradOpCUDAKernel, - ops::ScatterNdAddGradOpCUDAKernel, - ops::ScatterNdAddGradOpCUDAKernel); diff --git a/paddle/fluid/operators/scatter_nd_add_op.h b/paddle/fluid/operators/scatter_nd_add_op.h deleted file mode 100644 index 81c95fe55abaad2e126a52ac7ab97dea24fe67f0..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/scatter_nd_add_op.h +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/gather.h" -#include "paddle/phi/kernels/funcs/scatter.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class ScatterNdAddOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - auto *X = ctx.Input("X"); - auto *Ids = ctx.Input("Index"); - auto *Updates = ctx.Input("Updates"); - auto *Out = ctx.Output("Out"); - - // In place output: Out = X - framework::TensorCopySync(*X, ctx.GetPlace(), Out); - const auto &index_type = Ids->dtype(); - bool index_type_match = index_type == phi::DataType::INT32 || - index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s], but " - "desires to be [%s] or [%s].", - index_type, phi::DataType::INT32, phi::DataType::INT64)); - - auto &dev_ctx = ctx.template device_context(); - if (index_type == phi::DataType::INT32) { - phi::funcs::ScatterNdAdd(dev_ctx, *Updates, *Ids, Out); - } else { - phi::funcs::ScatterNdAdd(dev_ctx, *Updates, *Ids, Out); - } - } -}; - -template -class ScatterNdAddGradientOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - auto *dX = ctx.Output(framework::GradVarName("X")); - auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); - auto *Ids = ctx.Input("Index"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - - if (dX) { - framework::TensorCopy(*dOut, ctx.GetPlace(), dX); - } - if (dUpdates) { - dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates = dO[Ids] - const auto &index_type = Ids->dtype(); - auto &dev_ctx = ctx.template device_context(); - if (index_type == phi::DataType::INT32) { - phi::funcs::CPUGatherNd(dev_ctx, *dOut, *Ids, dUpdates); - } else { - phi::funcs::CPUGatherNd(dev_ctx, *dOut, *Ids, dUpdates); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 3174f07e96e227c8a2f1103d3d6664673c7a2d56..fec003305fdc6508bcc0153cf14b8bd08eddeebf 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/scatter_op.h" #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -23,46 +26,6 @@ class ScatterOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Input(X) of ScatterOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true, - platform::errors::InvalidArgument( - "Input(Ids) of ScatterOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Updates"), true, - platform::errors::InvalidArgument( - "Input(Updates) of ScatterOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output(Out) of ScatterOp should not be null.")); - - auto updates_dims = ctx->GetInputDim("Updates"); - auto ref_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Ids").size(), 1, - platform::errors::InvalidArgument( - "The size of Input(Ids)'s shape should be equal to 1, but " - "received the rank of Input(Ids) is %d.", - ctx->GetInputDim("Ids").size())); - PADDLE_ENFORCE_EQ( - ref_dims.size(), updates_dims.size(), - platform::errors::InvalidArgument( - "Input(X) and Input(Updates) should have the same shape size, " - "but received the size of Input(x)'s shape is %d, the size of " - "Input(Updates)'s shape is %d.", - ref_dims.size(), updates_dims.size())); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Updates")[0], ctx->GetInputDim("Ids")[0], - platform::errors::InvalidArgument( - "Input(Updates) and Input(Ids) should have same batch-size, but" - " received Input(Updates)'s batch-size is %d, Input(Ids)'s " - "batch-size is %d.", - ctx->GetInputDim("Updates")[0], ctx->GetInputDim("Ids")[0])); - ctx->SetOutputDim("Out", ref_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -76,17 +39,6 @@ class ScatterGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - if (ctx->HasOutput(framework::GradVarName("Updates"))) { - ctx->SetOutputDim(framework::GradVarName("Updates"), - ctx->GetInputDim("Updates")); - } - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), - ctx->GetInputDim(framework::GradVarName("Out"))); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -151,17 +103,17 @@ DECLARE_INPLACE_OP_INFERER(ScatterInplaceInferer, {"X", "Out"}); } // namespace operators } // namespace paddle +DELCARE_INFER_SHAPE_FUNCTOR(scatter, ScatterInferShapeFunctor, + PT_INFER_META(phi::ScatterInferMeta)); + +DELCARE_INFER_SHAPE_FUNCTOR(scatter_grad, ScatterGradInferShapeFunctor, + PT_INFER_META(phi::ScatterGradInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR(scatter, ops::ScatterOp, ops::ScatterOpMaker, ops::ScatterGradMaker, ops::ScatterGradMaker, - ops::ScatterInplaceInferer); + ops::ScatterInplaceInferer, ScatterInferShapeFunctor); REGISTER_OPERATOR(scatter_grad, ops::ScatterGradOp, - ops::ScatterGradNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel, - ops::ScatterOpKernel, ops::ScatterOpKernel, - ops::ScatterOpKernel); -REGISTER_OP_CPU_KERNEL(scatter_grad, ops::ScatterGradientOpKernel, - ops::ScatterGradientOpKernel, - ops::ScatterGradientOpKernel, - ops::ScatterGradientOpKernel); + ops::ScatterGradNoNeedBufferVarsInferer, + ScatterGradInferShapeFunctor); diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu deleted file mode 100644 index 7755e376bc1956a1f9e09dc2eb8aead9fa083157..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/scatter_op.cu +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/gather_op.h" -#include "paddle/fluid/operators/scatter_op.h" -#include "paddle/phi/kernels/funcs/gather.cu.h" -#include "paddle/phi/kernels/funcs/scatter.cu.h" - -namespace paddle { -namespace operators { - -template -class ScatterOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto *X = ctx.Input("X"); - auto *Ids = ctx.Input("Ids"); - auto *Updates = ctx.Input("Updates"); - auto *Out = ctx.Output("Out"); - bool overwrite = ctx.Attr("overwrite"); - - framework::TensorCopy(*X, ctx.GetPlace(), Out); - // use template class to support int32_t and int64_t - auto index_type = Ids->dtype(); - bool index_type_match = index_type == phi::DataType::INT32 || - index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "scatter_op Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - index_type, phi::DataType::INT32, phi::DataType::INT64)); - auto &dev_ctx = ctx.cuda_device_context(); - if (index_type == phi::DataType::INT32) { - phi::funcs::GPUScatterAssign(dev_ctx, *Updates, *Ids, Out, - overwrite); - } else { - phi::funcs::GPUScatterAssign(dev_ctx, *Updates, *Ids, Out, - overwrite); - } - } -}; - -template -class ScatterGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet( - "This kernel only runs on GPU device.")); - auto *dX = ctx.Output(framework::GradVarName("X")); - auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); - auto *Ids = ctx.Input("Ids"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - - auto index_type = Ids->dtype(); - bool index_type_match = index_type == phi::DataType::INT32 || - index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "scatter_op index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s]", - index_type, phi::DataType::INT32, phi::DataType::INT64)); - - auto &dev_ctx = ctx.cuda_device_context(); - if (dX) { - framework::TensorCopy(*dOut, ctx.GetPlace(), dX); - if (index_type == phi::DataType::INT32) { - phi::funcs::GPUScatterGradForX(dev_ctx, *Ids, dX); - } else { - phi::funcs::GPUScatterGradForX(dev_ctx, *Ids, dX); - } - } - - if (dUpdates) { - dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates = dO[Ids] - if (index_type == phi::DataType::INT32) { - phi::funcs::GPUGather(dev_ctx, *dOut, *Ids, dUpdates); - } else { - phi::funcs::GPUGather(dev_ctx, *dOut, *Ids, dUpdates); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(scatter, ops::ScatterOpCUDAKernel, - ops::ScatterOpCUDAKernel, - ops::ScatterOpCUDAKernel, - ops::ScatterOpCUDAKernel, - ops::ScatterOpCUDAKernel); - -REGISTER_OP_CUDA_KERNEL( - scatter_grad, ops::ScatterGradOpCUDAKernel, - ops::ScatterGradOpCUDAKernel, ops::ScatterOpCUDAKernel, - ops::ScatterOpCUDAKernel, - ops::ScatterGradOpCUDAKernel); diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h deleted file mode 100644 index 7733181a93fb60c116ff3da964336b0a85d9a84c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/scatter_op.h +++ /dev/null @@ -1,113 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/gather.h" -#include "paddle/phi/kernels/funcs/scatter.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class ScatterOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - auto *X = ctx.Input("X"); - auto *Ids = ctx.Input("Ids"); - auto *Updates = ctx.Input("Updates"); - auto *Out = ctx.Output("Out"); - double overwrite = ctx.Attr("overwrite"); - - // In place output: Out = X, Out[Ids] = Updates - framework::TensorCopy(*X, ctx.GetPlace(), Out); - // Apply ScatterUpdate: Out[index] = Updates[:] - const auto &index_type = Ids->dtype(); - bool index_type_match = index_type == phi::DataType::INT32 || - index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - index_type, phi::DataType::INT32, phi::DataType::INT64)); - auto &dev_ctx = ctx.template device_context(); - if (overwrite) { - if (index_type == phi::DataType::INT32) { - phi::funcs::ScatterAssign(dev_ctx, *Updates, *Ids, Out); - } else { - phi::funcs::ScatterAssign(dev_ctx, *Updates, *Ids, Out); - } - } else { - if (index_type == phi::DataType::INT32) { - phi::funcs::ScatterAssignAdd(dev_ctx, *Updates, *Ids, Out); - } else { - phi::funcs::ScatterAssignAdd(dev_ctx, *Updates, *Ids, Out); - } - } - } -}; - -template -class ScatterGradientOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), true, - platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); - auto *dX = ctx.Output(framework::GradVarName("X")); - auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); - auto *Ids = ctx.Input("Ids"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - - const auto &index_type = Ids->dtype(); - bool index_type_match = index_type == phi::DataType::INT32 || - index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "scatter_op index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s]", - index_type, phi::DataType::INT32, phi::DataType::INT64)); - - auto &dev_ctx = ctx.template device_context(); - if (dX) { - framework::TensorCopy(*dOut, ctx.GetPlace(), dX); - if (index_type == phi::DataType::INT32) { - phi::funcs::CPUScatterGradForX(dev_ctx, *Ids, dX); - } else { - phi::funcs::CPUScatterGradForX(dev_ctx, *Ids, dX); - } - } - - if (dUpdates) { - dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates = dO[Ids] - if (index_type == phi::DataType::INT32) { - phi::funcs::CPUGather(dev_ctx, *dOut, *Ids, dUpdates); - } else { - phi::funcs::CPUGather(dev_ctx, *dOut, *Ids, dUpdates); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/scatter_op_npu.cc b/paddle/fluid/operators/scatter_op_npu.cc index fa5f03a092882ec1f63e9556bc38d94ed40c9a7f..815984ac307fdce14a64f01a661b4b7f7ce1d616 100644 --- a/paddle/fluid/operators/scatter_op_npu.cc +++ b/paddle/fluid/operators/scatter_op_npu.cc @@ -17,7 +17,6 @@ limitations under the License. */ #include #include "paddle/fluid/operators/kron_op.h" -#include "paddle/fluid/operators/scatter_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/scatter_op_xpu.cc b/paddle/fluid/operators/scatter_op_xpu.cc index 9f0b74e8a3f80c5c8a22c2db109f75e6ee316be1..07dd2f2d85fe9ac330be1f85d283c85207b1b78c 100644 --- a/paddle/fluid/operators/scatter_op_xpu.cc +++ b/paddle/fluid/operators/scatter_op_xpu.cc @@ -16,7 +16,10 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/scatter_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 7d403fee94300e9517fcc517f4d088470d772e35..4ddef5b0002e286181ce5ac1ad198136424861a9 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -105,4 +105,49 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out, dx->share_meta(dout); } +void GatherNdGradInferMeta(const MetaTensor& x, + const MetaTensor& index, + const MetaTensor& out_grad, + MetaTensor* x_grad) { + const auto& dtype = out_grad.dtype(); + x_grad->set_dims(x.dims()); + x_grad->share_lod(x); + x_grad->set_dtype(dtype); +} + +void ScatterGradInferMeta(const MetaTensor& index, + const MetaTensor& updates, + const MetaTensor& out_grad, + bool overwrite, + MetaTensor* x_grad, + MetaTensor* updates_grad) { + const auto& dtype = out_grad.dtype(); + if (updates_grad) { + updates_grad->set_dims(updates.dims()); + updates_grad->set_dtype(dtype); + } + + if (x_grad) { + x_grad->set_dims(out_grad.dims()); + x_grad->set_dtype(dtype); + } +} + +void ScatterNdAddGradInferMeta(const MetaTensor& index, + const MetaTensor& updates, + const MetaTensor& out_grad, + MetaTensor* x_grad, + MetaTensor* updates_grad) { + const auto& dtype = out_grad.dtype(); + if (updates_grad) { + updates_grad->set_dims(updates.dims()); + updates_grad->set_dtype(dtype); + } + + if (x_grad) { + x_grad->set_dims(out_grad.dims()); + x_grad->set_dtype(dtype); + } +} + } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index f2c0cf8a6896610332fa0470965e32ea2fcd5530..f7b0eed5dd929e180810af52914e9a3139676e8a 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -46,4 +46,18 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out, const MetaTensor& dout, int axis, MetaTensor* dx); + +void ScatterGradInferMeta(const MetaTensor& index, + const MetaTensor& updates, + const MetaTensor& out_grad, + bool overwrite, + MetaTensor* x_grad, + MetaTensor* updates_grad); + +void ScatterNdAddGradInferMeta(const MetaTensor& index, + const MetaTensor& updates, + const MetaTensor& out_grad, + MetaTensor* x_grad, + MetaTensor* updates_grad); + } // namespace phi diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 1f6f0b211b66ddb3e8695e8f94854559bdeea8e5..745ddffabbe33f22492fc985b19251cb88c4f551 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -397,6 +397,39 @@ void BCELossInferMeta(const MetaTensor& input, out->share_lod(input); } +void GatherNdInferMeta(const MetaTensor& x, + const MetaTensor& index, + MetaTensor* out) { + auto x_dims = x.dims(); + auto x_dims_size = x_dims.size(); + auto index_dims = index.dims(); + auto index_dims_size = index_dims.size(); + + PADDLE_ENFORCE_LE( + index_dims[index_dims_size - 1], + x_dims_size, + phi::errors::InvalidArgument( + "Input(Index).shape[-1] should be no greater than Input(X).rank")); + PADDLE_ENFORCE_GE(index_dims_size, + 1UL, + phi::errors::InvalidArgument( + "The rank of Input(Index) should be greater than 1")); + + std::vector result_dims; + // The result dims is + // Index.shape[:-1] + X.shape[Index.shape[-1]:] + for (int i = 0; i < index_dims_size - 1; ++i) { + result_dims.emplace_back(index_dims[i]); + } + for (int i = index_dims[index_dims_size - 1]; i < x_dims_size; ++i) { + result_dims.emplace_back(x_dims[i]); + } + + out->set_dims(phi::make_ddim(result_dims)); + out->share_lod(x); + out->set_dtype(x.dtype()); +} + void GatherTreeMeta(const MetaTensor& ids, const MetaTensor& parents, MetaTensor* out) { diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 47745f8ce13dc3d96c521c8365f3cb0f06df0aa2..2ec744636988f29c80baaabc11b5a43617b47465 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -78,6 +78,10 @@ void BCELossInferMeta(const MetaTensor& input, MetaTensor* out, MetaConfig config = MetaConfig()); +void GatherNdInferMeta(const MetaTensor& x, + const MetaTensor& index, + MetaTensor* out); + void GatherTreeMeta(const MetaTensor& ids, const MetaTensor& parents, MetaTensor* out); diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 1c1497fb0e4569a4fe183b687a2155b0bf22110e..c3472a24801fd8d3c67187e786ad104d3da59ab7 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -89,6 +89,109 @@ void AddmmInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void ScatterInferMeta(const MetaTensor& x, + const MetaTensor& index, + const MetaTensor& updates, + bool overwrite, + MetaTensor* out) { + const auto& updates_dims = updates.dims(); + const auto& ref_dims = x.dims(); + const auto& index_dims = index.dims(); + PADDLE_ENFORCE_EQ( + index_dims.size(), + 1, + phi::errors::InvalidArgument( + "The size of Input(Ids)'s shape should be equal to 1, but " + "received the rank of Input(Ids) is %d.", + index_dims.size())); + PADDLE_ENFORCE_EQ( + ref_dims.size(), + updates_dims.size(), + phi::errors::InvalidArgument( + "Input(X) and Input(Updates) should have the same shape size, " + "but received the size of Input(x)'s shape is %d, the size of " + "Input(Updates)'s shape is %d.", + ref_dims.size(), + updates_dims.size())); + PADDLE_ENFORCE_EQ( + updates_dims[0], + index_dims[0], + phi::errors::InvalidArgument( + "Input(Updates) and Input(Ids) should have same batch-size, but" + " received Input(Updates)'s batch-size is %d, Input(Ids)'s " + "batch-size is %d.", + updates_dims[0], + index_dims[0])); + out->set_dims(ref_dims); + out->share_lod(x); + out->set_dtype(x.dtype()); +} + +void ScatterNdAddInferMeta(const MetaTensor& x, + const MetaTensor& index, + const MetaTensor& updates, + MetaTensor* out) { + const auto& ref_dims = x.dims(); + auto ref_dims_size = ref_dims.size(); + const auto& index_dims = index.dims(); + auto index_dims_size = index_dims.size(); + const auto& updates_dims = updates.dims(); + auto updates_dims_size = updates_dims.size(); + + PADDLE_ENFORCE_LE( + index_dims[index_dims_size - 1], + ref_dims_size, + phi::errors::InvalidArgument( + "The last dimension of Input(Index)'s shape should be no greater " + "than the rank of Input(X), but received the last dimension of " + "Input(Index)'s shape is %d, the rank of Input(X) is %d.", + index_dims[index_dims_size - 1], + ref_dims_size)); + PADDLE_ENFORCE_GE(index_dims_size, + 2UL, + phi::errors::InvalidArgument( + "The rank of Input(Index) should be greater than 1, " + "but received the rank of Input(Index) is %d.", + index_dims_size)); + + // update.shape = index.shape[:-1] + output.shape[index.shape[-1]:] + std::vector r_updates_dims; + for (int64_t i = 0; i < index_dims_size - 1; ++i) { + r_updates_dims.emplace_back(index_dims[i]); + } + for (int64_t i = index_dims[index_dims_size - 1]; i < ref_dims_size; ++i) { + r_updates_dims.emplace_back(ref_dims[i]); + } + + PADDLE_ENFORCE_EQ( + r_updates_dims.size(), + updates_dims_size, + phi::errors::InvalidArgument( + "Updates has wrong shape. The shape of Updates and Input(Updates) " + "should be same, but received the shape of Updates is %d, " + "the shape of Input(Updates) is %d.", + r_updates_dims.size(), + updates_dims_size)); + + for (int64_t i = 0; i < updates_dims_size; ++i) { + PADDLE_ENFORCE_EQ( + r_updates_dims[i], + updates_dims[i], + phi::errors::InvalidArgument( + "Updates has wrong shape. The dimensions of Updates and " + "Input(Updates) should match, but received Updates's" + "%d-th dimension is %d, Input(Updates)'s %d-th " + "dimension is %d.", + i, + r_updates_dims[i], + i, + updates_dims[i])); + } + out->set_dims(ref_dims); + out->share_lod(x); + out->set_dtype(x.dtype()); +} + void LerpInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 5679c5b533f1e308be3d8d064eafa653582cd19e..cff57e1ba7078c1765732c19e243aa6655397ec3 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -37,6 +37,22 @@ void AddmmInferMeta(const MetaTensor& input, float beta, MetaTensor* out); +void GatherNdGradInferMeta(const MetaTensor& x, + const MetaTensor& index, + const MetaTensor& out_grad, + MetaTensor* x_grad); + +void ScatterInferMeta(const MetaTensor& x, + const MetaTensor& index, + const MetaTensor& updates, + bool overwrite, + MetaTensor* out); + +void ScatterNdAddInferMeta(const MetaTensor& x, + const MetaTensor& index, + const MetaTensor& updates, + MetaTensor* out); + void LerpInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/kernels/cpu/gather_nd_grad_kernel.cc b/paddle/phi/kernels/cpu/gather_nd_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..b375a7ec4691c723f2f029c39b7e364b8332c402 --- /dev/null +++ b/paddle/phi/kernels/cpu/gather_nd_grad_kernel.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gather_nd_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/scatter.h" + +namespace phi { + +template +void GatherNdGradKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &out_grad, + DenseTensor *x_grad) { + ctx.template Alloc(x_grad); + auto dxt = phi::EigenVector::Flatten(*x_grad); + auto &place = *ctx.eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (out_grad.numel() == 0) return; + + auto index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, + true, + phi::errors::InvalidArgument("Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + if (index_type == phi::DataType::INT32) { + phi::funcs::ScatterNdAdd(ctx, out_grad, index, x_grad); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::ScatterNdAdd(ctx, out_grad, index, x_grad); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(gather_nd_grad, + CPU, + ALL_LAYOUT, + phi::GatherNdGradKernel, + float, + double, + int64_t, + int, + uint8_t) {} diff --git a/paddle/phi/kernels/cpu/gather_nd_kernel.cc b/paddle/phi/kernels/cpu/gather_nd_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa32d036934e838b7630a19a152e0c14de907253 --- /dev/null +++ b/paddle/phi/kernels/cpu/gather_nd_kernel.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gather_nd_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/gather.h" + +namespace phi { + +template +void GatherNdKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + DenseTensor *out) { + ctx.template Alloc(out); + if (x.numel() == 0) return; + + auto index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, + true, + phi::errors::InvalidArgument("Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + if (index_type == phi::DataType::INT32) { + phi::funcs::CPUGatherNd(ctx, x, index, out); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::CPUGatherNd(ctx, x, index, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(gather_nd, + CPU, + ALL_LAYOUT, + phi::GatherNdKernel, + float, + double, + int64_t, + int, + int16_t, + bool, + uint8_t) {} diff --git a/paddle/phi/kernels/cpu/scatter_grad_kernel.cc b/paddle/phi/kernels/cpu/scatter_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..62fd58704c4fef7c23cd8255d6958103b9755bff --- /dev/null +++ b/paddle/phi/kernels/cpu/scatter_grad_kernel.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/scatter_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/gather.h" +#include "paddle/phi/kernels/funcs/scatter.h" + +namespace phi { + +template +void ScatterGradKernel(const Context &ctx, + const DenseTensor &index, + const DenseTensor &updates, + const DenseTensor &out_grad, + bool overwrite, + DenseTensor *x_grad, + DenseTensor *updates_grad) { + const auto &index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "scatter_op index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + if (x_grad) { + phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + if (index_type == phi::DataType::INT32) { + phi::funcs::CPUScatterGradForX(ctx, index, x_grad); + } else { + phi::funcs::CPUScatterGradForX(ctx, index, x_grad); + } + } + + if (updates_grad) { + ctx.template Alloc(updates_grad); + // Gradient by Gather: dUpdates = dO[Ids] + if (index_type == phi::DataType::INT32) { + phi::funcs::CPUGather(ctx, out_grad, index, updates_grad); + } else { + phi::funcs::CPUGather(ctx, out_grad, index, updates_grad); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(scatter_grad, + CPU, + ALL_LAYOUT, + phi::ScatterGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/scatter_kernel.cc b/paddle/phi/kernels/cpu/scatter_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..d48ceaf29a08c58de6f06746c36f2a8e8725852f --- /dev/null +++ b/paddle/phi/kernels/cpu/scatter_kernel.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/scatter_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/scatter.h" + +namespace phi { + +template +void ScatterKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &updates, + bool overwrite, + DenseTensor *out) { + // In place output: Out = X, Out[Ids] = Updates + phi::Copy(ctx, x, ctx.GetPlace(), false, out); + // Apply ScatterUpdate: Out[index] = Updates[:] + const auto &index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, + true, + phi::errors::InvalidArgument("Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s].", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + if (overwrite) { + if (index_type == phi::DataType::INT32) { + phi::funcs::ScatterAssign(ctx, updates, index, out); + } else { + phi::funcs::ScatterAssign(ctx, updates, index, out); + } + } else { + if (index_type == phi::DataType::INT32) { + phi::funcs::ScatterAssignAdd(ctx, updates, index, out); + } else { + phi::funcs::ScatterAssignAdd(ctx, updates, index, out); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + scatter, CPU, ALL_LAYOUT, phi::ScatterKernel, float, double, int, int64_t) { +} diff --git a/paddle/phi/kernels/cpu/scatter_nd_add_grad_kernel.cc b/paddle/phi/kernels/cpu/scatter_nd_add_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc143ba8d0e4557f8aaf07a4d4606bbf6c2b4d73 --- /dev/null +++ b/paddle/phi/kernels/cpu/scatter_nd_add_grad_kernel.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/scatter_nd_add_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/gather.h" + +namespace phi { + +template +void ScatterNdAddGradKernel(const Context &ctx, + const DenseTensor &index, + const DenseTensor &updates, + const DenseTensor &out_grad, + DenseTensor *x_grad, + DenseTensor *updates_grad) { + if (x_grad) { + Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + } + if (updates_grad) { + ctx.template Alloc(updates_grad); + // Gradient by Gather: dUpdates = dO[Ids] + const auto &index_type = index.dtype(); + if (index_type == phi::DataType::INT32) { + phi::funcs::CPUGatherNd(ctx, out_grad, index, updates_grad); + } else { + phi::funcs::CPUGatherNd(ctx, out_grad, index, updates_grad); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(scatter_nd_add_grad, + CPU, + ALL_LAYOUT, + phi::ScatterNdAddGradKernel, + float, + double, + int64_t, + int, + uint8_t) {} diff --git a/paddle/phi/kernels/cpu/scatter_nd_add_kernel.cc b/paddle/phi/kernels/cpu/scatter_nd_add_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..04ae10f5e8b5d551819a97ea1594140e535e6a12 --- /dev/null +++ b/paddle/phi/kernels/cpu/scatter_nd_add_kernel.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/scatter_nd_add_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/scatter.h" + +namespace phi { + +template +void ScatterNdAddKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &updates, + DenseTensor *out) { + // In place output: Out = X + Copy(ctx, x, ctx.GetPlace(), true, out); + const auto &index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s], but " + "desires to be [%s] or [%s].", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + if (index_type == phi::DataType::INT32) { + phi::funcs::ScatterNdAdd(ctx, updates, index, out); + } else { + phi::funcs::ScatterNdAdd(ctx, updates, index, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(scatter_nd_add, + CPU, + ALL_LAYOUT, + phi::ScatterNdAddKernel, + float, + double, + int64_t, + int, + uint8_t) {} diff --git a/paddle/phi/kernels/gather_nd_grad_kernel.h b/paddle/phi/kernels/gather_nd_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..050034714957fe749d6608243dfedc4d30d66e88 --- /dev/null +++ b/paddle/phi/kernels/gather_nd_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GatherGradNdKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &out_grad, + DenseTensor *x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/gather_nd_kernel.h b/paddle/phi/kernels/gather_nd_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d2393eb3b0709345cb2d3ec63d739cb223a79683 --- /dev/null +++ b/paddle/phi/kernels/gather_nd_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GatherNdKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + DenseTensor *out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu b/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..5273902804a200bdf36e8e36748639758145e742 --- /dev/null +++ b/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" +#include "paddle/phi/kernels/gather_nd_grad_kernel.h" + +namespace phi { + +template +void GatherNdGradKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &out_grad, + DenseTensor *x_grad) { + ctx.template Alloc(x_grad); + auto dxt = phi::EigenVector::Flatten(*x_grad); + auto &place = *ctx.eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (out_grad.numel() == 0) return; + + const auto &index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + + PADDLE_ENFORCE_EQ( + index_type_match, + true, + phi::errors::InvalidArgument("Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s].", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUScatterNdAdd(ctx, out_grad, index, x_grad); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::GPUScatterNdAdd(ctx, out_grad, index, x_grad); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(gather_nd_grad, + GPU, + ALL_LAYOUT, + phi::GatherNdGradKernel, + float, + double, + int64_t, + int, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/gather_nd_kernel.cu b/paddle/phi/kernels/gpu/gather_nd_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..33745ef5f07e82387edfabf89d527b5b698fbabe --- /dev/null +++ b/paddle/phi/kernels/gpu/gather_nd_kernel.cu @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" +#include "paddle/phi/kernels/gather_nd_kernel.h" + +namespace phi { + +template +void GatherNdKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + DenseTensor *out) { + ctx.template Alloc(out); + if (x.numel() == 0) return; + const auto &index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s], but " + "desires to be [%s] or [%s].", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUGatherNd(ctx, x, index, out); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::GPUGatherNd(ctx, x, index, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(gather_nd, + GPU, + ALL_LAYOUT, + phi::GatherNdKernel, + float, + double, + int64_t, + int, + int16_t, + bool, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/scatter_grad_kernel.cu b/paddle/phi/kernels/gpu/scatter_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..75506e2a0a17b269cb1327c01c1c4f3825eadb37 --- /dev/null +++ b/paddle/phi/kernels/gpu/scatter_grad_kernel.cu @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" +#include "paddle/phi/kernels/scatter_grad_kernel.h" + +namespace phi { + +template +void ScatterGradKernel(const Context &ctx, + const DenseTensor &index, + const DenseTensor &updates, + const DenseTensor &out_grad, + bool overwrite, + DenseTensor *x_grad, + DenseTensor *updates_grad) { + auto index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "scatter_op index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + if (x_grad) { + phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUScatterGradForX(ctx, index, x_grad); + } else { + phi::funcs::GPUScatterGradForX(ctx, index, x_grad); + } + } + + if (updates_grad) { + ctx.template Alloc(updates_grad); + // Gradient by Gather: dUpdates = dO[Ids] + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUGather(ctx, out_grad, index, updates_grad); + } else { + phi::funcs::GPUGather(ctx, out_grad, index, updates_grad); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(scatter_grad, + GPU, + ALL_LAYOUT, + phi::ScatterGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/scatter_kernel.cu b/paddle/phi/kernels/gpu/scatter_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..811eae1bc028ec5484d173d1b4373111546d73b4 --- /dev/null +++ b/paddle/phi/kernels/gpu/scatter_kernel.cu @@ -0,0 +1,62 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" +#include "paddle/phi/kernels/scatter_kernel.h" + +namespace phi { + +template +void ScatterKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &updates, + bool overwrite, + DenseTensor *out) { + phi::Copy(ctx, x, ctx.GetPlace(), false, out); + // use template class to support int32_t and int64_t + auto index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "scatter_op Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s].", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUScatterAssign( + ctx, updates, index, out, overwrite); + } else { + phi::funcs::GPUScatterAssign( + ctx, updates, index, out, overwrite); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(scatter, + GPU, + ALL_LAYOUT, + phi::ScatterKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu b/paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..71924befe8cf9383c523eee059c3efa79dd6a262 --- /dev/null +++ b/paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/scatter_nd_add_grad_kernel.h" + +namespace phi { + +template +void ScatterNdAddGradKernel(const Context &ctx, + const DenseTensor &index, + const DenseTensor &updates, + const DenseTensor &out_grad, + DenseTensor *x_grad, + DenseTensor *updates_grad) { + if (x_grad) { + Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + } + if (updates_grad) { + ctx.template Alloc(updates_grad); + // Gradient by Gather + const auto &index_type = index.dtype(); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUGatherNd(ctx, out_grad, index, updates_grad); + } else { + phi::funcs::GPUGatherNd(ctx, out_grad, index, updates_grad); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(scatter_nd_add_grad, + GPU, + ALL_LAYOUT, + phi::ScatterNdAddGradKernel, + float, + double, + int64_t, + int, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu b/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..eadd91773c00810e3f4187d079926028733a4945 --- /dev/null +++ b/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu @@ -0,0 +1,58 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" +#include "paddle/phi/kernels/scatter_nd_add_kernel.h" + +namespace phi { + +template +void ScatterNdAddKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &updates, + DenseTensor *out) { + Copy(ctx, x, ctx.GetPlace(), true, out); + const auto &index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s], but " + "desires to be [%s] or [%s].", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUScatterNdAdd(ctx, updates, index, out); + } else { + phi::funcs::GPUScatterNdAdd(ctx, updates, index, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(scatter_nd_add, + GPU, + ALL_LAYOUT, + phi::ScatterNdAddKernel, + float, + double, + int64_t, + int, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/scatter_grad_kernel.h b/paddle/phi/kernels/scatter_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cf1482fca7f667eae530e09a95df0020186aef77 --- /dev/null +++ b/paddle/phi/kernels/scatter_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ScatterGradKernel(const Context &ctx, + const DenseTensor &index, + const DenseTensor &updates, + const DenseTensor &out_grad, + bool overwrite, + DenseTensor *x_grad, + DenseTensor *updates_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/scatter_kernel.h b/paddle/phi/kernels/scatter_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..5191d6bce45f26c69d5a2fe61c23f71039c2433c --- /dev/null +++ b/paddle/phi/kernels/scatter_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ScatterKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &updates, + bool overwrite, + DenseTensor *out); + +} // namespace phi diff --git a/paddle/phi/kernels/scatter_nd_add_grad_kernel.h b/paddle/phi/kernels/scatter_nd_add_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..bcfdb2cdb2f09e53ed9da1bc3995587737f9492c --- /dev/null +++ b/paddle/phi/kernels/scatter_nd_add_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ScatterNdAddGradKernel(const Context &ctx, + const DenseTensor &index, + const DenseTensor &updates, + const DenseTensor &out_grad, + DenseTensor *x_grad, + DenseTensor *updates_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/scatter_nd_add_kernel.h b/paddle/phi/kernels/scatter_nd_add_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c20709dccc08c16650dd1fc00ec10a91c333e13f --- /dev/null +++ b/paddle/phi/kernels/scatter_nd_add_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ScatterNdAddKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &updates, + DenseTensor *out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/gather_scatter_sig.cc b/paddle/phi/ops/compat/gather_scatter_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..f71e30f85b09df041b02fbd4f34b69c0e85f92da --- /dev/null +++ b/paddle/phi/ops/compat/gather_scatter_sig.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("gather_nd_grad", + {"X", "Index", GradVarName("Out")}, + {}, + {GradVarName("X")}); +} + +KernelSignature ScatterGradArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("scatter_grad", + {"Ids", "Updates", GradVarName("Out")}, + {"overwrite"}, + {GradVarName("X"), GradVarName("Updates")}); +} + +KernelSignature ScatterNdAddGradArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("scatter_nd_add_grad", + {"Index", "Updates", GradVarName("Out")}, + {}, + {GradVarName("X"), GradVarName("Updates")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(gather_nd_grad, phi::GatherNdGradArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(scatter_grad, phi::ScatterGradArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(scatter_nd_add_grad, + phi::ScatterNdAddGradArgumentMapping);