未验证 提交 18d33346 编写于 作者: L LiYuRio 提交者: GitHub

new mp_allreduce_sum_op (#47715)

上级 38ba5f2e
......@@ -62,12 +62,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceSumInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(c_allreduce_sum,
ops::CAllReduceOp,
ops::CAllReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::CAllReduceSumOpGradMaker<paddle::imperative::OpBase>,
ops::CAllReduceSumOpMaker,
ops::AllreduceSumInplaceInferer);
REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_sum,
ops::CAllReduceOp,
ops::CAllReduceSumOpMaker,
ops::AllreduceSumInplaceInferer)
REGISTER_OP_CPU_KERNEL(c_allreduce_sum,
ops::CAllReduceOpCPUKernel<ops::kRedSum, float>,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace framework {
class OpDesc;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
class MpAllReduceSumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
};
class MpAllReduceSumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor), tensor to be allreduced in model parallel.");
AddOutput("Out", "(Tensor) the allreduced result in model parallel.");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
#if defined(PADDLE_WITH_ASCEND_CL)
AddAttr<std::string>("tag", "(string default tag) tag for all reduce.")
.SetDefault("tag");
#endif
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(string::Sprintf(R"DOC(
MpAllReduceSum Operator
Call collective AllReduceSum in model parallel. If input and output are
the same variable, in-place allreduce will be used.
Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allreduce
)DOC"));
}
};
template <typename T>
class MpAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_identity");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(MpAllReduceSumInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(mp_allreduce_sum,
ops::MpAllReduceSumOp,
ops::MpAllReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::MpAllReduceSumOpGradMaker<paddle::imperative::OpBase>,
ops::MpAllReduceSumOpMaker,
ops::MpAllReduceSumInplaceInferer);
REGISTER_OP_CPU_KERNEL(mp_allreduce_sum,
ops::CAllReduceOpCPUKernel<ops::kRedSum, float>,
ops::CAllReduceOpCPUKernel<ops::kRedSum, double>,
ops::CAllReduceOpCPUKernel<ops::kRedSum, int>,
ops::CAllReduceOpCPUKernel<ops::kRedSum, int64_t>,
ops::CAllReduceOpCPUKernel<ops::kRedSum, plat::float16>)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
mp_allreduce_sum,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>,
#if NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>,
#endif
ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int64_t>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::float16>)
#ifdef PADDLE_WITH_XPU_KP
// Please do not modify the following code
#if defined(__CUDA_ARCH__)
#undef __CUDA_ARCH__
#endif
#if defined(__CUDACC__)
#undef __CUDACC__
#endif
#if defined(__CUDA__)
#undef __CUDA__
#endif
#if defined(__NVCC__)
#undef __NVCC__
#endif
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(mp_allreduce_sum,
KP,
plat::XPUPlace,
ops::CAllReduceOpXPUKernel<ops::kRedSum, float>);
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(mp_allreduce_sum,
ops::CAllReduceOpMLUKernel<ops::kRedSum, float>,
ops::CAllReduceOpMLUKernel<ops::kRedSum, plat::float16>,
ops::CAllReduceOpMLUKernel<ops::kRedSum, int>,
ops::CAllReduceOpMLUKernel<ops::kRedSum, int16_t>,
ops::CAllReduceOpMLUKernel<ops::kRedSum, int8_t>,
ops::CAllReduceOpMLUKernel<ops::kRedSum, uint8_t>)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace platform {
struct ASCENDPlace;
} // namespace platform
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
mp_allreduce_sum,
ops::CAllReduceOpASCENDKernel<ops::kRedSum, int>,
ops::CAllReduceOpASCENDKernel<ops::kRedSum, int8_t>,
ops::CAllReduceOpASCENDKernel<ops::kRedSum, float>,
ops::CAllReduceOpASCENDKernel<ops::kRedSum, plat::float16>)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(mp_allreduce_sum,
ops::CAllReduceOpXPUKernel<ops::kRedSum, float>,
ops::CAllReduceOpXPUKernel<ops::kRedSum, plat::float16>,
ops::CAllReduceOpXPUKernel<ops::kRedSum, int>)
......@@ -266,8 +266,6 @@ def _mp_allreduce(
use_calc_stream,
'ring_id',
ring_id,
"use_model_parallel",
use_model_parallel,
)
@staticmethod
......@@ -289,19 +287,17 @@ def _mp_allreduce(
ring_id = 0 if group is None else group.id
if _in_legacy_dygraph():
if op == ReduceOp.SUM:
return _legacy_C_ops.c_allreduce_sum_(
return _legacy_C_ops.mp_allreduce_sum_(
tensor,
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
"use_model_parallel",
use_model_parallel,
)
else:
raise ValueError("Unknown parameter: {}.".format(op))
op_type = 'c_allreduce_sum'
op_type = 'mp_allreduce_sum'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
......@@ -319,7 +315,6 @@ def _mp_allreduce(
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
'use_model_parallel': use_model_parallel,
},
)
return out
......@@ -602,13 +597,12 @@ def _parallel_linear(
)
if axis == 0:
main_block.append_op(
type='c_allreduce_sum',
type='mp_allreduce_sum',
inputs={'X': linear_out},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
'use_model_parallel': True,
},
)
if linear.bias is not None:
......
......@@ -128,7 +128,7 @@ class TestDistTraning(unittest.TestCase):
ops = [op.type for op in ops]
self.assertEqual(
ops,
['c_split', 'matmul_v2', 'c_allreduce_sum', 'elementwise_add'],
['c_split', 'matmul_v2', 'mp_allreduce_sum', 'elementwise_add'],
)
weight = model_a.parallel_linear.weight
......@@ -156,7 +156,7 @@ class TestDistTraning(unittest.TestCase):
# print(main_program)
ops = main_program.global_block().ops
ops = [op.type for op in ops]
self.assertEqual(ops, ['c_embedding', 'c_allreduce_sum'])
self.assertEqual(ops, ['c_embedding', 'mp_allreduce_sum'])
weight = model_a.embedding.weight
self.assertEqual(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册