From 8fa8a37f6aec0b95bb6fd7ef2774f0758c08866a Mon Sep 17 00:00:00 2001 From: lilong12 Date: Fri, 23 Apr 2021 21:17:38 +0800 Subject: [PATCH] add the c_identity op (#32485) * add c_identity op, test=develop --- .../operators/collective/c_allreduce_op.h | 6 ++ .../collective/c_allreduce_sum_op.cc | 7 +- .../operators/collective/c_identity_op.cc | 92 +++++++++++++++++++ .../operators/collective/c_identity_op.cu.cc | 48 ++++++++++ .../operators/collective/c_identity_op.h | 38 ++++++++ .../fluid/tests/unittests/CMakeLists.txt | 2 + .../tests/unittests/collective_identity_op.py | 66 +++++++++++++ .../fluid/tests/unittests/test_c_identity.py | 34 +++++++ .../tests/unittests/test_collective_base.py | 5 + 9 files changed, 297 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/collective/c_identity_op.cc create mode 100644 paddle/fluid/operators/collective/c_identity_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_identity_op.h create mode 100644 python/paddle/fluid/tests/unittests/collective_identity_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_c_identity.py diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 74673b1a52..0eaa377869 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -321,6 +321,12 @@ class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { "use_calc_stream", "(bool default false) eject CUDA operations to calculation stream.") .SetDefault(false); + AddAttr( + "use_model_parallel", + "(bool default false) use this op with model parallel mode. In model " + "parallel mode, the backward is c_identity which returns itself for " + "c_allreduce_sum.") + .SetDefault(false); AddComment(string::Sprintf(R"DOC( CAllReduce %s Operator diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op.cc index 68061e6ae6..23ed98bb04 100644 --- a/paddle/fluid/operators/collective/c_allreduce_sum_op.cc +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op.cc @@ -37,7 +37,12 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker { protected: void Apply(GradOpPtr retv) const override { - retv->SetType("c_allreduce_sum"); + bool use_mp = BOOST_GET_CONST(bool, this->GetAttr("use_model_parallel")); + if (use_mp) { + retv->SetType("c_identity"); + } else { + retv->SetType("c_allreduce_sum"); + } retv->SetInput("X", this->OutputGrad("Out")); retv->SetOutput("Out", this->InputGrad("X")); retv->SetAttrMap(this->Attrs()); diff --git a/paddle/fluid/operators/collective/c_identity_op.cc b/paddle/fluid/operators/collective/c_identity_op.cc new file mode 100644 index 0000000000..646c27b90e --- /dev/null +++ b/paddle/fluid/operators/collective/c_identity_op.cc @@ -0,0 +1,92 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_identity_op.h" + +namespace paddle { +namespace operators { + +class CIdentityOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "c_identity"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "c_identity"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for c_identity must be non-negative.", ring_id)); + framework::DDim dim = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class CIdentityOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) identity tensor."); + AddOutput("Out", "(Tensor) identity tensor."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default true) eject CUDA operations to calculation stream.") + .SetDefault(true); + AddAttr("use_model_parallel", + "(bool default true) use this op with model parallel.") + .SetDefault(true); + AddComment(R"DOC( +Identity Operator which returns a copy of itself. +)DOC"); + } +}; + +template +class CIdentityOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("c_allreduce_sum"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetOutput("Out", this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(c_identity, ops::CIdentityOp, + ops::CIdentityOpGradMaker, + ops::CIdentityOpGradMaker, + ops::CIdentityOpMaker); + +REGISTER_OP_CPU_KERNEL(c_identity, ops::CIdentityOpCPUKernel, + ops::CIdentityOpCPUKernel, + ops::CIdentityOpCPUKernel, + ops::CIdentityOpCPUKernel, + ops::CIdentityOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_identity_op.cu.cc b/paddle/fluid/operators/collective/c_identity_op.cu.cc new file mode 100644 index 0000000000..8ccf40e317 --- /dev/null +++ b/paddle/fluid/operators/collective/c_identity_op.cu.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_identity_op.h" + +namespace paddle { +namespace operators { + +template +class CIdentityOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + + int rid = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + rid, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for c_identity op must be non-negative.", rid)); + out->mutable_data(ctx.GetPlace()); + + TensorCopy(*x, out->place(), out); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(c_identity, ops::CIdentityOpCUDAKernel, + ops::CIdentityOpCUDAKernel, + ops::CIdentityOpCUDAKernel, + ops::CIdentityOpCUDAKernel, + ops::CIdentityOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_identity_op.h b/paddle/fluid/operators/collective/c_identity_op.h new file mode 100644 index 0000000000..ca817fb6ba --- /dev/null +++ b/paddle/fluid/operators/collective/c_identity_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class CIdentityOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support c_identity for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1db84665ca..6763b702ca 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -74,6 +74,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_c_concat) LIST(REMOVE_ITEM TEST_OPS test_c_split) LIST(REMOVE_ITEM TEST_OPS test_allgather) + LIST(REMOVE_ITEM TEST_OPS test_c_identity) LIST(REMOVE_ITEM TEST_OPS test_allreduce) LIST(REMOVE_ITEM TEST_OPS test_broadcast) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce) @@ -878,6 +879,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) set_tests_properties(test_c_concat PROPERTIES TIMEOUT 120) set_tests_properties(test_c_split PROPERTIES TIMEOUT 120) set_tests_properties(test_allgather PROPERTIES TIMEOUT 120) + set_tests_properties(test_c_identity PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_scatter_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_barrier_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_scatter PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/collective_identity_op.py b/python/paddle/fluid/tests/unittests/collective_identity_op.py new file mode 100644 index 0000000000..e024b64e82 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_identity_op.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + +paddle.enable_static() + + +class TestCollectiveIdentity(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + nranks = 2 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = main_prog.current_block().create_var( + name="outofgather", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + main_prog.global_block().append_op( + type="c_identity", + inputs={'X': tindata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id, + 'nranks': nranks}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveIdentity, "identity", 0) diff --git a/python/paddle/fluid/tests/unittests/test_c_identity.py b/python/paddle/fluid/tests/unittests/test_c_identity.py new file mode 100644 index 0000000000..c780f800d1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_c_identity.py @@ -0,0 +1,34 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle + +from test_collective_base import TestDistBase + +paddle.enable_static() + + +class TestIdentityOp(TestDistBase): + def _setup_config(self): + pass + + def test_identity(self, col_type="identity"): + self.check_with_place("collective_identity_op.py", col_type) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_base.py b/python/paddle/fluid/tests/unittests/test_collective_base.py index 0d592ec185..697e8d32d6 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_base.py @@ -274,6 +274,11 @@ class TestDistBase(unittest.TestCase): self.assertTrue( np.allclose( tr1_out, need_result, rtol=1e-05, atol=1e-05)) + elif col_type == "identity": + need_result1 = input1 + need_result2 = input2 + self.assertTrue(np.allclose(tr0_out, need_result1, rtol=0, atol=0)) + self.assertTrue(np.allclose(tr1_out, need_result2, rtol=0, atol=0)) elif col_type == "reduce_slicegather": slicesize = input1.shape[0] // 2 tmp10 = input1[0:slicesize] -- GitLab