diff --git a/paddle/fluid/operators/merge_selected_rows_op.cc b/paddle/fluid/operators/merge_selected_rows_op.cc index ef89a730a0ff9fd73c2caf299325a23fdb2ea994..b408e42488b89d64c33b7c4e4eb565792c773918 100644 --- a/paddle/fluid/operators/merge_selected_rows_op.cc +++ b/paddle/fluid/operators/merge_selected_rows_op.cc @@ -12,33 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/merge_selected_rows_op.h" - #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { class MergeSelectedRowsOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "MergeSelectedRows"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MergeSelectedRows"); - PADDLE_ENFORCE_EQ( - ctx->GetInputsVarType("X").front(), - framework::proto::VarType::SELECTED_ROWS, - platform::errors::InvalidArgument("Input(X) of MergeSelectedRowsOp " - "should be of type SelectedRows.")); - PADDLE_ENFORCE_EQ( - ctx->GetOutputsVarType("Out").front(), - framework::proto::VarType::SELECTED_ROWS, - platform::errors::InvalidArgument("Output(Out) of MergeSelectedRowsOp " - "should be of type SelectedRows.")); - - ctx->ShareDim("X", /*->*/ "Out"); - } }; class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { @@ -95,11 +81,13 @@ class MergeSelectedRowsOpInferVarType namespace ops = paddle::operators; namespace plat = paddle::platform; + +DECLARE_INFER_SHAPE_FUNCTOR(merge_selected_rows, + MergeSelectedRowsInferMetaFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); + REGISTER_OPERATOR(merge_selected_rows, ops::MergeSelectedRowsOp, ops::MergeSelectedRowsOpMaker, - ops::MergeSelectedRowsOpInferVarType); - -REGISTER_OP_CPU_KERNEL(merge_selected_rows, - ops::MergeSelectedRowsKernel, - ops::MergeSelectedRowsKernel); + ops::MergeSelectedRowsOpInferVarType, + MergeSelectedRowsInferMetaFunctor); diff --git a/paddle/fluid/operators/merge_selected_rows_op.cu.cc b/paddle/fluid/operators/merge_selected_rows_op.cu.cc deleted file mode 100644 index 16b9b5dc6bdf13443dfbf8528c8f34391c6fe8aa..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/merge_selected_rows_op.cu.cc +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/merge_selected_rows_op.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL(merge_selected_rows, - ops::MergeSelectedRowsKernel, - ops::MergeSelectedRowsKernel); diff --git a/paddle/fluid/operators/merge_selected_rows_op.h b/paddle/fluid/operators/merge_selected_rows_op.h deleted file mode 100644 index d0f18b22b27971b9c45e921be5a8f1390bb041d5..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/merge_selected_rows_op.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/selected_rows_functor.h" - -namespace paddle { -namespace operators { - -template -class MergeSelectedRowsKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - - math::scatter::MergeAdd merge_func; - merge_func(context.template device_context(), *x, out); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 62aa990ca7bc826df129c3c961e779622d69f173..69e6539feaca3d151ad54ef9dab6405473c34734 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -365,7 +365,6 @@ register_unity_group( split_op.cu.cc activation_cudnn_op.cu.cc assign_value_op.cu.cc - merge_selected_rows_op.cu.cc run_program_op.cu.cc warpctc_op.cu.cc) register_unity_group( @@ -469,7 +468,6 @@ register_unity_group( lookup_table_v2_op.cu margin_rank_loss_op.cu masked_select_op.cu - merge_selected_rows_op.cu lstmp_op.cu shuffle_channel_op.cu softmax_cudnn_op.cu diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 64171facb51a0e8337228c37c76c3478031f8a60..7905d2094f0b3cb90d29c2645b473465e4557a32 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1762,6 +1762,14 @@ func : mean_all backward : mean_all_grad +- op : merge_selected_rows + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : merge_selected_rows {selected_rows -> selected_rows} + - op : merged_adam_ args : (Tensor[] param, Tensor[] grad, Tensor[] learning_rate, Tensor[] moment1, Tensor[] moment2, Tensor[] beta1_pow, Tensor[] beta2_pow, Tensor[] master_param, Scalar beta1, Scalar beta2, Scalar epsilon, bool multi_precision, bool use_global_beta_pow) output : Tensor[](param_out){param.size()}, Tensor[](moment1_out){param.size()}, Tensor[](moment2_out){param.size()}, Tensor[](beta1_pow_out){param.size()}, Tensor[](beta2_pow_out){param.size()}, Tensor[](master_param_out){param.size()} diff --git a/paddle/phi/kernels/selected_rows/merge_selected_rows_kernel.cc b/paddle/phi/kernels/selected_rows/merge_selected_rows_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..7bbecdbb4a8aeedba2094a5d87e2e2c3ad5ee1b9 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/merge_selected_rows_kernel.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/merge_selected_rows_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" + +namespace phi { +namespace sr { + +template +void MergeSelectedRowsKernel(const Context& dev_ctx, + const SelectedRows& x, + SelectedRows* out) { + paddle::operators::math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, x, out); +} + +} // namespace sr +} // namespace phi + +PD_REGISTER_KERNEL(merge_selected_rows, + CPU, + ALL_LAYOUT, + phi::sr::MergeSelectedRowsKernel, + float, + double) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(merge_selected_rows, + GPU, + ALL_LAYOUT, + phi::sr::MergeSelectedRowsKernel, + float, + double) {} +#endif diff --git a/paddle/phi/kernels/selected_rows/merge_selected_rows_kernel.h b/paddle/phi/kernels/selected_rows/merge_selected_rows_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..a938e9a7f2a7268bb22aeb339a6c6db7057ea4d2 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/merge_selected_rows_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/selected_rows.h" + +namespace phi { +namespace sr { + +template +void MergeSelectedRowsKernel(const Context& dev_ctx, + const SelectedRows& x, + SelectedRows* out); + +} // namespace sr +} // namespace phi diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 49de18acbb2135104a7cecf766676aa7c198f624..cd309b0ccd73e03793c9470a3a6764f702e99e6c 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -13216,6 +13216,9 @@ def merge_selected_rows(x, name=None): type=fluid.core.VarDesc.VarType.SELECTED_ROWS) y = fluid.layers.merge_selected_rows(var) """ + if in_dygraph_mode(): + return _C_ops.merge_selected_rows(x) + if _non_static_mode(): return _legacy_C_ops.merge_selected_rows(x)