From c594f57685f828e1e07caafd30c2024d07c027f8 Mon Sep 17 00:00:00 2001 From: lw921014 Date: Tue, 23 Mar 2021 11:30:06 +0800 Subject: [PATCH] add c_reduce_sum op (#31793) add c_reduce_sum op --- cmake/generic.cmake | 1 + .../fluid/operators/collective/CMakeLists.txt | 2 + .../collective/c_allreduce_sum_op_npu_test.cc | 25 +-- .../collective/c_reduce_max_op_npu.cc | 31 ++++ .../collective/c_reduce_min_op_npu.cc | 31 ++++ .../fluid/operators/collective/c_reduce_op.h | 120 ++++++++++++++ .../collective/c_reduce_prod_op_npu.cc | 31 ++++ .../collective/c_reduce_sum_op_npu.cc | 31 ++++ .../collective/c_reduce_sum_op_npu_test.cc | 153 ++++++++++++++++++ .../collective/c_reducescatter_op_npu_test.cc | 19 +-- 10 files changed, 424 insertions(+), 20 deletions(-) create mode 100644 paddle/fluid/operators/collective/c_reduce_max_op_npu.cc create mode 100644 paddle/fluid/operators/collective/c_reduce_min_op_npu.cc create mode 100644 paddle/fluid/operators/collective/c_reduce_prod_op_npu.cc create mode 100644 paddle/fluid/operators/collective/c_reduce_sum_op_npu.cc create mode 100644 paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 7356f7fc21..7caa63daf3 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -448,6 +448,7 @@ function(cc_test TARGET_NAME) "${TARGET_NAME}" STREQUAL "c_reducescatter_op_npu_test" OR "${TARGET_NAME}" STREQUAL "c_allgather_op_npu_test" OR "${TARGET_NAME}" STREQUAL "send_v2_op_npu_test" OR + "${TARGET_NAME}" STREQUAL "c_reduce_sum_op_npu_test" OR "${TARGET_NAME}" STREQUAL "recv_v2_op_npu_test")) cc_test_run(${TARGET_NAME} COMMAND ${TARGET_NAME} diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt index 3a220a4852..93caa5354d 100644 --- a/paddle/fluid/operators/collective/CMakeLists.txt +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -45,6 +45,8 @@ if(WITH_ASCEND_CL) DEPS c_allreduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM}) cc_test(c_allreduce_max_op_npu_test SRCS c_allreduce_max_op_npu_test.cc DEPS c_allreduce_max_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM}) + cc_test(c_reduce_sum_op_npu_test SRCS c_reduce_sum_op_npu_test.cc + DEPS c_reduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM}) cc_test(c_reducescatter_op_npu_test SRCS c_reducescatter_op_npu_test.cc DEPS c_reducescatter_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM}) cc_test(c_allgather_op_npu_test SRCS c_allgather_op_npu_test.cc diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc index ed3a7f50b9..2fff84593c 100644 --- a/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc @@ -53,7 +53,7 @@ void PrintDebugInfo(const std::string preStr, const std::vector &data){ for (auto ele : data) { debugstring += std::to_string(ele) + std::string(","); } - VLOG(2) << preStr << ":" << std::endl < rank_ids{0, 1}; f::AttributeMap comm_init_attrs; comm_init_attrs["ring_id"] = 0; @@ -80,7 +80,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx){ ctx.Wait(); } -void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { +void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) { // init auto x = scope->Var("X"); auto tensor_x = x->GetMutable(); @@ -109,12 +109,12 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { // run f::AttributeMap attrs; - attrs["tag"]=std::string("tagx"); + attrs["tag"]=std::string("tagx_"+ std::to_string(iter)); attrs["ring_id"]=0; - auto op = f::OpRegistry::CreateOp("c_allreduce_sum", + auto op = f::OpRegistry::CreateOp("c_allreduce_sum", {{"X", {"X"}}}, - {{"Out", {"Out"}}}, + {{"Out", {"Out"}}}, attrs); for (int i = 0; i < 10; i ++) { @@ -137,9 +137,12 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { TEST(c_allreduce_sum, NPU) { f::Scope scope; - // only support one device, if more than one device, use first default + // only support one device, if more than one device, use first default p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str()))); Prepare(&scope, ctx); - TestHCCLAllReduceOp(&scope, ctx); + for(int i = 0; i < 1; i ++){ + VLOG(2) << "iter num: " << i; + TestHCCLAllReduceOp(&scope, ctx, i); + } } diff --git a/paddle/fluid/operators/collective/c_reduce_max_op_npu.cc b/paddle/fluid/operators/collective/c_reduce_max_op_npu.cc new file mode 100644 index 0000000000..14ddd67e13 --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_max_op_npu.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace platform { +struct ASCENDPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(c_reduce_max, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel) diff --git a/paddle/fluid/operators/collective/c_reduce_min_op_npu.cc b/paddle/fluid/operators/collective/c_reduce_min_op_npu.cc new file mode 100644 index 0000000000..981145a31f --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_min_op_npu.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace platform { +struct ASCENDPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(c_reduce_min, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel) diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index 81dc5c35bf..4e9d804fc3 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -28,11 +28,17 @@ limitations under the License. */ #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif + #if defined(PADDLE_WITH_GLOO) #include #include "paddle/fluid/framework/fleet/gloo_wrapper.h" #endif +#if defined(PADDLE_WITH_ASCEND_CL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/hccl_helper.h" +#endif + namespace paddle { namespace operators { @@ -110,6 +116,116 @@ class CReduceOpCPUKernel : public framework::OpKernel { } }; +template +class CReduceOpASCENDKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_ASCEND_CL) + + // we need to pre-allocate 512 Bytes before the data + // and 512 Bytes after the data, so the hccl allreduce + // can work. This is a must acooding to huawei peer. + #define PRE_MALLOC_SIZE_BYTES 512 + + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + auto place = ctx.GetPlace(); + hcclDataType_t dtype = platform::ToHCCLDataType(in->type()); + int64_t numel = in->numel(); + + int64_t pre_tmp_size = PRE_MALLOC_SIZE_BYTES / sizeof(T); + int64_t tmp_numel = numel + pre_tmp_size * 2; + + paddle::framework::LoDTensor tmp_in, tmp_out; + tmp_in.Resize({tmp_numel}); + tmp_out.Resize({tmp_numel}); + auto p_tmp_in = tmp_in.mutable_data(place); // allocate + auto p_tmp_out = tmp_out.mutable_data(place); // allocate + + void* sendbuff = reinterpret_cast(tmp_in.data() + pre_tmp_size); + void* recvbuff = reinterpret_cast(tmp_out.data() + pre_tmp_size); + + std::string tag = ctx.Attr("tag"); + int ring_id = ctx.Attr("ring_id"); + int root_id = ctx.Attr("root_id"); + std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); + auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place); + + aclrtStream stream = nullptr; + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + if (ctx.Attr("use_calc_stream")) { + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + int rank_id = comm->rank(); + + // we need to memset this memory firstly to avoid core by hccl + platform::NPUMemsetAsync(static_cast(p_tmp_in), 0, tmp_numel*sizeof(T), stream); + platform::NPUMemsetAsync(static_cast(p_tmp_out), 0, tmp_numel*sizeof(T), stream); + + auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place); + + memory::Copy(npu_place, sendbuff, + npu_place, reinterpret_cast(const_cast(in->data())), + numel * sizeof(T), + stream); + + hcclRedOp_t hccl_red_type = HCCL_REP_OP_SUM; + switch (red_type) { + case kRedSum: + hccl_red_type = HCCL_REP_OP_SUM; + break; + + case kRedMax: + hccl_red_type = HCCL_REP_OP_MAX; + break; + + case kRedMin: + hccl_red_type = HCCL_REP_OP_MIN; + break; + + case kRedProd: + hccl_red_type = HCCL_REP_OP_PROD; + break; + + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "Invalid reduce type: %d", red_type)); + } + + VLOG(3) << "begin hccl reduce, parameter is: " + << "input num: " << numel + << "root_id: " << root_id + << "dtype: " << dtype + << "hccl_red_type: " << hccl_red_type + << ", group is: " << group + << ", tag is " << tag; + + PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_all_reduce( + tag.c_str(), sendbuff, recvbuff, numel, dtype, hccl_red_type, group.c_str(), (void*)stream)); + + if(rank_id == root_id){ + memory::Copy(npu_place, reinterpret_cast(out->data()), + npu_place, recvbuff, + numel * sizeof(T), + stream); + }else{ + memory::Copy(npu_place, reinterpret_cast(out->data()), + npu_place, reinterpret_cast(const_cast(in->data())), + numel * sizeof(T), + stream); + } + + out->Resize(in->dims()); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with NPU.")); +#endif + } +}; + template class CReduceOpCUDAKernel : public framework::OpKernel { public: @@ -179,6 +295,10 @@ class CReduceOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor) the reduced result."); AddAttr("ring_id", "(int default 0) communication ring id.") .SetDefault(0); +#if defined(PADDLE_WITH_ASCEND_CL) + AddAttr("tag", "(string default tag) tag for reduce.") + .SetDefault("tag"); +#endif AddAttr("root_id", "(int default 0) root id.").SetDefault(0); AddAttr( "use_calc_stream", diff --git a/paddle/fluid/operators/collective/c_reduce_prod_op_npu.cc b/paddle/fluid/operators/collective/c_reduce_prod_op_npu.cc new file mode 100644 index 0000000000..d87b3f3826 --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_prod_op_npu.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace platform { +struct ASCENDPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(c_reduce_prod, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel) diff --git a/paddle/fluid/operators/collective/c_reduce_sum_op_npu.cc b/paddle/fluid/operators/collective/c_reduce_sum_op_npu.cc new file mode 100644 index 0000000000..798f46a259 --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_sum_op_npu.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace platform { +struct ASCENDPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(c_reduce_sum, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel, + ops::CReduceOpASCENDKernel) diff --git a/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc b/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc new file mode 100644 index 0000000000..36ec6d155a --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include // NOLINT +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/string/printf.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/math/math_function.h" + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +#if defined(PADDLE_WITH_ASCEND_CL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/hccl_helper.h" +#endif + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(c_reduce_sum); +USE_NO_KERNEL_OP(c_comm_init_hcom); +USE_OP_DEVICE_KERNEL(c_reduce_sum, NPU); + +DECLARE_string(selected_npus); + +template +void PrintDebugInfo(const std::string preStr, const std::vector &data){ + std::string debugstring = ""; + for (auto ele : data) { + debugstring += std::to_string(ele) + std::string(","); + } + VLOG(3) << preStr << ":" << std::endl < rank_ids{0, 1}; + f::AttributeMap comm_init_attrs; + comm_init_attrs["ring_id"] = 0; + comm_init_attrs["nranks"] = 2; + comm_init_attrs["rank"] = rank_id; + comm_init_attrs["device_id"] = device_id; + comm_init_attrs["rank_ids"] = rank_ids; + auto comm_init_op = + f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); + auto place = ctx.GetPlace(); + comm_init_op->Run(*scope, place); + ctx.Wait(); +} + +void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) { + // init + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + + int rank_id = atoi(getenv("RANK_ID")); + int num1 = 3; + int num2 = 128; + + std::vector init; + for (int64_t i = 0; i < num1 * num2; ++i) { + init.push_back(1.0 + rank_id); + } + PrintDebugInfo("input data", init); + + auto place = ctx.GetPlace(); + + TensorFromVector(init, ctx, tensor_x); + tensor_x->Resize({num1, num2}); + ctx.Wait(); + + auto out = scope->Var("Out"); + auto tensor_out = out->GetMutable(); + tensor_out->Resize({num1, num2}); + tensor_out->mutable_data(place); // allocate + ctx.Wait(); + + // run + f::AttributeMap attrs; + attrs["tag"]=std::string("tagx_"+ std::to_string(iter)); + attrs["ring_id"]=0; + int root_id = 0; + attrs["root_id"]=root_id; + + auto op = f::OpRegistry::CreateOp("c_reduce_sum", + {{"X", {"X"}}}, + {{"Out", {"Out"}}}, + attrs); + + op->Run(*scope, place); + ctx.Wait(); + + std::vector out_vec; + TensorToVector(*tensor_out, ctx, &out_vec); + ctx.Wait(); + + PrintDebugInfo("output data", out_vec); + + EXPECT_EQ(out_vec.size(), init.size()); + for (uint32_t i = 0; i < out_vec.size(); i++) { + if(rank_id == root_id){ + EXPECT_EQ(out_vec[i], 3.0); + } + else{ + EXPECT_EQ(out_vec[i], init[i]); + } + } +} + +TEST(c_reduce_sum, NPU) { + f::Scope scope; + + // only support one device, if more than one device, use first default + p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str()))); + + Prepare(&scope, ctx); + for(int i = 0; i < 2; i ++){ + VLOG(2) << "iter num: " << i; + TestHCCLReduceOp(&scope, ctx, i); + } +} diff --git a/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc b/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc index 5bf6539d48..1c21ab19b9 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc @@ -56,7 +56,7 @@ void PrintDebugInfo(const std::string preStr, const std::vector &data){ for (auto ele : data) { debugstring += std::to_string(ele) + std::string(","); } - VLOG(2) << preStr << ":" << std::endl < rank_ids{0, 1}; f::AttributeMap comm_init_attrs; comm_init_attrs["ring_id"] = 0; @@ -119,11 +119,12 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) { auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"X"}}}, {{"Out", {"Out"}}}, attrs); - for (int i = 0; i < 10; i ++) { + int iter_num = 10; + for (int i = 0; i < iter_num; i ++) { op->Run(*scope, place); } ctx.Wait(); - + std::vector out_vec; TensorToVector(*tensor_out, ctx, &out_vec); ctx.Wait(); @@ -131,14 +132,14 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) { PrintDebugInfo("output data", out_vec); EXPECT_EQ(out_vec.size(), init.size() / 2); for (uint32_t i = 0; i < out_vec.size(); i++) { - EXPECT_EQ(out_vec[i], 2.0); + EXPECT_EQ(out_vec[i], iter_num + 1); } } TEST(c_reducescatter, NPU) { f::Scope scope; - // only support one device, if more than one device, use first default + // only support one device, if more than one device, use first default p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str()))); Prepare(&scope, ctx); -- GitLab