diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op.cc index b7831972872894500c0b7b6ed60d0e19d228b9ff..1fd77d3ab96d6c6eb99b21b3d28cf3d13c8f89b8 100644 --- a/paddle/fluid/operators/collective/c_allreduce_sum_op.cc +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op.cc @@ -62,12 +62,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceSumInplaceInferer, {"X", "Out"}); namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OPERATOR(c_allreduce_sum, - ops::CAllReduceOp, - ops::CAllReduceSumOpGradMaker, - ops::CAllReduceSumOpGradMaker, - ops::CAllReduceSumOpMaker, - ops::AllreduceSumInplaceInferer); +REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_sum, + ops::CAllReduceOp, + ops::CAllReduceSumOpMaker, + ops::AllreduceSumInplaceInferer) REGISTER_OP_CPU_KERNEL(c_allreduce_sum, ops::CAllReduceOpCPUKernel, diff --git a/paddle/fluid/operators/collective/mp_allreduce_sum_op.cc b/paddle/fluid/operators/collective/mp_allreduce_sum_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7000c46d64504dd19ce436edd9b9d54f93a76482 --- /dev/null +++ b/paddle/fluid/operators/collective/mp_allreduce_sum_op.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace framework { +class OpDesc; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +} // namespace paddle + +namespace paddle { +namespace operators { + +class MpAllReduceSumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } +}; + +class MpAllReduceSumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor), tensor to be allreduced in model parallel."); + AddOutput("Out", "(Tensor) the allreduced result in model parallel."); + AddAttr("ring_id", "(int default 0) communication ring id.") + .SetDefault(0); +#if defined(PADDLE_WITH_ASCEND_CL) + AddAttr("tag", "(string default tag) tag for all reduce.") + .SetDefault("tag"); +#endif + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(string::Sprintf(R"DOC( +MpAllReduceSum Operator + +Call collective AllReduceSum in model parallel. If input and output are +the same variable, in-place allreduce will be used. +Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allreduce +)DOC")); + } +}; + +template +class MpAllReduceSumOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("c_identity"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetOutput("Out", this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_INPLACE_OP_INFERER(MpAllReduceSumInplaceInferer, {"X", "Out"}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(mp_allreduce_sum, + ops::MpAllReduceSumOp, + ops::MpAllReduceSumOpGradMaker, + ops::MpAllReduceSumOpGradMaker, + ops::MpAllReduceSumOpMaker, + ops::MpAllReduceSumInplaceInferer); + +REGISTER_OP_CPU_KERNEL(mp_allreduce_sum, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel) diff --git a/paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc b/paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..26092a55e0f1a432f3852a9ef41ff3b9f9a91cd6 --- /dev/null +++ b/paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + mp_allreduce_sum, + ops::CAllReduceOpCUDAKernel, +#if NCCL_VERSION_CODE >= 21000 + ops::CAllReduceOpCUDAKernel, +#endif + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel) diff --git a/paddle/fluid/operators/collective/mp_allreduce_sum_op.kps b/paddle/fluid/operators/collective/mp_allreduce_sum_op.kps new file mode 100644 index 0000000000000000000000000000000000000000..e80fb0f2c105ee7847c1b6fa3ef9039643a14781 --- /dev/null +++ b/paddle/fluid/operators/collective/mp_allreduce_sum_op.kps @@ -0,0 +1,30 @@ +#ifdef PADDLE_WITH_XPU_KP + +// Please do not modify the following code +#if defined(__CUDA_ARCH__) +#undef __CUDA_ARCH__ +#endif + +#if defined(__CUDACC__) +#undef __CUDACC__ +#endif + +#if defined(__CUDA__) +#undef __CUDA__ +#endif + +#if defined(__NVCC__) +#undef __NVCC__ +#endif + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_KERNEL(mp_allreduce_sum, + KP, + plat::XPUPlace, + ops::CAllReduceOpXPUKernel); + +#endif diff --git a/paddle/fluid/operators/collective/mp_allreduce_sum_op_mlu.cc b/paddle/fluid/operators/collective/mp_allreduce_sum_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..2b0bd6efc887ebc436520b48bd207ab88e5c0ee1 --- /dev/null +++ b/paddle/fluid/operators/collective/mp_allreduce_sum_op_mlu.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(mp_allreduce_sum, + ops::CAllReduceOpMLUKernel, + ops::CAllReduceOpMLUKernel, + ops::CAllReduceOpMLUKernel, + ops::CAllReduceOpMLUKernel, + ops::CAllReduceOpMLUKernel, + ops::CAllReduceOpMLUKernel) diff --git a/paddle/fluid/operators/collective/mp_allreduce_sum_op_npu.cc b/paddle/fluid/operators/collective/mp_allreduce_sum_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..0054cfa4687466cdd8bffb16c58c9bedeec9a76f --- /dev/null +++ b/paddle/fluid/operators/collective/mp_allreduce_sum_op_npu.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace platform { +struct ASCENDPlace; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL( + mp_allreduce_sum, + ops::CAllReduceOpASCENDKernel, + ops::CAllReduceOpASCENDKernel, + ops::CAllReduceOpASCENDKernel, + ops::CAllReduceOpASCENDKernel) diff --git a/paddle/fluid/operators/collective/mp_allreduce_sum_op_xpu.cc b/paddle/fluid/operators/collective/mp_allreduce_sum_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..39dcc3470f6ac24356e9cac59c4eb3ba3ef8aafe --- /dev/null +++ b/paddle/fluid/operators/collective/mp_allreduce_sum_op_xpu.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(mp_allreduce_sum, + ops::CAllReduceOpXPUKernel, + ops::CAllReduceOpXPUKernel, + ops::CAllReduceOpXPUKernel) diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index 40795ac2586cb8fe68f8a9216c660846f53c5462..07fc4e7172b9f00ded095df45a5bb47c171eebe4 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -266,8 +266,6 @@ def _mp_allreduce( use_calc_stream, 'ring_id', ring_id, - "use_model_parallel", - use_model_parallel, ) @staticmethod @@ -289,19 +287,17 @@ def _mp_allreduce( ring_id = 0 if group is None else group.id if _in_legacy_dygraph(): if op == ReduceOp.SUM: - return _legacy_C_ops.c_allreduce_sum_( + return _legacy_C_ops.mp_allreduce_sum_( tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - "use_model_parallel", - use_model_parallel, ) else: raise ValueError("Unknown parameter: {}.".format(op)) - op_type = 'c_allreduce_sum' + op_type = 'mp_allreduce_sum' helper = LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference(dtype=tensor.dtype) @@ -319,7 +315,6 @@ def _mp_allreduce( attrs={ 'ring_id': ring_id, 'use_calc_stream': use_calc_stream, - 'use_model_parallel': use_model_parallel, }, ) return out @@ -602,13 +597,12 @@ def _parallel_linear( ) if axis == 0: main_block.append_op( - type='c_allreduce_sum', + type='mp_allreduce_sum', inputs={'X': linear_out}, outputs={'Out': out}, attrs={ 'ring_id': ring_id, 'use_calc_stream': True, - 'use_model_parallel': True, }, ) if linear.bias is not None: diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_static_mp_layers.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_static_mp_layers.py index f459cf67a33cf586301d3b5f90b13997bfb5200d..d43c30675344fcaf52527c430770140c462c56ce 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_static_mp_layers.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_static_mp_layers.py @@ -128,7 +128,7 @@ class TestDistTraning(unittest.TestCase): ops = [op.type for op in ops] self.assertEqual( ops, - ['c_split', 'matmul_v2', 'c_allreduce_sum', 'elementwise_add'], + ['c_split', 'matmul_v2', 'mp_allreduce_sum', 'elementwise_add'], ) weight = model_a.parallel_linear.weight @@ -156,7 +156,7 @@ class TestDistTraning(unittest.TestCase): # print(main_program) ops = main_program.global_block().ops ops = [op.type for op in ops] - self.assertEqual(ops, ['c_embedding', 'c_allreduce_sum']) + self.assertEqual(ops, ['c_embedding', 'mp_allreduce_sum']) weight = model_a.embedding.weight self.assertEqual(