diff --git a/cmake/xpu_kp.cmake b/cmake/xpu_kp.cmake index 9cddbe14964781485eec509ee600cda961033925..166f8786337b14174df7dc806b46588ca77b5b20 100644 --- a/cmake/xpu_kp.cmake +++ b/cmake/xpu_kp.cmake @@ -122,6 +122,12 @@ macro(compile_kernel COMPILE_ARGS) string(REPLACE ";" " " XPU_CXX_DEFINES "${XPU_CXX_DEFINES}" ) separate_arguments(XPU_CXX_DEFINES UNIX_COMMAND "${XPU_CXX_DEFINES}") + set(ABI_VERSION "") + if(WITH_HETERPS AND WITH_PSLIB) + set(ABI_VERSION "-D_GLIBCXX_USE_CXX11_ABI=0") + else() + set(ABI_VERSION "-D_GLIBCXX_USE_CXX11_ABI=1") + endif() add_custom_command( OUTPUT kernel_build/${kernel_name}.bin.o @@ -130,7 +136,7 @@ macro(compile_kernel COMPILE_ARGS) COMMAND ${CMAKE_COMMAND} -E copy ${kernel_path}/${kernel_name}.kps kernel_build/${kernel_name}.xpu COMMAND - ${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES} + ${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 ${ABI_VERSION} ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES} -I. -o kernel_build/${kernel_name}.bin.o.sec kernel_build/${kernel_name}.xpu --xpu-device-only -c -v COMMAND @@ -153,7 +159,7 @@ macro(compile_kernel COMPILE_ARGS) COMMAND ${CMAKE_COMMAND} -E copy ${kernel_path}/${kernel_name}.kps kernel_build/${kernel_name}.xpu COMMAND - ${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES} + ${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 ${ABI_VERSION} ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES} -I. -o kernel_build/${kernel_name}.host.o kernel_build/${kernel_name}.xpu --xpu-host-only -c -v WORKING_DIRECTORY diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 8e51f0e2405bfe6ab218148ca5006c210aaa34e7..218d7fcbb33136b17dff996aad7f16f0e9e3c824 100755 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -22,7 +22,7 @@ limitations under the License. */ #include #ifdef PADDLE_WITH_PSLIB -#include "common_value.h" // NOLINT +#include "common/common_value.h" // NOLINT #endif #ifdef PADDLE_WITH_PSCORE diff --git a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt index 983208c0608ae74c9d5985b9d160b01bc52c1350..cac366d6b22a1480eb75904e968d81c4cd43b72f 100644 --- a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt @@ -7,7 +7,9 @@ IF(WITH_GPU) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) SET(HETERPS_DEPS ${HETERPS_DEPS} ${RPC_DEPS}) endif() - nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS}) + nv_library(heter_comm_kernel SRCS heter_comm_kernel.cu feature_value.h DEPS ${HETERPS_DEPS}) + nv_library(hashtable_kernel SRCS hashtable_kernel.cu feature_value.h DEPS ${HETERPS_DEPS}) + nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h mem_pool.h DEPS ${HETERPS_DEPS} heter_comm_kernel hashtable_kernel) nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm) nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) if(WITH_PSCORE) diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index db11fca109bc31cd91232cbf62277542fa42adb1..b633394e7a81179ba8edf74950014951ffda2ee3 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -52,18 +52,18 @@ struct FeaturePushValue { float lr_g; float mf_g[MF_DIM]; - __device__ __forceinline__ FeaturePushValue - operator+(const FeaturePushValue& a) const { - FeaturePushValue out; - out.slot = a.slot; - out.show = a.show + show; - out.clk = a.clk + clk; - out.lr_g = a.lr_g + lr_g; - for (int i = 0; i < MF_DIM; ++i) { - out.mf_g[i] = a.mf_g[i] + mf_g[i]; - } - return out; - } + // __device__ __forceinline__ FeaturePushValue + // operator+(const FeaturePushValue& a) const { + // FeaturePushValue out; + // out.slot = a.slot; + // out.show = a.show + show; + // out.clk = a.clk + clk; + // out.lr_g = a.lr_g + lr_g; + // for (int i = 0; i < MF_DIM; ++i) { + // out.mf_g[i] = a.mf_g[i] + mf_g[i]; + // } + // return out; + // } }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h old mode 100755 new mode 100644 index e8eb91f6f6b14ef4af73291793c138f8a6af27b5..b821ccecf0a29c725e91fa827441b7c226dcebec --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -13,28 +13,38 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#ifdef PADDLE_WITH_HETERPS #include #include #include #include + #ifdef PADDLE_WITH_PSLIB #include "common_value.h" // NOLINT #endif -#ifdef PADDLE_WITH_PSCORE + +#if defined(PADDLE_WITH_PSCORE) #include "paddle/fluid/distributed/ps/table/depends/feature_value.h" #endif +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/phi/core/utils/rw_lock.h" -#include "thrust/pair.h" -// #include "cudf/concurrent_unordered_map.cuh.h" + +#if defined(PADDLE_WITH_CUDA) #include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" -#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h" -#ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/platform/device/gpu/gpu_types.h" +#include "thrust/pair.h" +#elif defined(__xpu__) +#include +#include "xpu/kernel/cluster_header.h" +#include "xpu/kernel/math.h" +#include "xpu/kernel/simd.h" +#endif namespace paddle { namespace framework { +#if defined(PADDLE_WITH_CUDA) template class TableContainer : public concurrent_unordered_map::max()>( capacity, ValType()) {} }; +#elif defined(PADDLE_WITH_XPU_KP) + +template +class XPUCacheArray { + public: + explicit XPUCacheArray(size_t capacity) : capacity_(capacity), size_(0) { + xpu_malloc(reinterpret_cast(&keys), capacity_ * sizeof(KeyType)); + xpu_malloc(reinterpret_cast(&vals), capacity_ * sizeof(ValType)); + } + + virtual ~XPUCacheArray() { + xpu_free(keys); + xpu_free(vals); + } + + void print() {} + // ValType* find(const KeyType& key) { return NULL; } + // bool insert(const KeyType& key, const ValType& val) { return true; } + + int prefetch(const int dev_id, XPUStream stream = NULL) { return 0; } + size_t size() { return size_; } + + private: + long long capacity_; + long long size_; + KeyType* keys; + ValType* vals; +}; +#endif template class HashTable { public: - HashTable(size_t capacity); + explicit HashTable(size_t capacity); virtual ~HashTable(); HashTable(const HashTable&) = delete; HashTable& operator=(const HashTable&) = delete; + + template void insert(const KeyType* d_keys, const ValType* d_vals, size_t len, - gpuStream_t stream); + StreamType stream); + + template void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index, - gpuStream_t stream); + StreamType stream); + + template void get(const KeyType* d_keys, ValType* d_vals, size_t len, - gpuStream_t stream); - void get(const KeyType* d_keys, char* d_vals, size_t len, gpuStream_t stream); + StreamType stream); + + template + void get(const KeyType* d_keys, char* d_vals, size_t len, StreamType stream); + void show(); - void dump_to_cpu(int devid, cudaStream_t stream); - template + template + void dump_to_cpu(int devid, StreamType stream); + +#if defined(PADDLE_WITH_CUDA) + + template void update(const KeyType* d_keys, const GradType* d_grads, size_t len, - Sgd sgd, gpuStream_t stream); + Sgd sgd, StreamType stream); - template + template void update(const KeyType* d_keys, const char* d_grads, size_t len, Sgd sgd, - gpuStream_t stream); + StreamType stream); + +#elif defined(PADDLE_WITH_XPU_KP) + template + void update(const KeyType* d_keys, const GradType* d_grads, size_t len, + StreamType stream); + + template + void update(const KeyType* d_keys, const char* d_grads, size_t len, + StreamType stream); + +#endif int size() { return container_->size(); } @@ -84,7 +147,11 @@ class HashTable { std::unique_ptr rwlock_{nullptr}; private: +#if defined(PADDLE_WITH_CUDA) TableContainer* container_; +#elif defined(PADDLE_WITH_XPU_KP) + XPUCacheArray* container_; +#endif int BLOCK_SIZE_{256}; float LOAD_FACTOR{0.75f}; size_t capacity_; @@ -94,5 +161,4 @@ class HashTable { }; } // end namespace framework } // end namespace paddle -#include "hashtable_inl.h" #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu similarity index 75% rename from paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h rename to paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 0297e71c35e279e40e9b89d730c36c531f51990e..cac1b9c17e077f3dd94a1dd405abdd09be355a62 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_HETERPS +#include +#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" namespace paddle { namespace framework { +#if defined(PADDLE_WITH_CUDA) + template struct ReplaceOp { __host__ __device__ value_type operator()(value_type new_value, @@ -87,6 +92,7 @@ __global__ void dy_mf_search_kernel(Table* table, } } } + template __global__ void update_kernel(Table* table, const typename Table::key_type* const keys, @@ -135,8 +141,9 @@ void HashTable::show() { } template +template void HashTable::get(const KeyType* d_keys, ValType* d_vals, - size_t len, gpuStream_t stream) { + size_t len, StreamType stream) { if (len == 0) { return; } @@ -146,8 +153,9 @@ void HashTable::get(const KeyType* d_keys, ValType* d_vals, } template +template void HashTable::get(const KeyType* d_keys, char* d_vals, - size_t len, gpuStream_t stream) { + size_t len, StreamType stream) { if (len == 0) { return; } @@ -157,9 +165,10 @@ void HashTable::get(const KeyType* d_keys, char* d_vals, } template +template void HashTable::insert(const KeyType* d_keys, const ValType* d_vals, size_t len, - gpuStream_t stream) { + StreamType stream) { if (len == 0) { return; } @@ -169,22 +178,24 @@ void HashTable::insert(const KeyType* d_keys, } template +template void HashTable::insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index, - gpuStream_t stream) { + StreamType stream) { if (len == 0) { return; } - const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; if (pool == NULL) { return; } + const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; insert_kernel<<>>(container_, d_keys, len, pool, start_index); } template -void HashTable::dump_to_cpu(int devid, cudaStream_t stream) { +template +void HashTable::dump_to_cpu(int devid, StreamType stream) { container_->prefetch(cudaCpuDeviceId, stream); std::vector threads; size_t num = container_->size(); @@ -260,10 +271,10 @@ void HashTable::dump_to_cpu(int devid, cudaStream_t stream) { } template -template +template void HashTable::update(const KeyType* d_keys, const GradType* d_grads, size_t len, - Sgd sgd, gpuStream_t stream) { + Sgd sgd, StreamType stream) { if (len == 0) { return; } @@ -273,19 +284,66 @@ void HashTable::update(const KeyType* d_keys, } template -template +template void HashTable::update(const KeyType* d_keys, const char* d_grads, size_t len, - Sgd sgd, gpuStream_t stream) { + Sgd sgd, StreamType stream) { if (len == 0) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; - dy_mf_update_kernel<<>>( container_, d_keys, d_grads, len, sgd, push_grad_value_size_); } +template class HashTable; + +template void HashTable::get< + cudaStream_t>(const unsigned long* d_keys, + paddle::framework::FeatureValue* d_vals, size_t len, + cudaStream_t stream); + +// template void +// HashTable::get( +// const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t +// stream); + +template void HashTable::insert< + cudaStream_t>(const unsigned long* d_keys, + const paddle::framework::FeatureValue* d_vals, size_t len, + cudaStream_t stream); + +// template void HashTable::insert< +// cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool, +// size_t start_index, cudaStream_t stream); + +template void HashTable:: + dump_to_cpu(int devid, cudaStream_t stream); + +template void HashTable::update< + paddle::framework::FeaturePushValue, + Optimizer, + cudaStream_t>(const unsigned long* d_keys, + const paddle::framework::FeaturePushValue* d_grads, + size_t len, Optimizer + sgd, + cudaStream_t stream); + +// template void HashTable::update< +// Optimizer, +// cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t +// len, +// Optimizer +// sgd, +// cudaStream_t stream); + +#endif } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps new file mode 100644 index 0000000000000000000000000000000000000000..55edf883271b95a27d054f313211e3a078c864ae --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps @@ -0,0 +1,344 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_HETERPS +#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h" + +namespace optimizer_config { +extern _global_ptr_ float* nonclk_coeff; +extern _global_ptr_ float* clk_coeff; + +extern _global_ptr_ float* min_bound; +extern _global_ptr_ float* max_bound; +extern _global_ptr_ float* learning_rate; +extern _global_ptr_ float* initial_g2sum; +extern _global_ptr_ float* initial_range; + +extern _global_ptr_ float* mf_create_thresholds; +extern _global_ptr_ float* mf_learning_rate; +extern _global_ptr_ float* mf_initial_g2sum; +extern _global_ptr_ float* mf_initial_range; +extern _global_ptr_ float* mf_min_bound; +extern _global_ptr_ float* mf_max_bound; +} + +namespace paddle { +namespace framework { + +#if defined(PADDLE_WITH_XPU_KP) + +__device__ void update_lr(float& w, float& g2sum, float g, // NOLINT + float scale) { + __local__ float local_learning_rate; + __local__ float local_initial_g2sum; + __local__ float local_min_bound; + __local__ float local_max_bound; + + GM2LM(optimizer_config::learning_rate, &local_learning_rate, sizeof(float)); + GM2LM(optimizer_config::initial_g2sum, &local_initial_g2sum, sizeof(float)); + GM2LM(optimizer_config::min_bound, &local_min_bound, sizeof(float)); + GM2LM(optimizr_config::max_bound, &local_max_bound, sizeof(float)); + + double add_g2sum = 0; + double ratio = local_learning_rate * + sqrt(local_initial_g2sum / (local_initial_g2sum + g2sum)); + double scaled_grad = g / scale; + + w += scaled_grad * ratio; + + if (w < local_min_bound) w = local_min_bound; + if (w > local_max_bound) w = local_max_bound; + + add_g2sum += scaled_grad * scaled_grad; + + g2sum += add_g2sum; +} + +__device__ void update_mf(int n, float* w, float& g2sum, const float* g, + float scale) { + __local__ float local_mf_learning_rate; + __local__ float local_mf_initial_g2sum; + __local__ float local_mf_min_bound; + __local__ float local_mf_max_bound; + + GM2LM(optimizer_config::mf_learning_rate, &local_mf_learning_rate, + sizeof(float)); + GM2LM(optimizer_config::mf_initial_g2sum, &local_mf_initial_g2sum, + sizeof(float)); + GM2LM(optimizer_config::mf_min_bound, &local_mf_min_bound, sizeof(float)); + GM2LM(optimizer_config::mf_max_bound, &local_mf_max_bound, sizeof(float)); + + double add_g2sum = 0; + double ratio = + local_mf_learning_rate * + sqrt(local_mf_initial_g2sum / (local_mf_initial_g2sum + g2sum)); + for (int i = 0; i < n; ++i) { + double scaled_grad = g[i] / scale; + w[i] += scaled_grad * ratio; + + if (w[i] < local_mf_min_bound) w[i] = local_mf_min_bound; + if (w[i] > local_mf_max_bound) w[i] = local_mf_max_bound; + add_g2sum += scaled_grad * scaled_grad; + } + + g2sum += add_g2sum / n; +} + +__device__ float xpu_rand_uniform() { return 0.1; } + +template +__device__ void update_value(ValType& val, const GradType& grad) { // NOLINT + val.slot = grad.slot; + val.show += grad.show; + val.clk += grad.clk; + + __local__ float local_nonclk_coeff; + __local__ float local_clk_coeff; + + __local__ float local_mf_create_thresholds; + __local__ float local_mf_initial_range; + + GM2LM(optimizer_config::nonclk_coeff, &local_nonclk_coeff, sizeof(float)); + GM2LM(optimizer_config::clk_coeff, &local_clk_coeff, sizeof(float)); + GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds, + sizeof(float)); + + val.delta_score += + local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk; + + update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show); + + if (val.mf_size == 0) { + if (local_mf_create_thresholds <= + local_nonclk_coeff * (val.show - val.clk) + local_clk_coeff * val.clk) { + val.mf_size = MF_DIM + 1; + val.mf[0] = 0; + + for (int i = 0; i < MF_DIM; ++i) { + val.mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range; + } + } + } else { + update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show); + } +} + +template +__global__ void insert_kernel(Table* table, const KeyType* const keys, + const ValType* const vals, size_t len) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = ncores * cluster_id() + cid; + int nthreads = ncores * cluster_num(); + + const int buf_size = 150; + __local__ KeyType local_keys[buf_size]; + __local__ ValType local_vals[buf_size]; + int len_per_loop = min(buf_size, roundup_div(len, nthreads)); + + for (int i = thread_id * len_per_loop; i < len; + i += nthreads * len_per_loop) { + int read_len = min(len_per_loop, len - i); + GM2LM(keys, local_keys, read_len * sizeof(KeyType)); + GM2LM(vals, local_vals, read_len * sizeof(ValType)); + for (int k = 0; k < read_len; k++) { + // auto status = table->insert(local_keys[k], local_vals[k]); + // assert(status != false && "error: insert fails: table is full"); + } + } +} + +template +__global__ void search_kernel(Table* table, const KeyType* const keys, + ValType* const vals, size_t len) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = ncores * cluster_id() + cid; + int nthreads = ncores * cluster_num(); + + const int buf_size = 150; + __local__ KeyType local_keys[buf_size]; + __local__ ValType local_vals[buf_size]; + + int len_per_loop = min(buf_size, roundup_div(len, nthreads)); + for (int i = thread_id * len_per_loop; i < len; + i += nthreads * len_per_loop) { + int read_len = min(len_per_loop, len - i); + GM2LM(keys, local_keys, read_len * sizeof(KeyType)); + for (int k = 0; k < read_len; k++) { + // ValType* val = table->find(local_keys[k]); + // if (val != NULL) { + // local_vals[k] = *val; + // } + } + LM2GM(local_vals, vals + i, read_len * sizeof(ValType)); + } +} + +template +__global__ void update_kernel(Table* table, const KeyType* const keys, + const GradType* const grads, size_t len) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = ncores * cluster_id() + cid; + int nthreads = ncores * cluster_num(); + + const int buf_size = 250; + __local__ KeyType local_keys[buf_size]; + __local__ GradType local_grads[buf_size]; + + int len_per_loop = min(buf_size, roundup_div(len, nthreads)); + for (int i = thread_id * len_per_loop; i < len; + i += nthreads * len_per_loop) { + int read_len = min(len_per_loop, len - i); + + GM2LM(keys, local_keys, read_len * sizeof(KeyType)); + GM2LM(grads, local_grads, read_len * sizeof(GradType)); + + for (int k = 0; k < read_len; k++) { + // ValType* val = table->find(local_keys[k]); + // if (val != NULL) { + // update_value(*val, grads[i]); + //} + } + } +} + +template +HashTable::HashTable(size_t capacity) { + auto tmp_container = XPUCacheArray(capacity); + xpu_malloc(reinterpret_cast(&container_), + sizeof(XPUCacheArray)); + xpu_memcpy(container_, &tmp_container, + sizeof(XPUCacheArray), XPU_HOST_TO_DEVICE); + rwlock_.reset(new phi::RWLock); +} + +template +HashTable::~HashTable() { + xpu_free((void*)container_); +} + +template +void HashTable::show() { + container_->print(); +} + +template +template +void HashTable::get(const KeyType* d_keys, ValType* d_vals, + size_t len, StreamType stream) { + if (len == 0) { + return; + } + search_kernel<<<4, 64, stream>>>(container_, d_keys, d_vals, len); +} + +template +template +void HashTable::get(const KeyType* d_keys, char* d_vals, + size_t len, StreamType stream) { + if (len == 0) { + return; + } + // TODO(zhangminxu): to be implemented +} + +template +template +void HashTable::insert(const KeyType* d_keys, + const ValType* d_vals, size_t len, + StreamType stream) { + if (len == 0) { + return; + } + insert_kernel<<<4, 64, stream>>>(container_, d_keys, d_vals, len); +} + +template +template +void HashTable::dump_to_cpu(int devid, StreamType stream) { + // TODO(zhangminxu): to be implemented +} + +template +template +void HashTable::update(const KeyType* d_keys, + const GradType* d_grads, size_t len, + StreamType stream) { + if (len == 0) { + return; + } + update_kernel<<<4, 64, stream>>>(container_, d_keys, d_grads, len); +} + +template +template +void HashTable::update(const KeyType* d_keys, + const char* d_grads, size_t len, + StreamType stream) { + if (len == 0) { + return; + } + // TODO(zhangminxu): to be implemented +} + +template class HashTable; + +template void HashTable::get< + XPUStream>(const unsigned long* d_keys, + paddle::framework::FeatureValue* d_vals, size_t len, + XPUStream stream); + +// template void +// HashTable::get( +// const unsigned long* d_keys, char* d_vals, size_t len, XPUStream stream); + +template void HashTable::insert< + XPUStream>(const unsigned long* d_keys, + const paddle::framework::FeatureValue* d_vals, size_t len, + XPUStream stream); + +// template void HashTable::insert< +// XPUStream>(const unsigned long* d_keys, size_t len, char* pool, +// size_t start_index, XPUStream stream); + +template void HashTable:: + dump_to_cpu(int devid, XPUStream stream); + +template void HashTable::update< + paddle::framework::FeaturePushValue, XPUStream>( + const unsigned long* d_keys, + const paddle::framework::FeaturePushValue* d_grads, size_t len, + XPUStream stream); + +// template void HashTable::update< +// XPUStream>(const unsigned long* d_keys, const char* d_grads, +// size_t len, XPUStream stream); + +#endif +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 1fca8cdf8bb801a57ec36ee957b27236f488a4b3..419bd716eb304738915adb2e74d08c9dd275bb95 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -15,39 +15,28 @@ limitations under the License. */ #pragma once #include #include -#include "cub/cub.cuh" -#include "cub/util_allocator.cuh" -#include "hashtable.h" // NOLINT -#include "heter_resource.h" // NOLINT +#if defined(PADDLE_WITH_CUDA) #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" -#include "paddle/fluid/memory/allocation/allocator.h" -#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/dynload/nccl.h" -#include "paddle/fluid/platform/place.h" #include "thrust/pair.h" +#elif defined(PADDLE_WITH_XPU_KP) +#include +#include "paddle/fluid/platform/device/xpu/enforce_xpu.h" +#endif + +#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_HETERPS namespace paddle { namespace framework { -struct CustomGradMerger { - template - CUB_RUNTIME_FUNCTION __forceinline__ __device__ T - operator()(const T& a, const T& b) const { - T out; - out.slot = a.slot; - out.show = a.show + b.show; - out.clk = a.clk + b.clk; - out.lr_g = a.lr_g + b.lr_g; - for (int i = 0; i < MF_DIM; ++i) { - out.mf_g[i] = a.mf_g[i] + b.mf_g[i]; - } - return out; - } -}; - template class HeterComm { public: @@ -67,10 +56,21 @@ class HeterComm { void show_one_table(int gpu_num); int get_index_by_devid(int devid); +#if defined(PADDLE_WITH_CUDA) template void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd); // NOLINT +#elif defined(PADDLE_WITH_XPU_KP) + void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len); +#endif + + int log2i(int x); + template + void memory_copy(DstPlace dst_place, void* dst, SrcPlace src_place, + const void* src, size_t count, StreamType stream = 0); + +#if defined(PADDLE_WITH_CUDA) template void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd); // NOLINT @@ -85,8 +85,6 @@ class HeterComm { int gather_multi_node_grad(int num, KeyType* d_keys, GradType* d_grads, int len); - int log2i(int x); - void set_nccl_comm_and_size(const std::vector& inner_comms, const std::vector& inter_comms, int comm_size) { @@ -94,6 +92,7 @@ class HeterComm { nccl_inter_comms_ = inter_comms; node_size_ = comm_size; } +#endif bool need_transfer(int send_id, int receive_id) { return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id); @@ -101,19 +100,19 @@ class HeterComm { // void dump_to_cpu(int index); - void end_pass(); - int get_transfer_devid(int send_id) { return (send_id + 4) % 8; } + void end_pass(); + struct Node { - cudaStream_t in_stream; - cudaStream_t out_stream; + ppStream in_stream; + ppStream out_stream; char* key_storage; char* val_storage; int sync; int key_bytes_len; int val_bytes_len; - int gpu_num; + int dev_num; }; struct Path { @@ -133,7 +132,7 @@ class HeterComm { alloc(size, true); } - void alloc(int size, bool force = false) { + void alloc(size_t size, bool force = false) { if (force || size > all_keys_mem->size()) { all_keys_mem.reset(); all_grads_mem.reset(); @@ -152,7 +151,11 @@ class HeterComm { } } +#if defined(PADDLE_WITH_CUDA) platform::CUDAPlace place_; +#elif defined(PADDLE_WITH_XPU_KP) + platform::XPUPlace place_; +#endif std::shared_ptr all_keys_mem; std::shared_ptr all_grads_mem; KeyType* all_keys; @@ -166,6 +169,33 @@ class HeterComm { void init_path(); + template + void sync_stream(const StreamType& stream) { +#if defined(PADDLE_WITH_CUDA) + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); +#elif defined(PADDLE_WITH_XPU_KP) + PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(stream)); +#endif + } + + template + void create_stream(StreamType* stream) { +#if defined(PADDLE_WITH_CUDA) + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(stream)); +#elif defined(PADDLE_WITH_XPU_KP) + PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(stream)); +#endif + } + + template + void destroy_stream(StreamType stream) { +#if defined(PADDLE_WITH_CUDA) + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream)); +#elif defined(PADDLE_WITH_XPU_KP) + PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(stream)); +#endif + } + void create_storage(int start_index, int end_index, int keylen, int vallen); void destroy_storage(int start_index, int end_index); void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right, @@ -182,15 +212,18 @@ class HeterComm { int block_size_{256}; private: + std::unique_ptr heter_comm_kernel_; std::vector storage_; - CustomGradMerger merger_; int topo_aware_{0}; int feanum_{1800 * 2048}; int multi_node_{0}; + int node_size_; + +#if defined(PADDLE_WITH_CUDA) std::vector nccl_inner_comms_; std::vector nccl_inter_comms_; - int node_size_; std::vector> allocators_; +#endif }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index f85ed330dc8ea4eb4199b6ab006ac54be1b30b0d..1e66b3cb250313850e79407846339e30f7525b14 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -13,115 +13,46 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once #ifdef PADDLE_WITH_HETERPS -//#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" #include +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" +#include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_XPU_KP +#include "paddle/fluid/platform/device/xpu/xpu_info.h" +#endif namespace paddle { namespace framework { -template -__global__ void fill_idx(T* idx, size_t len) { - const size_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < len) { - idx[i] = i; - } -} - -template -void show_tensor(T* input, size_t len, gpuStream_t stream, std::string name) { - T tmp[len]; // NOLINT - cudaMemcpyAsync(&tmp, input, sizeof(T) * len, cudaMemcpyDeviceToHost, stream); - cudaStreamSynchronize(stream); - std::cout << name; - for (int i = 0; i < len; ++i) { - std::cout << ":" << tmp[i]; - } - std::cout << std::endl; -} - -template -__global__ void calc_shard_offset(T* idx, T* left, T* right, size_t len) { - const size_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < len - 1) { - if (idx[i] != idx[i + 1]) { - right[idx[i]] = i; - left[idx[i + 1]] = i + 1; - } - } - if (i == 0) { - left[idx[i]] = i; - } - if (i == (len - 1)) { - right[idx[i]] = i; - } -} - -template -__global__ void calc_shard_index(KeyType* d_keys, size_t len, T* shard_index, - int total_gpu) { - const size_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < len) { - shard_index[i] = d_keys[i] % total_gpu; - } -} - -template -__global__ void fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, T* idx, - size_t len) { - const size_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < len) { - d_shard_keys[i] = d_keys[idx[i]]; - } -} - -template -__global__ void fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, - GradType* d_shard_grads, GradType* d_grads, - T* idx, size_t len) { - const size_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < len) { - d_shard_keys[i] = d_keys[idx[i]]; - d_shard_grads[i] = d_grads[idx[i]]; - } -} - -template -__global__ void fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, - size_t len) { - const size_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < len) { - d_vals[idx[i]] = d_shard_vals[i]; - } -} - template HeterComm::HeterComm( size_t capacity, std::shared_ptr resource) { resource_ = resource; - storage_.resize(resource_->total_gpu()); - for (int i = 0; i < resource_->total_gpu(); ++i) { + storage_.resize(resource_->total_device()); + for (int i = 0; i < resource_->total_device(); ++i) { +#if defined(PADDLE_WITH_CUDA) platform::CUDADeviceGuard guard(resource_->dev_id(i)); allocators_.push_back(std::make_shared( 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT +#endif auto table = new Table(capacity / load_factor_); tables_.push_back(table); if (multi_node_) { storage_[i].init(feanum_, resource_->dev_id(i)); } } + heter_comm_kernel_ = std::make_unique(block_size_); init_path(); } template void HeterComm::init_path() { - int total_gpu = resource_->total_gpu(); - path_.resize(total_gpu); - + int total_device = resource_->total_device(); + path_.resize(total_device); if (!topo_aware_) { VLOG(0) << "init path without topo aware"; - for (int i = 0; i < total_gpu; ++i) { - path_[i].resize(total_gpu); - for (int j = 0; j < total_gpu; ++j) { + for (int i = 0; i < total_device; ++i) { + path_[i].resize(total_device); + for (int j = 0; j < total_device; ++j) { auto& nodes = path_[i][j].nodes_; nodes.resize(1); nodes[0].in_stream = resource_->comm_stream(i, j); @@ -129,17 +60,18 @@ void HeterComm::init_path() { nodes[0].key_storage = NULL; nodes[0].val_storage = NULL; nodes[0].sync = 0; - nodes[0].gpu_num = j; + nodes[0].dev_num = j; } } } else { VLOG(0) << "init path with topo aware"; - for (int i = 0; i < total_gpu; ++i) { - path_[i].resize(total_gpu); - for (int j = 0; j < total_gpu; ++j) { + for (int i = 0; i < total_device; ++i) { + path_[i].resize(total_device); + for (int j = 0; j < total_device; ++j) { auto& nodes = path_[i][j].nodes_; int from = resource_->dev_id(i); int to = resource_->dev_id(j); + int transfer_id = i; if (need_transfer(from, to)) { transfer_id = resource_->get_index_by_devid(get_transfer_devid(from)); @@ -150,7 +82,7 @@ void HeterComm::init_path() { node.key_storage = NULL; node.val_storage = NULL; node.sync = 1; - node.gpu_num = transfer_id; + node.dev_num = transfer_id; } nodes.push_back(Node()); Node& node = nodes.back(); @@ -159,148 +91,222 @@ void HeterComm::init_path() { node.key_storage = NULL; node.val_storage = NULL; node.sync = 0; - node.gpu_num = j; + node.dev_num = j; } } } } +template +template +void HeterComm::memory_copy( + DstPlace dst_place, void* dst, SrcPlace src_place, const void* src, + size_t count, StreamType stream) { +#if defined(PADDLE_WITH_CUDA) + cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, stream); + if (stream == 0) { + cudaStreamSynchronize(0); + } +#elif defined(PADDLE_WITH_XPU_KP) + memory::Copy(dst_place, dst, src_place, src, count); +#endif +} + template void HeterComm::create_storage(int start_index, int end_index, int keylen, int vallen) { +#if defined(PADDLE_WITH_CUDA) auto& allocator = allocators_[start_index]; auto& nodes = path_[start_index][end_index].nodes_; for (size_t i = 0; i < nodes.size(); ++i) { - platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num)); + platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].dev_num)); allocator->DeviceAllocate( - resource_->dev_id(nodes[i].gpu_num), + resource_->dev_id(nodes[i].dev_num), (void**)&(nodes[i].key_storage), // NOLINT - keylen, resource_->remote_stream(nodes[i].gpu_num, start_index)); + keylen, resource_->remote_stream(nodes[i].dev_num, start_index)); allocator->DeviceAllocate( - resource_->dev_id(nodes[i].gpu_num), + resource_->dev_id(nodes[i].dev_num), (void**)&(nodes[i].val_storage), // NOLINT - vallen, resource_->remote_stream(nodes[i].gpu_num, start_index)); - + vallen, resource_->remote_stream(nodes[i].dev_num, start_index)); + nodes[i].key_bytes_len = keylen; + nodes[i].val_bytes_len = vallen; + } +#elif defined(PADDLE_WITH_XPU_KP) + auto& nodes = path_[start_index][end_index].nodes_; + for (size_t i = 0; i < nodes.size(); ++i) { + platform::XPUDeviceGuard guard(resource_->dev_id(nodes[i].dev_num)); + auto place = DevPlace(resource_->dev_id(nodes[i].dev_num)); + auto node_keys_mem = memory::Alloc(place, keylen); + nodes[i].key_storage = reinterpret_cast(node_keys_mem->ptr()); + auto node_vals_mem = memory::Alloc(place, vallen); + nodes[i].val_storage = reinterpret_cast(node_vals_mem->ptr()); nodes[i].key_bytes_len = keylen; nodes[i].val_bytes_len = vallen; } +#endif } template void HeterComm::destroy_storage(int start_index, int end_index) { +#if defined(PADDLE_WITH_CUDA) auto& allocator = allocators_[start_index]; auto& nodes = path_[start_index][end_index].nodes_; for (size_t i = 0; i < nodes.size(); ++i) { - platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num)); + platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].dev_num)); - allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num), + allocator->DeviceFree(resource_->dev_id(nodes[i].dev_num), nodes[i].key_storage); - allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num), + allocator->DeviceFree(resource_->dev_id(nodes[i].dev_num), nodes[i].val_storage); } +#endif } template -void HeterComm::walk_to_dest( - int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key, - GradType* src_val) { +void HeterComm::walk_to_dest(int start_index, + int num, int* h_left, + int* h_right, + KeyType* src_key, + GradType* src_val) { int need_copy_val = 0; if (src_val) { need_copy_val = 1; } std::queue que; - for (int i = 0; i < gpu_num; i++) { + for (int i = 0; i < num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { continue; } - int size = path_[start_index][i].nodes_.size(); + // int size = path_[start_index][i].nodes_.size(); auto& node = path_[start_index][i].nodes_[0]; + CopyTask t(&path_[start_index][i], 0); que.push(t); - cudaMemcpyAsync(node.key_storage, - reinterpret_cast(src_key + h_left[i]), - node.key_bytes_len, cudaMemcpyDefault, node.in_stream); + auto src_dev_id = resource_->dev_id(start_index); + auto dst_dev_id = resource_->dev_id(i); + auto src_place = DevPlace(src_dev_id); + auto dst_place = DevPlace(dst_dev_id); + + memory_copy(dst_place, node.key_storage, src_place, + reinterpret_cast(src_key + h_left[i]), + node.key_bytes_len, node.in_stream); if (need_copy_val) { - cudaMemcpyAsync(node.val_storage, - reinterpret_cast(src_val + h_left[i]), - node.val_bytes_len, cudaMemcpyDefault, node.in_stream); + memory_copy(dst_place, node.val_storage, src_place, + reinterpret_cast(src_val + h_left[i]), + node.val_bytes_len, node.in_stream); } } while (!que.empty()) { CopyTask& cur_task = que.front(); que.pop(); if (cur_task.path->nodes_[cur_task.step].sync) { - cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream); + sync_stream(cur_task.path->nodes_[cur_task.step].in_stream); } - if (cur_task.step != cur_task.path->nodes_.size() - 1) { + if (static_cast(cur_task.step) != + cur_task.path->nodes_.size() - 1) { int cur_step = cur_task.step; CopyTask c(cur_task.path, cur_step + 1); que.push(c); - cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage, - cur_task.path->nodes_[cur_step].key_storage, - cur_task.path->nodes_[cur_step + 1].key_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step + 1].in_stream); + + auto src_dev_id = + resource_->dev_id(cur_task.path->nodes_[cur_step].dev_num); + auto dst_dev_id = + resource_->dev_id(cur_task.path->nodes_[cur_step + 1].dev_num); + auto src_place = DevPlace(src_dev_id); + auto dst_place = DevPlace(dst_dev_id); + + memory_copy(dst_place, cur_task.path->nodes_[cur_step + 1].key_storage, + src_place, cur_task.path->nodes_[cur_step].key_storage, + cur_task.path->nodes_[cur_step + 1].key_bytes_len, + cur_task.path->nodes_[cur_step + 1].in_stream); if (need_copy_val) { - cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage, - cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step + 1].val_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step + 1].in_stream); + memory_copy(dst_place, cur_task.path->nodes_[cur_step + 1].val_storage, + src_place, cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step + 1].val_bytes_len, + cur_task.path->nodes_[cur_step + 1].in_stream); } } } } template -void HeterComm::walk_to_src( - int start_index, int gpu_num, int* h_left, int* h_right, ValType* src_val) { +void HeterComm::walk_to_src(int start_index, + int num, int* h_left, + int* h_right, + ValType* src_val) { std::queue que; - for (int i = 0; i < gpu_num; i++) { + + for (int i = 0; i < num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { continue; } int cur_step = path_[start_index][i].nodes_.size() - 1; auto& node = path_[start_index][i].nodes_[cur_step]; + + auto src_dev_id = resource_->dev_id(i); + auto src_place = DevPlace(src_dev_id); + if (cur_step == 0) { - cudaMemcpyAsync(reinterpret_cast(src_val + h_left[i]), - node.val_storage, node.val_bytes_len, cudaMemcpyDefault, - node.out_stream); + auto dst_dev_id = resource_->dev_id(start_index); + auto dst_place = DevPlace(dst_dev_id); + memory_copy(dst_place, reinterpret_cast(src_val + h_left[i]), + src_place, node.val_storage, node.val_bytes_len, + node.out_stream); } else { CopyTask t(&path_[start_index][i], cur_step - 1); que.push(t); - cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage, - node.val_storage, - path_[start_index][i].nodes_[cur_step - 1].val_bytes_len, - cudaMemcpyDefault, - path_[start_index][i].nodes_[cur_step - 1].out_stream); + + auto dst_dev_id = + resource_->dev_id(path_[start_index][i].nodes_[cur_step - 1].dev_num); + auto dst_place = DevPlace(dst_dev_id); + + memory_copy(dst_place, + path_[start_index][i].nodes_[cur_step - 1].val_storage, + src_place, node.val_storage, + path_[start_index][i].nodes_[cur_step - 1].val_bytes_len, + path_[start_index][i].nodes_[cur_step - 1].out_stream); } } + while (!que.empty()) { CopyTask& cur_task = que.front(); que.pop(); int cur_step = cur_task.step; if (cur_task.path->nodes_[cur_step].sync) { - cudaStreamSynchronize(cur_task.path->nodes_[cur_step].out_stream); + sync_stream(cur_task.path->nodes_[cur_step].out_stream); } + + auto src_dev_id = + resource_->dev_id(cur_task.path->nodes_[cur_step].dev_num); + auto src_place = DevPlace(src_dev_id); + if (cur_step > 0) { CopyTask c(cur_task.path, cur_step - 1); que.push(c); - cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage, - cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step - 1].val_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step - 1].out_stream); + + auto dst_dev_id = + resource_->dev_id(cur_task.path->nodes_[cur_step - 1].dev_num); + auto dst_place = DevPlace(dst_dev_id); + + memory_copy(dst_place, cur_task.path->nodes_[cur_step - 1].val_storage, + src_place, cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step - 1].val_bytes_len, + cur_task.path->nodes_[cur_step - 1].out_stream); + } else if (cur_step == 0) { - int end_index = cur_task.path->nodes_.back().gpu_num; - cudaMemcpyAsync(reinterpret_cast(src_val + h_left[end_index]), - cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step].val_bytes_len, - cudaMemcpyDefault, - cur_task.path->nodes_[cur_step].out_stream); + int end_index = cur_task.path->nodes_.back().dev_num; + + auto dst_dev_id = resource_->dev_id(end_index); + auto dst_place = DevPlace(dst_dev_id); + + memory_copy(dst_place, + reinterpret_cast(src_val + h_left[end_index]), + src_place, cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step].val_bytes_len, + cur_task.path->nodes_[cur_step].out_stream); } } } @@ -314,8 +320,8 @@ HeterComm::~HeterComm() { } template -void HeterComm::show_one_table(int gpu_num) { - tables_[gpu_num]->show(); +void HeterComm::show_one_table(int num) { + tables_[num]->show(); } template @@ -333,24 +339,22 @@ int HeterComm::get_index_by_devid(int devid) { } template -void HeterComm::build_ps(int num, KeyType* h_keys, - ValType* h_vals, - size_t len, - size_t chunk_size, - int stream_num) { +void HeterComm::build_ps( + int dev_num, KeyType* h_keys, ValType* h_vals, size_t len, + size_t chunk_size, int stream_num) { if (len <= 0) { return; } - int dev_id = resource_->dev_id(num); - platform::CUDAPlace place = platform::CUDAPlace(dev_id); - platform::CUDADeviceGuard guard(dev_id); + int dev_id = resource_->dev_id(dev_num); std::vector d_key_bufs; std::vector d_val_bufs; - gpuStream_t streams[stream_num]; // NOLINT + DevPlace place = DevPlace(dev_id); + AnyDeviceGuard guard(dev_id); + ppStream streams[stream_num]; // NOLINT for (int i = 0; i < stream_num; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&(streams[i]))); + create_stream(&(streams[i])); auto d_k_buf = memory::Alloc(place, chunk_size * sizeof(KeyType)); auto d_v_buf = memory::Alloc(place, chunk_size * sizeof(ValType)); d_key_bufs.push_back(std::move(d_k_buf)); @@ -360,39 +364,48 @@ void HeterComm::build_ps(int num, KeyType* h_keys, int cur_len = 0; int cur_stream = 0; - while (cur_len < len) { + while (static_cast(cur_len) < len) { cur_stream = cur_stream % stream_num; + auto cur_use_stream = streams[cur_stream]; +#if defined(PADDLE_WITH_XPU_KP) + cur_use_stream = 0; +#endif + int tmp_len = cur_len + chunk_size > len ? len - cur_len : chunk_size; - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemcpyAsync(d_key_bufs[cur_stream]->ptr(), h_keys + cur_len, - sizeof(KeyType) * tmp_len, cudaMemcpyHostToDevice, - streams[cur_stream])); - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemcpyAsync(d_val_bufs[cur_stream]->ptr(), h_vals + cur_len, - sizeof(ValType) * tmp_len, cudaMemcpyHostToDevice, - streams[cur_stream])); - tables_[num]->insert( + + auto dst_place = place; + auto src_place = platform::CPUPlace(); + + memory_copy( + dst_place, reinterpret_cast(d_key_bufs[cur_stream]->ptr()), + src_place, h_keys + cur_len, sizeof(KeyType) * tmp_len, cur_use_stream); + memory_copy( + dst_place, reinterpret_cast(d_val_bufs[cur_stream]->ptr()), + src_place, h_vals + cur_len, sizeof(ValType) * tmp_len, cur_use_stream); + + tables_[dev_num]->insert( reinterpret_cast(d_key_bufs[cur_stream]->ptr()), reinterpret_cast(d_val_bufs[cur_stream]->ptr()), tmp_len, - streams[cur_stream]); + cur_use_stream); + cur_stream += 1; cur_len += tmp_len; } - for (int i = 0; i < stream_num; ++i) { - cudaStreamSynchronize(streams[i]); - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(streams[i])); + sync_stream(streams[i]); + destroy_stream(streams[i]); } } template void HeterComm::merge_grad( - int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, + int dev_num, KeyType* d_keys, GradType* d_grads, size_t len, int& uniq_len) { // NOLINT - int dev_id = resource_->dev_id(gpu_num); - platform::CUDAPlace place = platform::CUDAPlace(dev_id); - platform::CUDADeviceGuard guard(dev_id); - auto stream = resource_->local_stream(gpu_num, 0); + + int dev_id = resource_->dev_id(dev_num); + DevPlace place = DevPlace(dev_id); + AnyDeviceGuard guard(dev_id); + auto stream = resource_->local_stream(dev_num, 0); size_t temp_storage_bytes; @@ -403,48 +416,50 @@ void HeterComm::merge_grad( GradType* d_merge_grads_ptr = reinterpret_cast(d_merge_grads->ptr()); - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( - NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_grads, - d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false)); + heter_comm_kernel_->sort_pairs(NULL, temp_storage_bytes, d_keys, + d_merge_keys_ptr, d_grads, d_merge_grads_ptr, + len, 0, 8 * sizeof(KeyType), stream, false); - void* d_buff = NULL; auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + heter_comm_kernel_->sort_pairs( d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr, - d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false)); + d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false); temp_storage_bytes = 0; auto d_num_runs_out_mem = memory::Alloc(place, sizeof(int)); int* d_num_runs_out = reinterpret_cast(d_num_runs_out_mem->ptr()); - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey( - NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_merge_grads_ptr, - d_grads, d_num_runs_out, merger_, len, stream, false)); + heter_comm_kernel_->reduce_by_key(NULL, temp_storage_bytes, d_merge_keys_ptr, + d_keys, d_merge_grads_ptr, d_grads, + d_num_runs_out, len, stream, false); if (d_temp_storage->size() < temp_storage_bytes) { d_temp_storage = NULL; d_temp_storage = memory::Alloc(place, temp_storage_bytes); } - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey( + heter_comm_kernel_->reduce_by_key( d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys, - d_merge_grads_ptr, d_grads, d_num_runs_out, merger_, len, stream, false)); + d_merge_grads_ptr, d_grads, d_num_runs_out, len, stream, false); - cudaMemcpyAsync(&uniq_len, d_num_runs_out, sizeof(int), - cudaMemcpyDeviceToHost, stream); - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + auto dst_place = platform::CPUPlace(); + auto src_place = place; + memory_copy(dst_place, &uniq_len, src_place, d_num_runs_out, sizeof(int), + stream); + + sync_stream(stream); } template void HeterComm::split_input_to_shard( KeyType* d_keys, int* d_idx_ptr, size_t len, int* left, int* right, - int gpu_num) { - int total_gpu = resource_->total_gpu(); - int dev_id = resource_->dev_id(gpu_num); - platform::CUDAPlace place = platform::CUDAPlace(dev_id); - platform::CUDADeviceGuard guard(dev_id); - auto stream = resource_->local_stream(gpu_num, 0); + int dev_num) { + int total_device = resource_->total_device(); + int dev_id = resource_->dev_id(dev_num); + DevPlace place = DevPlace(dev_id); + AnyDeviceGuard guard(dev_id); + auto stream = resource_->local_stream(dev_num, 0); auto d_idx_tmp = memory::Alloc(place, len * sizeof(int)); int* d_idx_tmp_ptr = reinterpret_cast(d_idx_tmp->ptr()); @@ -455,24 +470,28 @@ void HeterComm::split_input_to_shard( auto d_shard_index_tmp = memory::Alloc(place, len * sizeof(int)); int* d_shard_index_tmp_ptr = reinterpret_cast(d_shard_index_tmp->ptr()); - int grid_size = (len - 1) / block_size_ + 1; - fill_idx<<>>(d_idx_tmp_ptr, len); - calc_shard_index<<>>( - d_keys, len, d_shard_index_tmp_ptr, total_gpu); + // int grid_size = (len - 1) / block_size_ + 1; + + heter_comm_kernel_->fill_idx(d_idx_tmp_ptr, len, stream); + heter_comm_kernel_->calc_shard_index(d_keys, len, d_shard_index_tmp_ptr, + total_device, stream); size_t temp_storage_bytes; - const int num_bits = 1 + log2i(total_gpu); - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + const int num_bits = 1 + log2i(total_device); + + heter_comm_kernel_->sort_pairs( NULL, temp_storage_bytes, d_shard_index_tmp_ptr, d_shard_index_ptr, - d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream)); + d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + + heter_comm_kernel_->sort_pairs( d_temp_storage->ptr(), temp_storage_bytes, d_shard_index_tmp_ptr, - d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream)); - calc_shard_offset<<>>(d_shard_index_ptr, - left, right, len); - cudaStreamSynchronize(stream); + d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream); + + heter_comm_kernel_->calc_shard_offset(d_shard_index_ptr, left, right, len, + total_device, stream); + sync_stream(stream); } template @@ -484,25 +503,43 @@ void HeterComm::pull_sparse(int num, return; } - int total_gpu = resource_->total_gpu(); + int total_device = resource_->total_device(); int dev_id = resource_->dev_id(num); - platform::CUDAPlace place = platform::CUDAPlace(dev_id); - platform::CUDADeviceGuard guard(dev_id); + DevPlace place = DevPlace(dev_id); + AnyDeviceGuard guard(dev_id); auto stream = resource_->local_stream(num, 0); - int grid_size = (len - 1) / block_size_ + 1; + // int grid_size = (len - 1) / block_size_ + 1; - int h_left[total_gpu]; // NOLINT - int h_right[total_gpu]; // NOLINT + int h_left[total_device]; // NOLINT + int h_right[total_device]; // NOLINT - auto d_left = memory::Alloc(place, total_gpu * sizeof(int)); - auto d_right = memory::Alloc(place, total_gpu * sizeof(int)); + auto d_left = memory::Alloc(place, total_device * sizeof(int)); + auto d_right = memory::Alloc(place, total_device * sizeof(int)); int* d_left_ptr = reinterpret_cast(d_left->ptr()); int* d_right_ptr = reinterpret_cast(d_right->ptr()); - cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream); - cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream); - // +#if defined(PADDLE_WITH_CUDA) + cudaMemsetAsync(d_left_ptr, -1, total_device * sizeof(int), stream); + cudaMemsetAsync(d_right_ptr, -1, total_device * sizeof(int), stream); + +#elif defined(PADDLE_WITH_XPU_KP) + // get XPUDeviceContext according to xpu place + paddle::platform::XPUDeviceContext xpu_dev_ctx(place); + auto xpu_context = xpu_dev_ctx.x_context(); + + int r = xpu::constant(xpu_context, d_left_ptr, total_device, -1); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU constant kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + int r2 = xpu::constant(xpu_context, d_right_ptr, total_device, -1); + PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, + platform::errors::External( + "XPU constant kernel return wrong value[%d %s]", r2, + XPUAPIErrorMsg[r2])); +#endif + auto d_idx = memory::Alloc(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); @@ -513,17 +550,20 @@ void HeterComm::pull_sparse(int num, split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num); - fill_shard_key<<>>(d_shard_keys_ptr, - d_keys, d_idx_ptr, len); + heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, d_keys, d_idx_ptr, len, + stream); - cudaStreamSynchronize(stream); + sync_stream(stream); - cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), - cudaMemcpyDeviceToHost); - cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), - cudaMemcpyDeviceToHost); + auto dst_place = platform::CPUPlace(); + auto src_place = place; - for (int i = 0; i < total_gpu; ++i) { + memory_copy(dst_place, h_left, src_place, d_left_ptr, + total_device * sizeof(int), stream); + memory_copy(dst_place, h_right, src_place, d_right_ptr, + total_device * sizeof(int), stream); + + for (int i = 0; i < total_device; ++i) { int shard_len = h_right[i] - h_left[i] + 1; if (shard_len == 0) { continue; @@ -532,47 +572,53 @@ void HeterComm::pull_sparse(int num, shard_len * sizeof(ValType)); } - walk_to_dest(num, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL); + walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL); - for (int i = 0; i < total_gpu; ++i) { + for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1) { continue; } auto& node = path_[num][i].nodes_.back(); - cudaStreamSynchronize(node.in_stream); - platform::CUDADeviceGuard guard(resource_->dev_id(i)); + sync_stream(node.in_stream); + + AnyDeviceGuard guard(resource_->dev_id(i)); + tables_[i]->rwlock_->RDLock(); tables_[i]->get(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), h_right[i] - h_left[i] + 1, resource_->remote_stream(i, num)); } - for (int i = 0; i < total_gpu; ++i) { - cudaStreamSynchronize(resource_->remote_stream(i, num)); + + for (int i = 0; i < total_device; ++i) { + sync_stream(resource_->remote_stream(i, num)); if (h_left[i] == -1) { continue; } tables_[i]->rwlock_->UNLock(); } - walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr); + walk_to_src(num, total_device, h_left, h_right, d_shard_vals_ptr); - for (int i = 0; i < total_gpu; ++i) { + for (int i = 0; i < total_device; ++i) { auto& node = path_[num][i].nodes_.front(); - cudaStreamSynchronize(node.out_stream); + sync_stream(node.out_stream); } - fill_dvals<<>>(d_shard_vals_ptr, d_vals, - d_idx_ptr, len); - cudaStreamSynchronize(stream); - for (int i = 0; i < total_gpu; ++i) { + heter_comm_kernel_->fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len, + stream); + + sync_stream(stream); + + for (int i = 0; i < total_device; ++i) { destroy_storage(num, i); } } +#if defined(PADDLE_WITH_CUDA) template template -void HeterComm::push_sparse(int gpu_num, +void HeterComm::push_sparse(int dev_num, KeyType* d_keys, GradType* d_grads, size_t len, @@ -581,23 +627,42 @@ void HeterComm::push_sparse(int gpu_num, return; } - int total_gpu = resource_->total_gpu(); - int dev_id = resource_->dev_id(gpu_num); - platform::CUDAPlace place = platform::CUDAPlace(dev_id); - platform::CUDADeviceGuard guard(dev_id); - auto stream = resource_->local_stream(gpu_num, 0); + int total_device = resource_->total_device(); + int dev_id = resource_->dev_id(dev_num); - int h_left[total_gpu]; // NOLINT - int h_right[total_gpu]; // NOLINT + DevPlace place = DevPlace(dev_id); + AnyDeviceGuard guard(dev_id); + auto stream = resource_->local_stream(dev_num, 0); - auto d_left = memory::Alloc(place, total_gpu * sizeof(int)); - auto d_right = memory::Alloc(place, total_gpu * sizeof(int)); + int h_left[total_device]; // NOLINT + int h_right[total_device]; // NOLINT + + auto d_left = memory::Alloc(place, total_device * sizeof(int)); + auto d_right = memory::Alloc(place, total_device * sizeof(int)); int* d_left_ptr = reinterpret_cast(d_left->ptr()); int* d_right_ptr = reinterpret_cast(d_right->ptr()); - cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream); - cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream); - // +#if defined(PADDLE_WITH_CUDA) + cudaMemsetAsync(d_left_ptr, -1, total_device * sizeof(int), stream); + cudaMemsetAsync(d_right_ptr, -1, total_device * sizeof(int), stream); + +#elif defined(PADDLE_WITH_XPU_KP) + // get XPUDeviceContext according to xpu place + paddle::platform::XPUDeviceContext xpu_dev_ctx(place); + auto xpu_context = xpu_dev_ctx.x_context(); + + int r = xpu::constant(xpu_context, d_left_ptr, total_device, -1); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU constant kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + int r2 = xpu::constant(xpu_context, d_right_ptr, total_device, -1); + PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, + platform::errors::External( + "XPU constant kernel return wrong value[%d %s]", r2, + XPUAPIErrorMsg[r2])); +#endif + auto d_idx = memory::Alloc(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); @@ -608,61 +673,183 @@ void HeterComm::push_sparse(int gpu_num, reinterpret_cast(d_shard_grads->ptr()); int uniq_len = len; - merge_grad(gpu_num, d_keys, d_grads, len, uniq_len); + merge_grad(dev_num, d_keys, d_grads, len, uniq_len); - int grid_size = (uniq_len - 1) / block_size_ + 1; + // int grid_size = (uniq_len - 1) / block_size_ + 1; split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, - gpu_num); + dev_num); - fill_shard_grads<<>>( - d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, - uniq_len); + heter_comm_kernel_->fill_shard_grads(d_shard_keys_ptr, d_keys, + d_shard_grads_ptr, d_grads, d_idx_ptr, + uniq_len, stream); - cudaStreamSynchronize(stream); + sync_stream(stream); - cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), - cudaMemcpyDeviceToHost); - cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), - cudaMemcpyDeviceToHost); + auto dst_place = platform::CPUPlace(); + auto src_place = place; + memory_copy(dst_place, h_left, src_place, d_left_ptr, + total_device * sizeof(int), stream); + memory_copy(dst_place, h_right, src_place, d_right_ptr, + total_device * sizeof(int), stream); - for (int i = 0; i < total_gpu; ++i) { + for (int i = 0; i < total_device; ++i) { int shard_len = h_right[i] - h_left[i] + 1; if (h_left[i] == -1 || h_right[i] == -1) { continue; } - create_storage(gpu_num, i, shard_len * sizeof(KeyType), + create_storage(dev_num, i, shard_len * sizeof(KeyType), shard_len * sizeof(GradType)); } - walk_to_dest(gpu_num, total_gpu, h_left, h_right, d_shard_keys_ptr, + walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr, d_shard_grads_ptr); - for (int i = 0; i < total_gpu; ++i) { + for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { continue; } - auto& node = path_[gpu_num][i].nodes_.back(); - cudaStreamSynchronize(node.in_stream); + auto& node = path_[dev_num][i].nodes_.back(); + sync_stream(node.in_stream); - platform::CUDADeviceGuard guard(resource_->dev_id(i)); + AnyDeviceGuard guard(resource_->dev_id(i)); tables_[i]->rwlock_->WRLock(); tables_[i]->update(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), h_right[i] - h_left[i] + 1, sgd, - resource_->remote_stream(i, gpu_num)); + resource_->remote_stream(i, dev_num)); } - for (int i = 0; i < total_gpu; ++i) { - cudaStreamSynchronize(resource_->remote_stream(i, gpu_num)); + + for (int i = 0; i < total_device; ++i) { + sync_stream(resource_->remote_stream(i, dev_num)); if (h_left[i] != -1) { tables_[i]->rwlock_->UNLock(); } } - for (int i = 0; i < total_gpu; ++i) { - destroy_storage(gpu_num, i); + + for (int i = 0; i < total_device; ++i) { + destroy_storage(dev_num, i); } } +#elif defined(PADDLE_WITH_XPU_KP) +template +void HeterComm::push_sparse(int dev_num, + KeyType* d_keys, + GradType* d_grads, + size_t len) { + if (len == 0) { + return; + } + + int total_device = resource_->total_device(); + int dev_id = resource_->dev_id(dev_num); + + DevPlace place = DevPlace(dev_id); + AnyDeviceGuard guard(dev_id); + auto stream = resource_->local_stream(dev_num, 0); + + int h_left[total_device]; // NOLINT + int h_right[total_device]; // NOLINT + + auto d_left = memory::Alloc(place, total_device * sizeof(int)); + auto d_right = memory::Alloc(place, total_device * sizeof(int)); + int* d_left_ptr = reinterpret_cast(d_left->ptr()); + int* d_right_ptr = reinterpret_cast(d_right->ptr()); + +#if defined(PADDLE_WITH_CUDA) + cudaMemsetAsync(d_left_ptr, -1, total_device * sizeof(int), stream); + cudaMemsetAsync(d_right_ptr, -1, total_device * sizeof(int), stream); + +#elif defined(PADDLE_WITH_XPU_KP) + // get XPUDeviceContext according to xpu place + paddle::platform::XPUDeviceContext xpu_dev_ctx(place); + auto xpu_context = xpu_dev_ctx.x_context(); + + int r = xpu::constant(xpu_context, d_left_ptr, total_device, -1); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU constant kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + int r2 = xpu::constant(xpu_context, d_right_ptr, total_device, -1); + PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, + platform::errors::External( + "XPU constant kernel return wrong value[%d %s]", r2, + XPUAPIErrorMsg[r2])); +#endif + + auto d_idx = memory::Alloc(place, len * sizeof(int)); + int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); + + auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); + KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); + auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType)); + GradType* d_shard_grads_ptr = + reinterpret_cast(d_shard_grads->ptr()); + + int uniq_len = len; + merge_grad(dev_num, d_keys, d_grads, len, uniq_len); + + // int grid_size = (uniq_len - 1) / block_size_ + 1; + + split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, + dev_num); + + heter_comm_kernel_->fill_shard_grads(d_shard_keys_ptr, d_keys, + d_shard_grads_ptr, d_grads, d_idx_ptr, + (long long)uniq_len, stream); + + sync_stream(stream); + + auto dst_place = platform::CPUPlace(); + auto src_place = place; + memory_copy(dst_place, h_left, src_place, d_left_ptr, + total_device * sizeof(int), stream); + memory_copy(dst_place, h_right, src_place, d_right_ptr, + total_device * sizeof(int), stream); + + for (int i = 0; i < total_device; ++i) { + int shard_len = h_right[i] - h_left[i] + 1; + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + create_storage(dev_num, i, shard_len * sizeof(KeyType), + shard_len * sizeof(GradType)); + } + + walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr, + d_shard_grads_ptr); + + for (int i = 0; i < total_device; ++i) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + auto& node = path_[dev_num][i].nodes_.back(); + sync_stream(node.in_stream); + + AnyDeviceGuard guard(resource_->dev_id(i)); + tables_[i]->rwlock_->WRLock(); + tables_[i]->update(reinterpret_cast(node.key_storage), + reinterpret_cast(node.val_storage), + h_right[i] - h_left[i] + 1, + resource_->remote_stream(i, dev_num)); + } + + for (int i = 0; i < total_device; ++i) { + sync_stream(resource_->remote_stream(i, dev_num)); + if (h_left[i] != -1) { + tables_[i]->rwlock_->UNLock(); + } + } + + for (int i = 0; i < total_device; ++i) { + destroy_storage(dev_num, i); + } +} + +#endif + +#if defined(PADDLE_WITH_CUDA) template template void HeterComm::update_one_table( @@ -705,7 +892,7 @@ void HeterComm::push_sparse_multi_node( template int HeterComm::gather_one_node_grad( int gpu_num, KeyType* d_keys, GradType* d_grads, int len) { - int total_gpu = resource_->total_gpu(); + int total_gpu = resource_->total_device(); int dev_id = resource_->dev_id(gpu_num); auto& storage = storage_[gpu_num]; platform::CUDAPlace place = platform::CUDAPlace(dev_id); @@ -725,10 +912,10 @@ int HeterComm::gather_one_node_grad( // allgather grad len PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclAllGather((const void*)(d_node_len + gpu_num), - (void*)d_node_len, 1, ncclInt, // NOLINT - nccl_inner_comm, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( + (const void*)(d_node_len + gpu_num), (void*)d_node_len, 1, // NOLINT + ncclInt, // NOLINT + nccl_inner_comm, stream)); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); cudaMemcpy(h_node_len, d_node_len, sizeof(int) * total_gpu, @@ -775,11 +962,12 @@ int HeterComm::gather_one_node_grad( cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost); - int grid_size = (h_node_len[i] - 1) / block_size_ + 1; - fill_shard_grads<<>>( + // int grid_size = (h_node_len[i] - 1) / block_size_ + 1; + heter_comm_kernel_->fill_shard_grads( storage.local_keys + merge_num, storage.all_keys + index, storage.local_grads + merge_num, storage.all_grads + index, - d_idx_ptr + h_left[gpu_num], h_right[gpu_num] - h_left[gpu_num] + 1); + d_idx_ptr + h_left[gpu_num], h_right[gpu_num] - h_left[gpu_num] + 1, + stream); merge_num = merge_num + h_right[gpu_num] - h_left[gpu_num] + 1; } @@ -848,19 +1036,21 @@ int HeterComm::gather_multi_node_grad( return ret; } +#endif + template void HeterComm::end_pass() { - int total_gpu = resource_->total_gpu(); + int total_device = resource_->total_device(); std::vector threads; auto dump_to_cpu_func = [this](int index) { auto stream = resource_->local_stream(index, 0); int dev_id = resource_->dev_id(index); - platform::CUDADeviceGuard guard(dev_id); + AnyDeviceGuard guard(dev_id); tables_[index]->dump_to_cpu(dev_id, stream); }; - for (int i = 0; i < total_gpu; ++i) { + for (int i = 0; i < total_device; ++i) { threads.push_back(std::thread(dump_to_cpu_func, i)); } for (auto& t : threads) { diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..694bdb8d563f5726bfc40509f3e58c8c5553f047 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -0,0 +1,269 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_HETERPS +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" + +namespace paddle { +namespace framework { + +#ifdef PADDLE_WITH_CUDA + +struct GPUCustomGradMerger { + template + CUB_RUNTIME_FUNCTION __forceinline__ __device__ T + operator()(const T& a, const T& b) const { + T out; + out.slot = a.slot; + out.show = a.show + b.show; + out.clk = a.clk + b.clk; + out.lr_g = a.lr_g + b.lr_g; + for (int i = 0; i < MF_DIM; ++i) { + out.mf_g[i] = a.mf_g[i] + b.mf_g[i]; + } + return out; + } +} gpu_merger; + +template +__global__ void fill_idx_kernel(T* idx, size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + idx[i] = i; + } +} + +// template +// void show_tensor(T* input, size_t len, gpuStream_t stream, std::string +// name) +// { +// T tmp[len]; // NOLINT +// cudaMemcpyAsync(&tmp, input, sizeof(T) * len, cudaMemcpyDeviceToHost, +// stream); +// cudaStreamSynchronize(stream); +// std::cout << name; +// for (int i = 0; i < len; ++i) { +// std::cout << ":" << tmp[i]; +// } +// std::cout << std::endl; +//} + +template +__global__ void calc_shard_offset_kernel(T* idx, T* left, T* right, + size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len - 1) { + if (idx[i] != idx[i + 1]) { + right[idx[i]] = i; + left[idx[i + 1]] = i + 1; + } + } + if (i == 0) { + left[idx[i]] = i; + } + if (i == (len - 1)) { + right[idx[i]] = i; + } +} + +template +__global__ void calc_shard_index_kernel(KeyType* d_keys, size_t len, + T* shard_index, int total_gpu) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + shard_index[i] = d_keys[i] % total_gpu; + } +} + +template +__global__ void fill_shard_key_kernel(KeyType* d_shard_keys, KeyType* d_keys, + T* idx, size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + d_shard_keys[i] = d_keys[idx[i]]; + } +} + +template +__global__ void fill_shard_grads_kernel(KeyType* d_shard_keys, KeyType* d_keys, + GradType* d_shard_grads, + GradType* d_grads, T* idx, size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + d_shard_keys[i] = d_keys[idx[i]]; + d_shard_grads[i] = d_grads[idx[i]]; + } +} + +template +__global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals, + T* idx, size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + d_vals[idx[i]] = d_shard_vals[i]; + } +} + +// cuda implemention of heter_comm_kernel.h +template +void HeterCommKernel::fill_idx(T* idx, long long len, + const StreamType& stream) { + int grid_size = (len - 1) / block_size_ + 1; + size_t c_len = (size_t)len; + fill_idx_kernel<<>>(idx, c_len); +} + +template +void HeterCommKernel::calc_shard_offset(T* idx, T* left, T* right, + long long len, int total_devs, + const StreamType& stream) { + int grid_size = (len - 1) / block_size_ + 1; + size_t c_len = (size_t)len; + calc_shard_offset_kernel<<>>(idx, left, + right, c_len); +} + +template +void HeterCommKernel::calc_shard_index(KeyType* d_keys, long long len, + T* shard_index, int total_gpu, + const StreamType& stream) { + int grid_size = (len - 1) / block_size_ + 1; + size_t c_len = (size_t)len; + calc_shard_index_kernel<<>>( + d_keys, c_len, shard_index, total_gpu); +} + +template +void HeterCommKernel::fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, + T* idx, long long len, + const StreamType& stream) { + int grid_size = (len - 1) / block_size_ + 1; + size_t c_len = (size_t)len; + fill_shard_key_kernel<<>>( + d_shard_keys, d_keys, idx, c_len); +} + +template +void HeterCommKernel::fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, + GradType* d_shard_grads, + GradType* d_grads, T* idx, long long len, + const StreamType& stream) { + int grid_size = (len - 1) / block_size_ + 1; + size_t c_len = (size_t)len; + fill_shard_grads_kernel<<>>( + d_shard_keys, d_keys, d_shard_grads, d_grads, idx, c_len); +} + +template +void HeterCommKernel::fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, + long long len, const StreamType& stream) { + int grid_size = (len - 1) / block_size_ + 1; + size_t c_len = (size_t)len; + fill_dvals_kernel<<>>(d_shard_vals, d_vals, + idx, c_len); +} + +template +void HeterCommKernel::sort_pairs(void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + const KeyT* d_keys_in, // NOLINT + KeyT* d_keys_out, const ValueT* d_values_in, + ValueT* d_values_out, int num_items, + int begin_bit, int end_bit, StreamType stream, + bool debug_synchronous) { + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, d_values_in, + d_values_out, num_items, begin_bit, end_bit, stream, debug_synchronous)); +} + +template +void HeterCommKernel::reduce_by_key(void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + int num_items, StreamType stream, + bool debug_synchronous) { + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey( + d_temp_storage, temp_storage_bytes, d_keys_in, d_unique_out, d_values_in, + d_aggregates_out, d_num_runs_out, gpu_merger, num_items, stream, + debug_synchronous)); +} + +template void HeterCommKernel::fill_idx( + int* idx, long long len, const cudaStream_t& stream); + +template void HeterCommKernel::calc_shard_offset( + int* idx, int* left, int* right, long long len, int total_devs, + const cudaStream_t& stream); +template void HeterCommKernel::calc_shard_index< + unsigned long, int, cudaStream_t>(unsigned long* d_keys, long long len, + int* shard_index, int total_devs, + const cudaStream_t& stream); + +template void HeterCommKernel::fill_shard_key( + unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len, + const cudaStream_t& stream); + +template void HeterCommKernel::fill_shard_grads< + unsigned long, paddle::framework::FeaturePushValue, int, cudaStream_t>( + unsigned long* d_shard_keys, unsigned long* d_keys, + paddle::framework::FeaturePushValue* d_shard_grads, + paddle::framework::FeaturePushValue* d_grads, int* idx, long long len, + const cudaStream_t& stream); + +template void +HeterCommKernel::fill_dvals( + paddle::framework::FeatureValue* d_shard_vals, + paddle::framework::FeatureValue* d_vals, int* idx, long long len, + const cudaStream_t& stream); + +template void HeterCommKernel::sort_pairs< + unsigned long, paddle::framework::FeaturePushValue, cudaStream_t>( + void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + const unsigned long* d_keys_in, // NOLINT + unsigned long* d_keys_out, + const paddle::framework::FeaturePushValue* d_values_in, + paddle::framework::FeaturePushValue* d_values_out, int num_items, + int begin_bit, int end_bit, cudaStream_t stream, bool debug_synchronous); + +template void HeterCommKernel::sort_pairs( + void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + const int* d_keys_in, // NOLINT + int* d_keys_out, const int* d_values_in, int* d_values_out, int num_items, + int begin_bit, int end_bit, cudaStream_t stream, bool debug_synchronous); + +template void HeterCommKernel::reduce_by_key< + unsigned long*, unsigned long*, paddle::framework::FeaturePushValue*, + paddle::framework::FeaturePushValue*, int*, cudaStream_t>( + void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + unsigned long* d_keys_in, unsigned long* d_unique_out, + paddle::framework::FeaturePushValue* d_values_in, + paddle::framework::FeaturePushValue* d_aggregates_out, int* d_num_runs_out, + int num_items, cudaStream_t stream, bool debug_synchronous); + +#endif + +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1be3687a7dbeee7fc5017322c02c9b171a6e5223 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_HETERPS +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" + +#if defined(PADDLE_WITH_CUDA) +#include "cub/cub.cuh" +#include "cub/util_allocator.cuh" +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/enforce.h" +#endif + +namespace paddle { +namespace framework { + +class HeterCommKernel { + public: + HeterCommKernel() {} + explicit HeterCommKernel(const int block_size) : block_size_(block_size) {} + + template + void fill_idx(T* idx, long long len, const StreamType& stream); + + template + void calc_shard_offset(T* idx, T* left, T* right, long long len, + int total_devs, const StreamType& stream); + + template + void calc_shard_index(KeyType* d_keys, long long len, T* shard_index, + int total_devs, const StreamType& stream); + + template + void fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, T* idx, + long long len, const StreamType& stream); + + template + void fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, + GradType* d_shard_grads, GradType* d_grads, T* idx, + long long len, const StreamType& stream); + + template + void fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, long long len, + const StreamType& stream); + + template + void sort_pairs(void* d_temp_storage, size_t& temp_storage_bytes, // NOLINT + const KeyT* d_keys_in, KeyT* d_keys_out, + const ValueT* d_values_in, ValueT* d_values_out, + int num_items, int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, StreamType stream = NULL, + bool debug_synchronous = false); + + template + void reduce_by_key(void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, int num_items, + StreamType stream = NULL, bool debug_synchronous = false); + + private: + int block_size_{256}; +}; + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps new file mode 100644 index 0000000000000000000000000000000000000000..a1923a7f6019b6ce69d1fa70f02a760ead5ca507 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.kps @@ -0,0 +1,359 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_HETERPS +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" + +#if defined(PADDLE_WITH_XPU_KP) +#include +#include "xpu/kernel/cluster_header.h" +#include "xpu/kernel/math.h" +#include "xpu/kernel/simd.h" +#endif + +namespace paddle { +namespace framework { + +#if defined(PADDLE_WITH_XPU_KP) + +struct XPUCustomGradMerger { + template + __device__ T operator()(const T& a, const T& b) const { + T out; + out.slot = a.slot; + out.show = a.show + b.show; + out.clk = a.clk + b.clk; + out.lr_g = a.lr_g + b.lr_g; + for (int i = 0; i < MF_DIM; ++i) { + out.mf_g[i] = a.mf_g[i] + b.mf_g[i]; + } + return out; + } +} xpu_merger; + +template +__global__ void fill_idx_kernel(T* idx, long long len) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = ncores * cluster_id() + cid; + int nthreads = ncores * cluster_num(); + const int buf_size = 1024; + __local__ T local_idx[buf_size]; + int len_per_loop = min(buf_size, roundup_div(len, nthreads)); + for (int i = thread_id * len_per_loop; i < len; + i += nthreads * len_per_loop) { + int read_len = min(len_per_loop, len - i); + for (int k = 0; k < read_len; k++) { + int real_idx = i + k; + local_idx[k] = real_idx; + } + LM2GM(local_idx, idx + i, read_len * sizeof(T)); + } +} + +template +__global__ void calc_shard_offset_kernel(T* idx, T* left, T* right, + long long len, const int total_xpu) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = ncores * cluster_id() + cid; + int nthreads = ncores * cluster_num(); + + const int buf_size = 1024; + __local__ T local_idx[buf_size]; + __local__ T local_left[total_xpu]; + __local__ T local_right[total_xpu]; + + for (int i = 0; i < total_xpu; i++) { + local_left[i] = -1; + local_right[i] = -1; + } + int len_per_loop = min(buf_size, roundup_div(len, nthreads)); + for (int i = thread_id * len_per_loop; i < len; + i += nthreads * len_per_loop) { + // read batch from GM will boost performance + int read_len = min(len_per_loop, len - i); + GM2LM(idx + i, local_idx, read_len * sizeof(T)); + for (int k = 0; k < read_len; k++) { + if (local_idx[k] != local_idx[k + 1]) { + int real_idx = i + k; + local_right[local_idx[k]] = real_idx; + local_left[local_idx[k + 1]] = real_idx + 1; + } + } + if (i == 0) { + local_left[local_idx[i]] = i; + } + if (i + read_len == len) { + local_right[local_idx[len - 1]] = len - 1; + } + } + // to be optimized: call LM2GM too frequently + // all_reduce between threads to get global left & global right && LM2GM + for (int i = 0; i < total_xpu; i++) { + if (local_left[i] != -1) LM2GM(local_left + i, left + i, sizeof(T)); + if (local_right[i] != -1) LM2GM(local_right + i, right + i, sizeof(T)); + } +} + +template +__global__ void calc_shard_index_kernel(KeyType* d_keys, long long len, + T* shard_index, int total_xpu) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = ncores * cluster_id() + cid; + int nthreads = ncores * cluster_num(); + const int buf_size = 512; + __local__ KeyType local_keys[buf_size]; + __local__ T local_shard_index[buf_size]; + int len_per_loop = min(buf_size, roundup_div(len, nthreads)); + for (int i = thread_id * len_per_loop; i < len; + i += nthreads * len_per_loop) { + // read batch from GM will boost performance + int read_len = min(len_per_loop, len - i); + GM2LM(d_keys + i, local_keys, read_len * sizeof(KeyType)); + for (int k = 0; k < read_len; k++) { + local_shard_index[k] = local_keys[k] % total_xpu; + } + LM2GM(local_shard_index, shard_index + i, read_len * sizeof(T)); + } +} + +template +__global__ void fill_shard_key_kernel(KeyType* d_shard_keys, KeyType* d_keys, + T* idx, long long len) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = ncores * cluster_id() + cid; + int nthreads = ncores * cluster_num(); + const int buf_size = 400; + __local__ KeyType local_keys[buf_size]; + __local__ KeyType local_shard_keys[buf_size]; + __local__ T local_idx[buf_size]; + int len_per_loop = min(buf_size, roundup_div(len, nthreads)); + for (int i = thread_id * len_per_loop; i < len; + i += nthreads * len_per_loop) { + // read batch from GM will boost performance + int read_len = min(len_per_loop, len - i); + GM2LM(d_keys + i, local_keys, read_len * sizeof(KeyType)); + GM2LM(idx + i, local_idx, read_len * sizeof(T)); + for (int k = 0; k < read_len; k++) { + local_shard_keys[k] = local_keys[local_idx[k]]; + } + LM2GM(local_shard_keys, d_shard_keys + i, read_len * sizeof(KeyType)); + } +} + +// local mem too large, cause compile error +template +__global__ void fill_shard_grads_kernel(KeyType* d_shard_keys, KeyType* d_keys, + GradType* d_shard_grads, + GradType* d_grads, T* idx, + long long len) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = ncores * cluster_id() + cid; + int nthreads = ncores * cluster_num(); + + const int buf_size = 100; + __local__ KeyType local_keys[buf_size]; + __local__ GradType local_grads[buf_size]; + __local__ KeyType local_shard_keys[buf_size]; + __local__ GradType local_shard_grads[buf_size]; + __local__ T local_idx[buf_size]; + + int len_per_loop = min(buf_size, roundup_div(len, nthreads)); + for (int i = thread_id * len_per_loop; i < len; + i += nthreads * len_per_loop) { + // read batch from GM will boost performance + int read_len = min(len_per_loop, len - i); + GM2LM(d_keys + i, local_keys, read_len * sizeof(KeyType)); + GM2LM(d_grads + i, local_grads, read_len * sizeof(GradType)); + GM2LM(idx + i, local_idx, read_len * sizeof(T)); + for (int k = 0; k < read_len; k++) { + local_shard_keys[k] = local_keys[local_idx[k]]; + local_shard_grads[k] = local_grads[local_idx[k]]; + } + LM2GM(local_shard_keys, d_shard_keys + i, read_len * sizeof(KeyType)); + LM2GM(local_shard_grads, d_shard_grads + i, read_len * sizeof(GradType)); + } +} + +template +__global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals, + T* idx, long long len) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = ncores * cluster_id() + cid; + int nthreads = ncores * cluster_num(); + const int buf_size = 50; + __local__ ValType local_vals[buf_size]; + __local__ ValType local_shard_vals[buf_size]; + __local__ T local_idx[buf_size]; + int len_per_loop = min(buf_size, roundup_div(len, nthreads)); + for (int i = thread_id * len_per_loop; i < len; + i += nthreads * len_per_loop) { + // read batch from GM will boost performance + int read_len = min(len_per_loop, len - i); + GM2LM(idx + i, local_idx, read_len * sizeof(T)); + GM2LM(d_shard_vals + i, local_shard_vals, read_len * sizeof(ValType)); + for (int k = 0; k < read_len; k++) { + local_vals[local_idx[k]] = local_shard_vals[k]; + } + LM2GM(local_vals, d_vals + i, read_len * sizeof(ValType)); + } +} + +// xpu implementation of heter_comm_kernel.h + +template +void HeterCommKernel::fill_idx(T* idx, long long len, + const StreamType& stream) { + fill_idx_kernel<<<4, 64, stream>>>(idx, len); +} + +template +void HeterCommKernel::calc_shard_offset(T* idx, T* left, T* right, + long long len, int total_devs, + const StreamType& stream) { + calc_shard_offset_kernel<<<4, 64, stream>>>(idx, left, right, len, + total_devs); +} + +template +void HeterCommKernel::calc_shard_index(KeyType* d_keys, long long len, + T* shard_index, int total_devs, + const StreamType& stream) { + calc_shard_index_kernel<<<4, 64, stream>>>( + d_keys, len, shard_index, total_devs); +} + +template +void HeterCommKernel::fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, + T* idx, long long len, + const StreamType& stream) { + fill_shard_key_kernel<<<4, 64, stream>>>(d_shard_keys, d_keys, + idx, len); +} + +template +void HeterCommKernel::fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, + GradType* d_shard_grads, + GradType* d_grads, T* idx, long long len, + const StreamType& stream) { + fill_shard_grads_kernel<<<4, 64, stream>>>( + d_shard_keys, d_keys, d_shard_grads, d_grads, idx, len); +} + +template +void HeterCommKernel::fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, + long long len, const StreamType& stream) { + fill_dvals_kernel<<<4, 64, stream>>>(d_shard_vals, d_vals, idx, + len); +} + +template +void HeterCommKernel::sort_pairs(void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + const KeyT* d_keys_in, // NOLINT + KeyT* d_keys_out, const ValueT* d_values_in, + ValueT* d_values_out, int num_items, + int begin_bit, int end_bit, StreamType stream, + bool debug_synchronous) {} + +template ( + int* idx, long long len, const XPUStream& stream); +template void HeterCommKernel::calc_shard_offset( + int* idx, int* left, int* right, long long len, int total_devs, + const XPUStream& stream); +template void HeterCommKernel::calc_shard_index( + unsigned long* d_keys, long long len, int* shard_index, int total_devs, + const XPUStream& stream); + +template void HeterCommKernel::fill_shard_key( + unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len, + const XPUStream& stream); +template void HeterCommKernel::fill_shard_grads< + unsigned long, paddle::framework::FeaturePushValue, int, XPUStream>( + unsigned long* d_shard_keys, unsigned long* d_keys, + paddle::framework::FeaturePushValue* d_shard_grads, + paddle::framework::FeaturePushValue* d_grads, int* idx, long long len, + const XPUStream& stream); +template void +HeterCommKernel::fill_dvals( + paddle::framework::FeatureValue* d_shard_vals, + paddle::framework::FeatureValue* d_vals, int* idx, long long len, + const XPUStream& stream); + +template void HeterCommKernel::sort_pairs< + unsigned long, paddle::framework::FeaturePushValue, XPUStream>( + void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + const unsigned long* d_keys_in, // NOLINT + unsigned long* d_keys_out, + const paddle::framework::FeaturePushValue* d_values_in, + paddle::framework::FeaturePushValue* d_values_out, int num_items, + int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous); + +template void HeterCommKernel::sort_pairs( + void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + const int* d_keys_in, // NOLINT + int* d_keys_out, const int* d_values_in, int* d_values_out, int num_items, + int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous); + +template void HeterCommKernel::reduce_by_key< + unsigned long*, unsigned long*, paddle::framework::FeaturePushValue*, + paddle::framework::FeaturePushValue*, int*, XPUStream>( + void* d_temp_storage, + size_t& temp_storage_bytes, // NOLINT + unsigned long* d_keys_in, unsigned long* d_unique_out, + paddle::framework::FeaturePushValue* d_values_in, + paddle::framework::FeaturePushValue* d_aggregates_out, + int* d_num_runs_out int num_items, XPUStream stream, + bool debug_synchronous); + +#endif + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index 581b0d511c23ee070b6dc33af315cc420f6ef20a..583eb926a26a513e57cf1e779de41e2548969a6b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -29,7 +29,9 @@ HeterPs::HeterPs(size_t capacity, std::shared_ptr resource) { comm_ = std::make_shared>( capacity, resource); +#if defined(PADDLE_WITH_CUDA) opt_ = Optimizer(); +#endif } HeterPs::~HeterPs() {} @@ -54,15 +56,21 @@ void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } void HeterPs::push_sparse(int num, FeatureKey* d_keys, FeaturePushValue* d_grads, size_t len) { +#if defined(PADDLE_WITH_CUDA) comm_->push_sparse(num, d_keys, d_grads, len, opt_); +#elif defined(PADDLE_WITH_XPU_KP) + comm_->push_sparse(num, d_keys, d_grads, len); +#endif // comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_); } +#if defined(PADDLE_WITH_CUDA) void HeterPs::set_nccl_comm_and_size(const std::vector& inner_comms, const std::vector& inter_comms, int comm_size) { comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size); } +#endif } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h index d78b6b492074deb9f3dcc1073e951353c0846abb..7fb50f4da1fce3876efceff5da86e325d70f18a8 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -16,7 +16,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" +#if defined(PADDLE_WITH_CUDA) #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" +#endif #ifdef PADDLE_WITH_HETERPS @@ -35,9 +37,13 @@ class HeterPs : public HeterPsBase { size_t len) override; virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, size_t chunk_size, int stream_num) override; + +#if defined(PADDLE_WITH_CUDA) virtual void set_nccl_comm_and_size( const std::vector& inner_comms, const std::vector& inter_comms, int comm_size) override; +#endif + virtual void end_pass() override; virtual int get_index_by_devid(int devid) override; virtual void show_one_table(int gpu_num) override; @@ -46,7 +52,9 @@ class HeterPs : public HeterPsBase { private: std::shared_ptr> comm_; +#if defined(PADDLE_WITH_CUDA) Optimizer opt_; +#endif }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h index 05b3ecf9c3c12c6b4df1192785b0659a8ef851d0..ddbf02df6c578904d8fa79934f4704ad00c4d121 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h @@ -35,9 +35,11 @@ class HeterPsBase { virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, size_t chunk_size, int stream_num) = 0; virtual int get_index_by_devid(int devid) = 0; +#if defined(PADDLE_WITH_CUDA) virtual void set_nccl_comm_and_size( const std::vector& inner_comms, const std::vector& inter_comms, int comm_size) = 0; +#endif virtual void end_pass() = 0; virtual void show_one_table(int gpu_num) = 0; virtual void push_sparse(int num, FeatureKey* d_keys, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc b/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc index cad7559af5742f9accf640cf7aa6a95fb0f17d96..7074cfb521bdf61905b51602c456b57d32e68f3b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc +++ b/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc @@ -13,12 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_HETERPS -#include "heter_resource.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" + +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cuda_device_guard.h" +#endif + +#ifdef PADDLE_WITH_XPU_KP +#include "paddle/fluid/platform/device/xpu/enforce_xpu.h" +#include "paddle/fluid/platform/device/xpu/xpu_info.h" +#endif namespace paddle { namespace framework { +#if defined(PADDLE_WITH_CUDA) GPUResource::GPUResource(std::vector& dev_ids, int index) { index_ = index; dev_ids_ = dev_ids; @@ -52,7 +61,41 @@ GPUResource::~GPUResource() { } } +#elif defined(PADDLE_WITH_XPU_KP) +XPUResource::XPUResource(std::vector& dev_ids, int index) { + index_ = index; + dev_ids_ = dev_ids; + dev_id_ = dev_ids_[index]; + + platform::XPUDeviceGuard guard(dev_id_); + local_streams_.resize(dev_ids_.size()); + comm_streams_.resize(dev_ids_.size(), NULL); + remote_streams_.resize(dev_ids_.size()); + + for (size_t i = 0; i < dev_ids_.size(); ++i) { + PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&local_streams_[i])); + // PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&comm_streams_[i])); + PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&remote_streams_[i])); + } +} + +XPUResource::~XPUResource() { + platform::XPUDeviceGuard guard(dev_id_); + for (size_t i = 0; i < local_streams_.size(); ++i) { + PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(local_streams_[i])); + } + // for (size_t i = 0; i < comm_streams_.size(); ++i) { + // PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(comm_streams_[i])); + // } + for (size_t i = 0; i < remote_streams_.size(); ++i) { + PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(remote_streams_[i])); + } +} + +#endif + void HeterPsResource::enable_p2p() { +#if defined(PADDLE_WITH_CUDA) for (size_t i = 0; i < dev_ids_.size(); ++i) { platform::CUDADeviceGuard guard(dev_ids_[i]); for (size_t j = 0; j < dev_ids_.size(); ++j) { @@ -72,28 +115,28 @@ void HeterPsResource::enable_p2p() { } } } +#endif } HeterPsResource::HeterPsResource(const std::vector& dev_ids) { dev_ids_ = dev_ids; for (size_t i = 0; i < dev_ids_.size(); ++i) { - std::shared_ptr resource = - std::make_shared(dev_ids_, i); + std::shared_ptr resource = + std::make_shared(dev_ids_, i); resources_.push_back(resource); devid_2_index_[dev_ids_[i]] = i; } } -cudaStream_t HeterPsResource::comm_stream(int gpu_num, int stream_num) { - return resources_[gpu_num]->comm_stream(stream_num); +ppStream HeterPsResource::comm_stream(int dev_num, int stream_num) { + return resources_[dev_num]->comm_stream(stream_num); } - -cudaStream_t HeterPsResource::local_stream(int gpu_num, int stream_num) { - return resources_[gpu_num]->local_stream(stream_num); +ppStream HeterPsResource::local_stream(int dev_num, int stream_num) { + return resources_[dev_num]->local_stream(stream_num); } -cudaStream_t HeterPsResource::remote_stream(int gpu_num, int stream_num) { - return resources_[gpu_num]->remote_stream(stream_num); +ppStream HeterPsResource::remote_stream(int dev_num, int stream_num) { + return resources_[dev_num]->remote_stream(stream_num); } int HeterPsResource::dev_id(int num) { return dev_ids_[num]; } @@ -102,7 +145,7 @@ int HeterPsResource::get_index_by_devid(int devid) { return devid_2_index_[devid]; } -int HeterPsResource::total_gpu() { return dev_ids_.size(); } +int HeterPsResource::total_device() { return dev_ids_.size(); } void HeterPsResource::set_multi_mf(int multi_mf_dim, int max_mf_dim) { multi_mf_dim_ = multi_mf_dim; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_resource.h b/paddle/fluid/framework/fleet/heter_ps/heter_resource.h index 19df8cc70f50efd0130ea68390ce9fd374cfef46..164fca22768006a8872b0eee511e9ca5ed4562d1 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_resource.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_resource.h @@ -17,7 +17,16 @@ limitations under the License. */ #include #include #include + +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cuda_device_guard.h" +#endif + +#ifdef PADDLE_WITH_XPU_KP +#include // NOLINT +#include "paddle/fluid/platform/device/xpu/xpu_info.h" +#endif + #include "paddle/fluid/platform/enforce.h" #ifdef PADDLE_WITH_HETERPS @@ -25,9 +34,16 @@ limitations under the License. */ namespace paddle { namespace framework { +#if defined(PADDLE_WITH_CUDA) +using ppStream = cudaStream_t; +#elif defined(PADDLE_WITH_XPU_KP) +using ppStream = XPUStream; +#endif + +#if defined(PADDLE_WITH_CUDA) class GPUResource { public: - GPUResource(std::vector& device_id, int index); + GPUResource(std::vector& device_id, int index); // NOLINT virtual ~GPUResource(); GPUResource(const GPUResource&) = delete; GPUResource& operator=(const GPUResource&) = delete; @@ -45,23 +61,55 @@ class GPUResource { std::vector local_streams_; std::vector comm_streams_; }; +#elif defined(PADDLE_WITH_XPU_KP) +class XPUResource { + public: + XPUResource(std::vector& device_id, int index); // NOLINT + virtual ~XPUResource(); + XPUResource(const XPUResource&) = delete; + XPUResource& operator=(const XPUResource&) = delete; + + int dev_id() const { return dev_id_; } + int index() const { return index_; } + XPUStream local_stream(int num) { return local_streams_[num]; } + XPUStream remote_stream(int num) { return remote_streams_[num]; } + XPUStream comm_stream(int num) { return comm_streams_[num]; } + + int dev_id_; + int index_; + std::vector dev_ids_; + std::vector remote_streams_; + std::vector local_streams_; + std::vector comm_streams_; +}; +#endif + +#if defined(PADDLE_WITH_CUDA) +using DevResource = GPUResource; +using DevPlace = platform::CUDAPlace; +using AnyDeviceGuard = platform::CUDADeviceGuard; +#elif defined(PADDLE_WITH_XPU_KP) +using DevResource = XPUResource; +using DevPlace = platform::XPUPlace; +using AnyDeviceGuard = platform::XPUDeviceGuard; +#endif class HeterPsResource { public: - HeterPsResource(const std::vector& dev_ids); + explicit HeterPsResource(const std::vector& dev_ids); HeterPsResource(const HeterPsResource&) = delete; HeterPsResource& operator=(const HeterPsResource&) = delete; virtual ~HeterPsResource() {} void enable_p2p(); - int total_gpu(); + int total_device(); int get_index_by_devid(int devid); int dev_id(int num); void set_multi_mf(int multi_mf_dim, int max_mf_dim); - gpuStream_t local_stream(int gpu_num, int stream_num); - gpuStream_t remote_stream(int gpu_num, int stream_num); - gpuStream_t comm_stream(int gpu_num, int stream_num); + ppStream local_stream(int dev_num, int stream_num); + ppStream remote_stream(int dev_num, int stream_num); + ppStream comm_stream(int dev_num, int stream_num); - std::vector> resources_; + std::vector> resources_; std::vector dev_ids_; std::map devid_2_index_; int multi_mf_dim_{0}; diff --git a/paddle/fluid/framework/fleet/heter_ps/mem_pool.h b/paddle/fluid/framework/fleet/heter_ps/mem_pool.h index 9189902c28ffb4796b970baab858d2c99918540d..a663d1bf764104533fe48ee4c17e2e142161d086 100644 --- a/paddle/fluid/framework/fleet/heter_ps/mem_pool.h +++ b/paddle/fluid/framework/fleet/heter_ps/mem_pool.h @@ -18,6 +18,7 @@ limitations under the License. */ // #include // "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" #include +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/framework/fleet/heter_ps/cudf/managed.cuh" namespace paddle { @@ -111,3 +112,4 @@ class HBMMemoryPool : public managed { } // end namespace framework } // end namespace paddle #endif +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index ff9976db5d875cc0c1ab01336389068d278ded7a..ebf7dd277c7d6de6923676386a9fa8a2b9edca33 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -13,16 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#ifdef PADDLE_WITH_HETERPS + +#if defined(PADDLE_WITH_CUDA) #include +#endif #include -#include "optimizer_conf.h" #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" - -#ifdef PADDLE_WITH_HETERPS +#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" namespace paddle { namespace framework { +#if defined(PADDLE_WITH_CUDA) template class Optimizer { public: @@ -32,7 +35,8 @@ class Optimizer { void initialize() {} - __device__ void update_lr(float& w, float& g2sum, float g, float scale) { + __device__ void update_lr(float& w, float& g2sum, float g, // NOLINT + float scale) { double add_g2sum = 0; double ratio = optimizer_config::learning_rate * sqrt(optimizer_config::initial_g2sum / @@ -49,8 +53,8 @@ class Optimizer { g2sum += add_g2sum; } - __device__ void update_mf(int n, float* w, float& g2sum, const float* g, - float scale) { + __device__ void update_mf(int n, float* w, float& g2sum, // NOLINT + const float* g, float scale) { double add_g2sum = 0; double ratio = optimizer_config::mf_learning_rate * sqrt(optimizer_config::mf_initial_g2sum / @@ -69,7 +73,8 @@ class Optimizer { g2sum += add_g2sum / n; } - __device__ void update_value(ValType& val, const GradType& grad) { + + __device__ void update_value(ValType& val, const GradType& grad) { // NOLINT val.slot = grad.slot; val.show += grad.show; val.clk += grad.clk; @@ -132,6 +137,7 @@ class Optimizer { } }; +#endif } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h b/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h index 55d0fc561c574dc62e5eeed7502ccaa02946bc8b..6d924a395e19ac063236a352c1145f29c84ded67 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h @@ -14,8 +14,16 @@ limitations under the License. */ #pragma once +#if defined(PADDLE_WITH_XPU_KP) +#include "xpu/kernel/cluster_header.h" +#include "xpu/kernel/debug.h" +#include "xpu/kernel/math.h" +#endif + namespace optimizer_config { +#if defined(PADDLE_WITH_CUDA) + __constant__ float nonclk_coeff = 0.1; __constant__ float clk_coeff = 1; @@ -31,4 +39,24 @@ __constant__ float mf_initial_g2sum = 3.0; __constant__ float mf_initial_range = 1e-4; __constant__ float mf_min_bound = -10; __constant__ float mf_max_bound = 10; -} + +#elif defined(PADDLE_WITH_XPU_KP) + +_global_ptr_ float* nonclk_coeff; +_global_ptr_ float* clk_coeff; + +_global_ptr_ float* min_bound; +_global_ptr_ float* max_bound; +_global_ptr_ float* learning_rate; +_global_ptr_ float* initial_g2sum; +_global_ptr_ float* initial_range; + +_global_ptr_ float* mf_create_thresholds; +_global_ptr_ float* mf_learning_rate; +_global_ptr_ float* mf_initial_g2sum; +_global_ptr_ float* mf_initial_range; +_global_ptr_ float* mf_min_bound; +_global_ptr_ float* mf_max_bound; + +#endif +} // namespace optimizer_config diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index d9d29cc072dd7bc7beec62bedcd1d05ddac726ce..9145dda5f68c2ce03814e6e2017503746d72b663 100755 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -121,7 +121,7 @@ class PSGPUWrapper { is_initialized_ = true; resource_ = std::make_shared(dev_ids); resource_->enable_p2p(); - keys_tensor.resize(resource_->total_gpu()); + keys_tensor.resize(resource_->total_device()); #ifdef PADDLE_WITH_GLOO auto gloo = paddle::framework::GlooWrapper::GetInstance(); if (gloo->Size() > 1) { @@ -287,8 +287,8 @@ class PSGPUWrapper { for (size_t i = 0; i < num_of_dim; i++) { dim_index_map[index_dim_vec_[i]] = i; } - hbm_pools_.resize(resource_->total_gpu() * num_of_dim); - mem_pools_.resize(resource_->total_gpu() * num_of_dim); + hbm_pools_.resize(resource_->total_device() * num_of_dim); + mem_pools_.resize(resource_->total_device() * num_of_dim); max_mf_dim_ = index_dim_vec_.back(); multi_mf_dim_ = (dim_index_map.size() >= 1) ? dim_index_map.size() : 0; resource_->set_multi_mf(multi_mf_dim_, max_mf_dim_);