From 63067e90d517ccd15dd28aa57d06067ce2b8950d Mon Sep 17 00:00:00 2001 From: Fan Zhang Date: Fri, 6 May 2022 16:50:04 +0800 Subject: [PATCH] [XPUPS] Register pull_box_sparse op under XPU_KP compilation (#42354) * Adapt XPUPS - 1st version - 3.24 * Adapt XPUPS - update XPU PushSparse - 2nd version - 3.24 * Adapt XPUPS - add XPU PullSparseOp - 3nd version - 3.25 * refactor heter comm kernel * update. test=develop * Adapt XPUPS - modify by compilation - 4th version - 3.27 * update calc_shard_offset. test=develop * update xpu kernel. test=develop * update args of calc_shard_offset * update. test=develop * remove customGradMerger * update. test=develop * heter_comm update * heter_comm update * update calc_shard_offset. test=develop * heter_comm update * update args of calc_shard_offset * update. test=develop * remove customGradMerger * update. test=develop * fix. test=develop * update. test=develop * update. test=develop * update optimizer kernel * Adapt XPUPS - use WITH_XPU_KP and modify wrapper kernel function - 5th version - 3.30 * update. test=develop * update pslib.cmake * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * Adapt XPUPS - modify by kp compilation - 6th version - 3.30 * update. test=develop * update. test=develop * update. test=develop * update optimizer kernel * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * fix. test=develop * fix. test=develop * used by minxu * update heter_comm_inl * fix. test=develop * Adapt XPUPS - modify by kp compilation - 7th version - 3.30 * fix. test=develop * add optimizer kernel. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 3.31 update * Adapt XPUPS - update kp compilation path - 8th version - 3.31 * add optimizer kernel. test=develop * fix kunlun not support size_t. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix kunlun not support size_t. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update heter_comm_kernel.kps 3.31 * fix. test=develop * fix. test=develop * update heter_comm_kernel.kps 3.31 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update heter_comm.h 3.31 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update hashtable. test=develop * update. test=develop * Adapt XPUPS - update by kp compilation - 9th version - 4.1 * update hashtable. test=develop * fix. test=develop * update hashtable 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 10th version - 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * update. test=develop * modify by compilation 4.1 * update. test=develop * update. test=develop * fix. test=develop * modify by compilation 4.1 * update. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * modify by compilation 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * modify by compilation 4.1 19:30 * fix. test=develop * update ps_gpu_wrapper.kps 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 11th version - 4.1 * fix. test=develop * Adapt XPUPS - update by kp compilation - 12nd version - 4.2 * fix. test=develop * fix. test=develop * modify by compilation 4.2 * 4.2 update * fix. test=develop * template init. test=develop * update 4.6 * fix. test=develop * template init. test=develop * 4.6 modify by compilation * hashtable template init. test=develop * hashtable template init. test=develop * fix. test=develop * fix. test=develop * fix. test=devlop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=devlop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 13nd version - 4.7 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 4.11 update * fix. test=develop * fix. test=develop * 4.11 update * update by pre-commit * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 4.12 update * fix. test=develop * Adapt XPUPS - update by kp compilation - 14th version - 4.13 * 4.13 update * 4.14 update * 4.14 update * 4.14 update * 4.14 modify by merged latest compilation * retry CI 4.14 * 4.15 pass static check * 4.15 modify by gpups CI * 3.16 update by gpups CI - modify ps_gpu_wrapper.h * 4.16 update * 4.16 pass xpu compile * 4.16 retry CI * 4.16 update * Adapt XPUPS - adapt BKCL comm for XPUPS - 4.24 * update by compilation * Adapt XPUPS - register PSGPUTrainer for XPUPS - 4.25 * update device_worker_factory * Adapt XPUPS - split heter_ps into .cu and .cc - 4.27 * Adapt XPUPS - register pull_box_sparse op under XPU_KP - 4.28 * update Co-authored-by: zmxdream --- paddle/fluid/framework/data_feed.cc | 2 + .../framework/fleet/heter_ps/CMakeLists.txt | 5 -- .../framework/fleet/heter_ps/heter_comm_inl.h | 2 + .../framework/fleet/heter_ps/heter_ps.cc | 70 +++++++++++++++++++ .../framework/fleet/heter_ps/heter_ps.cu | 18 ----- paddle/fluid/operators/pull_box_sparse_op.cc | 6 +- paddle/fluid/operators/pull_box_sparse_op.cu | 43 ------------ paddle/fluid/operators/pull_box_sparse_op.h | 23 ++---- paddle/fluid/operators/pull_box_sparse_op.kps | 56 +++++++++++++++ .../platform/device/xpu/xpu_op_kpfirst_list.h | 2 + 10 files changed, 140 insertions(+), 87 deletions(-) create mode 100644 paddle/fluid/framework/fleet/heter_ps/heter_ps.cc delete mode 100644 paddle/fluid/operators/pull_box_sparse_op.cu create mode 100644 paddle/fluid/operators/pull_box_sparse_op.kps diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 3b6370e1185..b63f317aae8 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -261,6 +261,8 @@ void DataFeed::CopyToFeedTensor(void* dst, const void* src, size_t size) { cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); #elif defined(PADDLE_WITH_HIP) hipMemcpy(dst, src, size, hipMemcpyHostToDevice); +#elif defined(PADDLE_WITH_XPU_KP) + xpu_memcpy(dst, src, size, XPUMemcpyKind::XPU_HOST_TO_DEVICE); #else PADDLE_THROW(platform::errors::Unimplemented( "Not supported GPU/ROCM, please compile with option WITH_GPU=ON or " diff --git a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt index 975ce696ece..51456457d06 100644 --- a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt @@ -30,12 +30,7 @@ IF(WITH_XPU_KP) xpu_library(heter_comm_kernel SRCS heter_comm_kernel.h heter_comm_kernel.kps feature_value.h) xpu_library(hashtable_kernel SRCS hashtable.h hashtable_kernel.kps) cc_library(heter_comm SRCS heter_comm.h heter_resource.cc DEPS ${HETERPS_DEPS} heter_comm_kernel hashtable_kernel) - # Change heter_ps.cu file suffix - # NOTE(zhangminxu): If we compile with XPU_KP, we directly copy heter_ps.cu to heter_ps.cc - file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/heter_ps.cu DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/) - file(RENAME ${CMAKE_CURRENT_BINARY_DIR}/heter_ps.cu ${CMAKE_CURRENT_BINARY_DIR}/heter_ps.cc) cc_library(heter_ps SRCS heter_ps.cc DEPS heter_comm) - # xpu_library(heter_comm SRCS heter_comm.h heter_comm_kernel.kps feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS}) ENDIF() IF(WITH_ROCM) hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 7ebf7660ee5..098adc2bdeb 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -193,7 +193,9 @@ void HeterComm::walk_to_dest(int start_index, memory_copy(dst_place, node.key_storage, src_place, reinterpret_cast(src_key + h_left[i]), node.key_bytes_len, node.in_stream); +#if defined(PADDLE_WITH_CUDA) // adapt for gpu-graph cudaMemsetAsync(node.val_storage, -1, node.val_bytes_len, node.in_stream); +#endif if (need_copy_val) { memory_copy(dst_place, node.val_storage, src_place, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc new file mode 100644 index 00000000000..700b43f18fb --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/fleet/heter_ps/heter_ps.h" +#include + +#ifdef PADDLE_WITH_HETERPS + +namespace paddle { +namespace framework { + +HeterPsBase* HeterPsBase::get_instance( + size_t capacity, std::shared_ptr resource) { + return new HeterPs(capacity, resource); +} + +HeterPs::HeterPs(size_t capacity, std::shared_ptr resource) { + comm_ = + std::make_shared>( + capacity, resource); +} + +HeterPs::~HeterPs() {} + +void HeterPs::pull_sparse(int num, FeatureKey* d_keys, FeatureValue* d_vals, + size_t len) { + comm_->pull_sparse(num, d_keys, d_vals, len); +} + +void HeterPs::build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, + size_t len, size_t chunk_size, int stream_num) { + comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num); +} + +int HeterPs::get_index_by_devid(int devid) { + return comm_->get_index_by_devid(devid); +} + +void HeterPs::set_sparse_sgd(const OptimizerConfig& optimizer_config) { + comm_->set_sparse_sgd(optimizer_config); +} + +void HeterPs::set_embedx_sgd(const OptimizerConfig& optimizer_config) { + comm_->set_embedx_sgd(optimizer_config); +} + +void HeterPs::end_pass() { comm_->end_pass(); } + +void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } + +void HeterPs::push_sparse(int num, FeatureKey* d_keys, + FeaturePushValue* d_grads, size_t len) { + comm_->push_sparse(num, d_keys, d_grads, len); + // comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_); +} + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index 8a877f85076..581b0d511c2 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -29,9 +29,7 @@ HeterPs::HeterPs(size_t capacity, std::shared_ptr resource) { comm_ = std::make_shared>( capacity, resource); -#if defined(PADDLE_WITH_CUDA) opt_ = Optimizer(); -#endif } HeterPs::~HeterPs() {} @@ -50,37 +48,21 @@ int HeterPs::get_index_by_devid(int devid) { return comm_->get_index_by_devid(devid); } -#if defined(PADDLE_WITH_XPU_KP) -void HeterPs::set_sparse_sgd(const OptimizerConfig& optimizer_config) { - comm_->set_sparse_sgd(optimizer_config); -} - -void HeterPs::set_embedx_sgd(const OptimizerConfig& optimizer_config) { - comm_->set_embedx_sgd(optimizer_config); -} -#endif - void HeterPs::end_pass() { comm_->end_pass(); } void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } void HeterPs::push_sparse(int num, FeatureKey* d_keys, FeaturePushValue* d_grads, size_t len) { -#if defined(PADDLE_WITH_CUDA) comm_->push_sparse(num, d_keys, d_grads, len, opt_); -#elif defined(PADDLE_WITH_XPU_KP) - comm_->push_sparse(num, d_keys, d_grads, len); -#endif // comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_); } -#if defined(PADDLE_WITH_CUDA) void HeterPs::set_nccl_comm_and_size(const std::vector& inner_comms, const std::vector& inter_comms, int comm_size) { comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size); } -#endif } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/operators/pull_box_sparse_op.cc b/paddle/fluid/operators/pull_box_sparse_op.cc index 22b43910e69..8c9b5f8d90f 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_sparse_op.cc @@ -132,7 +132,5 @@ REGISTER_OPERATOR(pull_box_sparse, ops::PullBoxSparseOp, ops::PushBoxSparseOpMaker, ops::PushBoxSparseOpMaker); REGISTER_OPERATOR(push_box_sparse, ops::PushBoxSparseOp); -REGISTER_OP_CPU_KERNEL(pull_box_sparse, ops::PullBoxSparseCPUKernel); -REGISTER_OP_CPU_KERNEL(push_box_sparse, ops::PushBoxSparseCPUKernel); -REGISTER_OP_XPU_KERNEL(pull_box_sparse, ops::PullBoxSparseXPUKernel); -REGISTER_OP_XPU_KERNEL(push_box_sparse, ops::PushBoxSparseXPUKernel); +REGISTER_OP_CPU_KERNEL(pull_box_sparse, ops::PullBoxSparseKernel); +REGISTER_OP_CPU_KERNEL(push_box_sparse, ops::PushBoxSparseKernel); diff --git a/paddle/fluid/operators/pull_box_sparse_op.cu b/paddle/fluid/operators/pull_box_sparse_op.cu deleted file mode 100644 index e3407dd3b2e..00000000000 --- a/paddle/fluid/operators/pull_box_sparse_op.cu +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "paddle/fluid/operators/pull_box_sparse_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { -using platform::PADDLE_CUDA_NUM_THREADS; -using LoDTensor = framework::LoDTensor; - -template -class PullBoxSparseCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PullBoxSparseFunctor(ctx); - } -}; - -template -class PushBoxSparseCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PushBoxSparseFunctor(ctx); - } -}; -} // namespace operators -} // namespace paddle -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL(pull_box_sparse, ops::PullBoxSparseCUDAKernel); -REGISTER_OP_CUDA_KERNEL(push_box_sparse, ops::PushBoxSparseCUDAKernel); diff --git a/paddle/fluid/operators/pull_box_sparse_op.h b/paddle/fluid/operators/pull_box_sparse_op.h index 2bde9725abd..136e91121f8 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.h +++ b/paddle/fluid/operators/pull_box_sparse_op.h @@ -15,8 +15,12 @@ #pragma once #include #include +#ifdef PADDLE_WITH_BOX_PS #include "paddle/fluid/framework/fleet/box_wrapper.h" +#endif +#ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" +#endif #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" @@ -100,7 +104,7 @@ static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) { using LoDTensor = framework::LoDTensor; template -class PullBoxSparseCPUKernel : public framework::OpKernel { +class PullBoxSparseKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PullBoxSparseFunctor(ctx); @@ -108,27 +112,12 @@ class PullBoxSparseCPUKernel : public framework::OpKernel { }; template -class PushBoxSparseCPUKernel : public framework::OpKernel { +class PushBoxSparseKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PushBoxSparseFunctor(ctx); } }; -template -class PullBoxSparseXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PullBoxSparseFunctor(ctx); - } -}; - -template -class PushBoxSparseXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PushBoxSparseFunctor(ctx); - } -}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/pull_box_sparse_op.kps b/paddle/fluid/operators/pull_box_sparse_op.kps new file mode 100644 index 00000000000..6b7c7c84951 --- /dev/null +++ b/paddle/fluid/operators/pull_box_sparse_op.kps @@ -0,0 +1,56 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU_KP + +// Please do not modify the following code +#if defined(__CUDA_ARCH__) +#undef __CUDA_ARCH__ +#endif + +#if defined(__CUDACC__) +#undef __CUDACC__ +#endif + +#if defined(__CUDA__) +#undef __CUDA__ +#endif + +#if defined(__NVCC__) +#undef __NVCC__ +#endif + +#include // NOLINT +#include "xpu/kernel/cluster_header.h" // NOLINT +#include "xpu/kernel/debug.h" // NOLINT +#include "xpu/kernel/math.h" // NOLINT +#else +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#endif + +#include "paddle/fluid/operators/pull_box_sparse_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +#ifdef PADDLE_WITH_XPU_KP +REGISTER_OP_KERNEL(pull_box_sparse, KP, plat::XPUPlace, + ops::PullBoxSparseKernel); +REGISTER_OP_KERNEL(push_box_sparse, KP, plat::XPUPlace, + ops::PushBoxSparseKernel); +#else +REGISTER_OP_CUDA_KERNEL(pull_box_sparse, ops::PullBoxSparseKernel); +REGISTER_OP_CUDA_KERNEL(push_box_sparse, ops::PushBoxSparseKernel); +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h index 43c9e63ac19..ab68ebf3a54 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h +++ b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h @@ -107,6 +107,8 @@ XPUOpMap& get_kp_ops() { {"reduce_prod", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_all", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})}, {"reduce_any", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})}, + {"pull_box_sparse", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_amax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_amin", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, }; -- GitLab