未验证 提交 0ef3ef28 编写于 作者: F Fan Zhang 提交者: GitHub

XPUPS Adaptation (#40991)

* Adapt XPUPS - 1st version - 3.24

* Adapt XPUPS - update XPU PushSparse -  2nd version - 3.24

* Adapt XPUPS - add XPU PullSparseOp - 3nd version - 3.25

* refactor heter comm kernel

* update. test=develop

* Adapt XPUPS - modify by compilation - 4th version - 3.27

* update calc_shard_offset. test=develop

* update xpu kernel. test=develop

* update args of calc_shard_offset

* update. test=develop

* remove customGradMerger

* update. test=develop

* heter_comm update

* heter_comm update

* update calc_shard_offset. test=develop

* heter_comm update

* update args of calc_shard_offset

* update. test=develop

* remove customGradMerger

* update. test=develop

* fix. test=develop

* update. test=develop

* update. test=develop

* update optimizer kernel

* Adapt XPUPS - use WITH_XPU_KP and modify wrapper kernel function - 5th version - 3.30

* update. test=develop

* update pslib.cmake

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* Adapt XPUPS - modify by kp compilation  - 6th version - 3.30

* update. test=develop

* update. test=develop

* update. test=develop

* update optimizer kernel

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* used by minxu

* update heter_comm_inl

* fix. test=develop

* Adapt XPUPS - modify by kp compilation  - 7th version - 3.30

* fix. test=develop

* add optimizer kernel. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* 3.31 update

* Adapt XPUPS - update kp compilation path  - 8th version - 3.31

* add optimizer kernel. test=develop

* fix kunlun not support size_t. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix kunlun not support size_t. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update heter_comm_kernel.kps 3.31

* fix. test=develop

* fix. test=develop

* update heter_comm_kernel.kps 3.31

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update heter_comm.h 3.31

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update hashtable. test=develop

* update. test=develop

* Adapt XPUPS - update by kp compilation  - 9th version - 4.1

* update hashtable. test=develop

* fix. test=develop

* update hashtable 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 10th version - 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update. test=develop

* modify by compilation 4.1

* update. test=develop

* update. test=develop

* fix. test=develop

* modify by compilation 4.1

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* modify by compilation 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* modify by compilation 4.1 19:30

* fix. test=develop

* update ps_gpu_wrapper.kps 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 11th version - 4.1

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 12nd version - 4.2

* fix. test=develop

* fix. test=develop

* modify by compilation 4.2

* 4.2 update

* fix. test=develop

* template init. test=develop

* update 4.6

* fix. test=develop

* template init. test=develop

* 4.6 modify by compilation

* hashtable template init. test=develop

* hashtable template init. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=devlop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=devlop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 13nd version - 4.7

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* 4.11 update

* fix. test=develop

* fix. test=develop

* 4.11 update

* update by pre-commit

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* 4.12 update

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 14th version - 4.13

* 4.13 update

* 4.14 update

* 4.14 update

* 4.14 update

* 4.14 modify by merged latest compilation

* retry CI 4.14

* 4.15 pass static check

* 4.15 modify by gpups CI

* 3.16 update by gpups CI - modify ps_gpu_wrapper.h

* 4.16 update

* 4.16 pass xpu compile

* 4.16 retry CI

* 4.16 update
Co-authored-by: Nzmxdream <zhangminxu01@baidu.com>
上级 7ee31a96
......@@ -12,15 +12,19 @@ else()
endif(WITH_PSLIB)
if(WITH_HETERPS)
if(WITH_NCCL)
if(WITH_NCCL AND WITH_GPU)
nv_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc
DEPS heter_ps gloo_wrapper ${BRPC_DEPS})
add_subdirectory(heter_ps)
elseif(WITH_XPU_KP)
xpu_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.kps ps_gpu_wrapper.cc
DEPS heter_ps gloo_wrapper ${BRPC_DEPS})
add_subdirectory(heter_ps)
elseif(WITH_RCCL)
hip_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc
DEPS heter_ps gloo_wrapper ${BRPC_DEPS})
add_subdirectory(heter_ps)
endif(WITH_NCCL)
endif()
else()
cc_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cc DEPS gloo_wrapper)
endif(WITH_HETERPS)
......
......@@ -24,6 +24,14 @@ IF(WITH_GPU)
endif()
ENDIF()
IF(WITH_XPU_KP)
SET(HETERPS_DEPS device_context)
xpu_library(heter_comm_kernel SRCS heter_comm_kernel.h heter_comm_kernel.kps feature_value.h)
xpu_library(hashtable_kernel SRCS hashtable.h hashtable_kernel.kps)
cc_library(heter_comm SRCS heter_comm.h heter_resource.cc DEPS ${HETERPS_DEPS} heter_comm_kernel hashtable_kernel)
cc_library(heter_ps SRCS heter_ps.cc DEPS heter_comm)
# xpu_library(heter_comm SRCS heter_comm.h heter_comm_kernel.kps feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS})
ENDIF()
IF(WITH_ROCM)
hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context)
hip_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
......
......@@ -48,7 +48,7 @@ __device__ void update_lr(float& w, float& g2sum, float g, // NOLINT
GM2LM(optimizer_config::learning_rate, &local_learning_rate, sizeof(float));
GM2LM(optimizer_config::initial_g2sum, &local_initial_g2sum, sizeof(float));
GM2LM(optimizer_config::min_bound, &local_min_bound, sizeof(float));
GM2LM(optimizr_config::max_bound, &local_max_bound, sizeof(float));
GM2LM(optimizer_config::max_bound, &local_max_bound, sizeof(float));
double add_g2sum = 0;
double ratio = local_learning_rate *
......@@ -136,7 +136,7 @@ __device__ void update_value(ValType& val, const GradType& grad) { // NOLINT
template <typename KeyType, typename ValType, typename Table>
__global__ void insert_kernel(Table* table, const KeyType* const keys,
const ValType* const vals, size_t len) {
const ValType* const vals, long long len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
......@@ -164,7 +164,7 @@ __global__ void insert_kernel(Table* table, const KeyType* const keys,
template <typename KeyType, typename ValType, typename Table>
__global__ void search_kernel(Table* table, const KeyType* const keys,
ValType* const vals, size_t len) {
ValType* const vals, long long len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
......@@ -194,7 +194,7 @@ __global__ void search_kernel(Table* table, const KeyType* const keys,
template <typename KeyType, typename ValType, typename Table, typename GradType>
__global__ void update_kernel(Table* table, const KeyType* const keys,
const GradType* const grads, size_t len) {
const GradType* const grads, long long len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
......@@ -251,7 +251,10 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
if (len == 0) {
return;
}
search_kernel<<<4, 64, stream>>>(container_, d_keys, d_vals, len);
long long c_len = (long long)len;
search_kernel<KeyType, ValType,
XPUCacheArray<KeyType, ValType>><<<4, 64, stream>>>(
container_, d_keys, d_vals, c_len);
}
template <typename KeyType, typename ValType>
......@@ -272,7 +275,10 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
if (len == 0) {
return;
}
insert_kernel<<<4, 64, stream>>>(container_, d_keys, d_vals, len);
long long c_len = (long long)len;
insert_kernel<KeyType, ValType,
XPUCacheArray<KeyType, ValType>><<<4, 64, stream>>>(
container_, d_keys, d_vals, c_len);
}
template <typename KeyType, typename ValType>
......@@ -289,7 +295,10 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
if (len == 0) {
return;
}
update_kernel<<<4, 64, stream>>>(container_, d_keys, d_grads, len);
long long c_len = (long long)len;
update_kernel<KeyType, ValType, XPUCacheArray<KeyType, ValType>,
GradType><<<4, 64, stream>>>(container_, d_keys, d_grads,
c_len);
}
template <typename KeyType, typename ValType>
......
......@@ -153,11 +153,13 @@ class HeterComm {
#if defined(PADDLE_WITH_CUDA)
platform::CUDAPlace place_;
#elif defined(PADDLE_WITH_XPU_KP)
platform::XPUPlace place_;
#endif
std::shared_ptr<memory::Allocation> all_keys_mem;
std::shared_ptr<memory::Allocation> all_grads_mem;
KeyType* all_keys;
GradType* all_grads;
......@@ -228,5 +230,7 @@ class HeterComm {
} // end namespace framework
} // end namespace paddle
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h"
#endif
......@@ -411,7 +411,6 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(
auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr());
auto d_merge_grads = memory::Alloc(place, len * sizeof(GradType));
GradType* d_merge_grads_ptr =
reinterpret_cast<GradType*>(d_merge_grads->ptr());
......@@ -1035,7 +1034,6 @@ int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad(
merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret);
return ret;
}
#endif
template <typename KeyType, typename ValType, typename GradType>
......@@ -1065,7 +1063,6 @@ void HeterComm<KeyType, ValType, GradType>::end_pass() {
// platform::CUDADeviceGuard guard(dev_id);
// tables_[index]->dump_to_cpu(dev_id, stream);
//}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -41,6 +41,7 @@ class HeterCommKernel {
template <typename KeyType, typename T, typename StreamType>
void calc_shard_index(KeyType* d_keys, long long len, T* shard_index,
int total_devs, const StreamType& stream);
template <typename KeyType, typename T, typename StreamType>
......@@ -62,6 +63,7 @@ class HeterCommKernel {
const KeyT* d_keys_in, KeyT* d_keys_out,
const ValueT* d_values_in, ValueT* d_values_out,
int num_items, int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8, StreamType stream = NULL,
bool debug_synchronous = false);
......@@ -75,6 +77,7 @@ class HeterCommKernel {
ValuesInputIteratorT d_values_in,
AggregatesOutputIteratorT d_aggregates_out,
NumRunsOutputIteratorT d_num_runs_out, int num_items,
StreamType stream = NULL, bool debug_synchronous = false);
private:
......
......@@ -233,8 +233,6 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
}
}
// xpu implementation of heter_comm_kernel.h
template <typename T, typename StreamType>
void HeterCommKernel::fill_idx(T* idx, long long len,
const StreamType& stream) {
......@@ -291,17 +289,21 @@ void HeterCommKernel::sort_pairs(void* d_temp_storage,
bool debug_synchronous) {}
template <typename KeysInputIteratorT, typename UniqueOutputIteratorT,
void HeterCommKernel::reduce_by_key(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
KeysInputIteratorT d_keys_in, UniqueOutputIteratorT d_unique_out,
ValuesInputIteratorT d_values_in,
AggregatesOutputIteratorT d_aggregates_out,
NumRunsOutputIteratorT d_num_runs_out, int num_items,
StreamType stream, bool debug_synchronous) {}
typename ValuesInputIteratorT, typename AggregatesOutputIteratorT,
typename NumRunsOutputIteratorT, typename StreamType>
void HeterCommKernel::reduce_by_key(void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
KeysInputIteratorT d_keys_in,
UniqueOutputIteratorT d_unique_out,
ValuesInputIteratorT d_values_in,
AggregatesOutputIteratorT d_aggregates_out,
NumRunsOutputIteratorT d_num_runs_out,
int num_items, StreamType stream,
bool debug_synchronous) {}
template void HeterCommKernel::fill_idx<int, XPUStream>(
int* idx, long long len, const XPUStream& stream);
template void HeterCommKernel::calc_shard_offset<int, XPUStream>(
int* idx, int* left, int* right, long long len, int total_devs,
const XPUStream& stream);
......@@ -312,12 +314,14 @@ template void HeterCommKernel::calc_shard_index<unsigned long, int, XPUStream>(
template void HeterCommKernel::fill_shard_key<unsigned long, int, XPUStream>(
unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len,
const XPUStream& stream);
template void HeterCommKernel::fill_shard_grads<
unsigned long, paddle::framework::FeaturePushValue, int, XPUStream>(
unsigned long* d_shard_keys, unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads, int* idx, long long len,
const XPUStream& stream);
template void
HeterCommKernel::fill_dvals<paddle::framework::FeatureValue, int, XPUStream>(
paddle::framework::FeatureValue* d_shard_vals,
......@@ -348,9 +352,8 @@ template void HeterCommKernel::reduce_by_key<
size_t& temp_storage_bytes, // NOLINT
unsigned long* d_keys_in, unsigned long* d_unique_out,
paddle::framework::FeaturePushValue* d_values_in,
paddle::framework::FeaturePushValue* d_aggregates_out,
int* d_num_runs_out int num_items, XPUStream stream,
bool debug_synchronous);
paddle::framework::FeaturePushValue* d_aggregates_out, int* d_num_runs_out,
int num_items, XPUStream stream, bool debug_synchronous);
#endif
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps.h"
#include <vector>
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
HeterPsBase* HeterPsBase::get_instance(
size_t capacity, std::shared_ptr<HeterPsResource> resource) {
return new HeterPs(capacity, resource);
}
HeterPs::HeterPs(size_t capacity, std::shared_ptr<HeterPsResource> resource) {
comm_ =
std::make_shared<HeterComm<FeatureKey, FeatureValue, FeaturePushValue>>(
capacity, resource);
}
HeterPs::~HeterPs() {}
void HeterPs::pull_sparse(int num, FeatureKey* d_keys, FeatureValue* d_vals,
size_t len) {
comm_->pull_sparse(num, d_keys, d_vals, len);
}
void HeterPs::build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) {
comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num);
}
int HeterPs::get_index_by_devid(int devid) {
return comm_->get_index_by_devid(devid);
}
void HeterPs::end_pass() { comm_->end_pass(); }
void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); }
void HeterPs::push_sparse(int num, FeatureKey* d_keys,
FeaturePushValue* d_grads, size_t len) {
// comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_);
}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -69,6 +69,7 @@ XPUResource::XPUResource(std::vector<int>& dev_ids, int index) {
platform::XPUDeviceGuard guard(dev_id_);
local_streams_.resize(dev_ids_.size());
comm_streams_.resize(dev_ids_.size(), NULL);
remote_streams_.resize(dev_ids_.size());
......@@ -84,6 +85,7 @@ XPUResource::~XPUResource() {
for (size_t i = 0; i < local_streams_.size(); ++i) {
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(local_streams_[i]));
}
// for (size_t i = 0; i < comm_streams_.size(); ++i) {
// PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(comm_streams_[i]));
// }
......
......@@ -36,6 +36,7 @@ namespace framework {
#if defined(PADDLE_WITH_CUDA)
using ppStream = cudaStream_t;
#elif defined(PADDLE_WITH_XPU_KP)
using ppStream = XPUStream;
#endif
......@@ -61,6 +62,7 @@ class GPUResource {
std::vector<gpuStream_t> local_streams_;
std::vector<gpuStream_t> comm_streams_;
};
#elif defined(PADDLE_WITH_XPU_KP)
class XPUResource {
public:
......@@ -105,6 +107,7 @@ class HeterPsResource {
int get_index_by_devid(int devid);
int dev_id(int num);
void set_multi_mf(int multi_mf_dim, int max_mf_dim);
ppStream local_stream(int dev_num, int stream_num);
ppStream remote_stream(int dev_num, int stream_num);
ppStream comm_stream(int dev_num, int stream_num);
......
......@@ -31,6 +31,7 @@ limitations under the License. */
#include <algorithm>
#include <deque>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/platform/timer.h"
......@@ -690,7 +691,6 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
}
#endif
VLOG(3) << "GpuPs build hbmps done";
};
if (multi_mf_dim_) {
......@@ -753,7 +753,9 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
}
std::vector<std::thread> threads(device_num);
HeterPs_ = HeterPsBase::get_instance(size_max, resource_);
#ifdef PADDLE_WITH_CUDA
HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_);
#endif
auto build_func = [this, &gpu_task, &feature_keys_count](int i) {
VLOG(3) << "building table: " << i;
this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(),
......@@ -891,18 +893,27 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size) {
VLOG(3) << "Begine Gpu Ps PullSparse";
platform::Timer all_timer;
platform::Timer pull_gpups_timer;
all_timer.Start();
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
#ifdef PADDLE_WITH_CUDA
VLOG(3) << "Begine Gpu Ps PullSparse";
auto buf = memory::Alloc(place, total_length * sizeof(FeatureValue));
FeatureValue* total_values_gpu = reinterpret_cast<FeatureValue*>(buf->ptr());
#endif
#ifdef PADDLE_WITH_XPU_KP
VLOG(3) << "Begine Xpu Ps PullSparse";
FeatureValue* total_values_gpu = nullptr;
xpu_malloc(reinterpret_cast<void**>(&total_values_gpu),
total_length * sizeof(FeatureValue));
#endif
if (platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Warning:: CPUPlace is not supported in GpuPs now."));
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
int device_id = place.GetDeviceId();
int devid_2_index = HeterPs_->get_index_by_devid(device_id);
......@@ -942,9 +953,63 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len,
static_cast<int>(slot_lengths.size()), hidden_size,
total_length);
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU_KP
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
int device_id = place.GetDeviceId();
int devid_2_index = HeterPs_->get_index_by_devid(device_id);
LoDTensor& total_keys_tensor = keys_tensor[devid_2_index];
uint64_t* total_keys = reinterpret_cast<uint64_t*>(
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place));
// construct slot_level lod info
auto slot_lengths_lod = slot_lengths;
for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
uint64_t* buf_key = nullptr;
int64_t* buf_length = nullptr;
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&buf_key),
keys.size() * sizeof(uint64_t*)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&buf_length),
slot_lengths.size() * sizeof(int64_t)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
uint64_t** xpu_keys = reinterpret_cast<uint64_t**>(&buf_key);
int64_t* xpu_len = reinterpret_cast<int64_t*>(buf_length);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_memcpy(xpu_keys, keys.data(),
keys.size() * sizeof(uint64_t*),
XPU_HOST_TO_DEVICE));
PADDLE_ENFORCE_XPU_SUCCESS(xpu_memcpy(xpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t),
XPU_HOST_TO_DEVICE));
this->CopyKeys(place, xpu_keys, total_keys, xpu_len,
static_cast<int>(slot_lengths.size()),
static_cast<int>(total_length));
VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index
<< " len: " << total_length;
pull_gpups_timer.Start();
HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu,
static_cast<int>(total_length));
// PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
// "PullSparseGPU failed in GPUPS."));
pull_gpups_timer.Pause();
VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
<< "]";
this->CopyForPull(place, xpu_keys, values, total_values_gpu, xpu_len,
static_cast<int>(slot_lengths.size()), hidden_size,
total_length);
#endif
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GpuPs: PullSparse Only Support CUDAPlace Now."));
"GpuPs/XpuPs: PullSparse Only Support CUDAPlace or XPUPlace Now."));
}
all_timer.Pause();
VLOG(3) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec()
......@@ -959,15 +1024,23 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int batch_size) {
VLOG(3) << "Begin GPUPS PushSparseGrad";
platform::Timer all_timer;
platform::Timer push_gpups_timer;
all_timer.Start();
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
#ifdef PADDLE_WITH_CUDA
VLOG(3) << "Begin GPUPS PushSparseGrad";
auto buf = memory::Alloc(place, total_length * sizeof(FeaturePushValue));
FeaturePushValue* total_grad_values_gpu =
reinterpret_cast<FeaturePushValue*>(buf->ptr());
#endif
#ifdef PADDLE_WITH_XPU_KP
VLOG(3) << "Begine Xpu Ps PushSparseGrad";
FeaturePushValue* total_grad_values_gpu = nullptr;
xpu_malloc(reinterpret_cast<void**>(&total_grad_values_gpu),
total_length * sizeof(FeaturePushValue));
#endif
if (platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Warning:: CPUPlace is not supported in GPUPS now."));
......@@ -987,6 +1060,22 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
HeterPs_->push_sparse(devid_2_index, total_keys, total_grad_values_gpu,
static_cast<int>(total_length));
push_gpups_timer.Pause();
} else if (platform::is_xpu_place(place)) {
int device_id = place.GetDeviceId();
int devid_2_index = HeterPs_->get_index_by_devid(device_id);
LoDTensor& cached_total_keys_tensor = keys_tensor[devid_2_index];
uint64_t* total_keys =
reinterpret_cast<uint64_t*>(cached_total_keys_tensor.data<int64_t>());
VLOG(3) << "Begin copy grad tensor to xpups struct";
this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths,
hidden_size, total_length, batch_size);
VLOG(3) << "Begin call PushSparseXPU in XPUPS, dev: " << devid_2_index
<< " len: " << total_length;
push_gpups_timer.Start();
HeterPs_->push_sparse(devid_2_index, total_keys, total_grad_values_gpu,
static_cast<int>(total_length));
push_gpups_timer.Pause();
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GPUPS: PushSparseGrad Only Support CUDAPlace Now."));
......
......@@ -105,6 +105,8 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len,
}
}
PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; }
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
......
......@@ -30,16 +30,22 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#include "paddle/fluid/distributed/ps/thirdparty/round_robin.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/fleet/heter_context.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/heter_util.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_PSCORE
......@@ -55,6 +61,8 @@ namespace framework {
#define TYPEALIGN(ALIGNVAL, LEN) \
(((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))
class Dataset;
#ifdef PADDLE_WITH_PSLIB
class AfsWrapper {
public:
......@@ -82,7 +90,7 @@ class AfsWrapper {
class PSGPUWrapper {
public:
virtual ~PSGPUWrapper() { delete HeterPs_; }
virtual ~PSGPUWrapper();
PSGPUWrapper() {
HeterPs_ = NULL;
......@@ -160,6 +168,7 @@ class PSGPUWrapper {
PADDLE_THROW(
platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
#ifdef PADDLE_WITH_CUDA
if (multi_node_) {
int dev_size = dev_ids.size();
// init inner comm
......@@ -195,6 +204,7 @@ class PSGPUWrapper {
platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
}
#endif
heter_devices_ = dev_ids;
data_ready_channel_->Open();
data_ready_channel_->SetCapacity(3);
......@@ -262,7 +272,11 @@ class PSGPUWrapper {
? 1.0
: config["mf_max_bound"];
for (size_t i = 0; i < heter_devices_.size(); i++) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i]));
#elif defined(PADDLE_WITH_XPU_KP)
PADDLE_ENFORCE_XPU_SUCCESS(xpu_set_device(heter_devices_[i]));
#endif
this->SetSparseSGD(nonclk_coeff, clk_coeff, min_bound, max_bound,
learning_rate, initial_g2sum, initial_range);
this->SetEmbedxSGD(mf_create_thresholds, mf_learning_rate,
......@@ -270,6 +284,7 @@ class PSGPUWrapper {
mf_max_bound);
}
}
void SetDate(int year, int month, int day) {
year_ = year;
month_ = month;
......@@ -297,6 +312,7 @@ class PSGPUWrapper {
slot_offset_vector_ = slot_offset_vector;
}
#ifdef PADDLE_WITH_CUDA
void SetSlotDimVector(const std::vector<int>& slot_mf_dim_vector) {
slot_mf_dim_vector_ = slot_mf_dim_vector;
assert(slot_mf_dim_vector_.size() == slot_vector_.size());
......@@ -330,6 +346,7 @@ class PSGPUWrapper {
grad_type_size_ =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
}
#endif
void ShowOneTable(int index) { HeterPs_->show_one_table(index); }
......@@ -371,9 +388,11 @@ class PSGPUWrapper {
int multi_node_{0};
int node_size_;
uint64_t table_id_;
#ifdef PADDLE_WITH_CUDA
std::vector<ncclComm_t> inner_comms_;
std::vector<ncclComm_t> inter_comms_;
std::vector<ncclUniqueId> inter_ncclids_;
#endif
std::vector<int> heter_devices_;
std::unordered_set<std::string> gpu_ps_config_keys_;
HeterObjectPool<HeterContext> gpu_task_pool_;
......@@ -388,9 +407,11 @@ class PSGPUWrapper {
int day_;
int use_afs_api_ = 0;
#ifdef PADDLE_WITH_CUDA
std::vector<MemoryPool*> mem_pools_;
std::vector<HBMMemoryPool*> hbm_pools_; // in multi mfdim, one table need hbm
// pools of totol dims number
#endif
std::shared_ptr<
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_HETERPS
#include <xpu/runtime.h> // NOLINT
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "xpu/kernel/cluster_header.h" // NOLINT
#include "xpu/kernel/debug.h" // NOLINT
#include "xpu/kernel/math.h" // NOLINT
#include "xpu/kernel/simd.h"
namespace paddle {
namespace framework {
__global__ void PullCopy(float** dest, const FeatureValue* src,
const long long* len, int hidden, int slot_num,
int total_len, unsigned long long** keys) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
__local__ int64_t local_len[slot_num];
GM2LM(len, local_len, slot_num * sizeof(int64_t));
for (int i = thread_id; i < slot_num; i += nthreads) {
// max core local memory = 8KB
// slot's max memory size = slot_len * sizeof(FeatureValue)
int slot_len = i ? local_len[i] - local_len[i - 1] : local_len[0];
int read_len = min(roundup_div(1024 * 8, sizeof(FeatureValue)), slot_len);
int dest_len = i ? local_len[i - 1] : 0;
__local__ FeatureValue local_slot_vals[read_len];
__local__ float local_dest_vals[read_len * hidden];
__local__ uint64_t local_slot_keys[read_len];
// copy read_len (length) of slots' val to LM
for (int k = 0; k < slot_len; k += read_len) {
int real_read_len = min(read_len, slot_len - k);
GM2LM(src + dest_len + k, local_slot_vals,
real_read_len * sizeof(FeatureValue));
GM2LM(keys[i] + k, local_slot_keys, real_read_len * sizeof(uint64_t));
for (int j = 0; j < real_read_len; j++) {
if (local_slot_keys[j] == 0) {
local_dest_vals[j * hidden] = 0;
local_dest_vals[j * hidden + 1] = 0;
local_dest_vals[j * hidden + 2] = 0;
} else {
local_dest_vals[j * hidden] = local_slot_vals[j].show;
local_dest_vals[j * hidden + 1] = local_slot_vals[j].clk;
local_dest_vals[j * hidden + 2] = local_slot_vals[j].lr;
}
if (local_slot_vals[j].mf_size == 0 || local_slot_keys[j] == 0) {
for (int m = 0; m < hidden - 3; m++) {
local_dest_vals[j * hidden + 3 + m] = 0;
}
} else {
for (int m = 0; m < hidden - 3; m++) {
local_dest_vals[j * hidden + 3 + m] = local_slot_vals[j].mf[1 + m];
}
}
}
LM2GM(local_dest_vals, dest[i] + k * hidden,
real_read_len * hidden * sizeof(float));
}
}
}
__global__ void CopyKeysKernel(unsigned long long** src_keys,
unsigned long long* dest_total_keys,
const long long* len, int slot_num,
int total_len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
__local__ int64_t local_len[slot_num];
GM2LM(len, local_len, slot_num * sizeof(int64_t));
for (int i = thread_id; i < slot_num; i += nthreads) {
// max core local memory = 8KB
int slot_len = i ? local_len[i] - local_len[i - 1] : local_len[0];
int read_len = min(slot_len, 1024);
int dest_len = i ? local_len[i - 1] : 0;
__local__ uint64_t local_slot_keys[read_len];
for (int k = 0; k < slot_len; k += read_len) {
int real_read_len = min(read_len, slot_len - k);
GM2LM(src_keys[i] + k, local_slot_keys, real_read_len * sizeof(uint64_t));
LM2GM(local_slot_keys, dest_total_keys + dest_len + k,
real_read_len * sizeof(uint64_t));
}
}
}
__global__ void PushCopy(FeaturePushValue* dest, float** src, long long* len,
int hidden, int slot_num, int total_len, int bs,
int* slot_vector) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
__local__ int64_t local_len[slot_num];
__local__ int local_slot[slot_num];
GM2LM(len, local_len, slot_num * sizeof(int64_t));
GM2LM(slot_vector, local_slot, slot_num * sizeof(int));
for (int i = thread_id; i < slot_num; i += nthreads) {
int slot_len = i ? local_len[i] - local_len[i - 1] : local_len[0];
// max core local memory = 8KB
// slot's max memory size = slot_len * hidden * 8
int read_len = min(roundup_div(1024, hidden), slot_len);
int dest_len = i ? local_len[i - 1] : 0;
__local__ float local_slot_grads[read_len * hidden];
__local__ FeaturePushValue local_dest_grads[read_len];
// copy read_len(length) of slots' grad to LM
for (int k = 0; k < slot_len; k += read_len) {
int real_read_len = min(read_len, slot_len - k);
GM2LM(src[i] + k * hidden, local_slot_grads,
real_read_len * hidden * sizeof(float));
// copy from slots' grad to total grad
for (int j = 0; j < real_read_len; j++) {
local_dest_grads[j].slot = local_slot[i];
local_dest_grads[j].show = local_slot_grads[j * hidden];
local_dest_grads[j].clk = local_slot_grads[j * hidden + 1];
local_dest_grads[j].lr_g = local_slot_grads[j * hidden + 2] * -1. * bs;
for (int m = 0; m < hidden - 3; m++) {
local_dest_grads[j].mf_g[m] =
local_slot_grads[j * hidden + 3 + m] * -1. * bs;
}
}
LM2GM(local_dest_grads, dest + dest_len + k,
real_read_len * sizeof(FeaturePushValue));
}
}
}
PSGPUWrapper::~PSGPUWrapper() {
delete HeterPs_;
xpu_free((void*)optimizer_config::nonclk_coeff);
xpu_free((void*)optimizer_config::clk_coeff);
xpu_free((void*)optimizer_config::min_bound);
xpu_free((void*)optimizer_config::max_bound);
xpu_free((void*)optimizer_config::learning_rate);
xpu_free((void*)optimizer_config::initial_g2sum);
xpu_free((void*)optimizer_config::initial_range);
xpu_free((void*)optimizer_config::mf_create_thresholds);
xpu_free((void*)optimizer_config::mf_learning_rate);
xpu_free((void*)optimizer_config::mf_initial_g2sum);
xpu_free((void*)optimizer_config::mf_initial_range);
xpu_free((void*)optimizer_config::mf_min_bound);
xpu_free((void*)optimizer_config::mf_max_bound);
}
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu,
const int64_t* gpu_len, const int slot_num,
const int hidden_size,
const int64_t total_length) {
XPUStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::XPUDeviceContext*>(dev_ctx)
->x_context()
->xpu_stream;
float* buf_value = nullptr;
xpu_malloc(reinterpret_cast<void**>(&buf_value),
values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(&buf_value);
xpu_memcpy(gpu_values, values.data(), values.size() * sizeof(float*),
XPU_HOST_TO_DEVICE);
unsigned long long** c_keys = (unsigned long long**)gpu_keys;
const long long* c_len = (const long long*)gpu_len;
PullCopy<<<2, 64, stream>>>(gpu_values, total_values_gpu, c_len, hidden_size,
slot_num, total_length, c_keys);
xpu_wait(stream);
}
void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
uint64_t** origin_keys, uint64_t* total_keys,
const int64_t* gpu_len, int slot_num,
int total_len) {
XPUStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::XPUDeviceContext*>(dev_ctx)
->x_context()
->xpu_stream;
unsigned long long** o_keys = (unsigned long long**)origin_keys;
unsigned long long* t_keys = (unsigned long long*)total_keys;
const long long* c_len = (const long long*)gpu_len;
CopyKeysKernel<<<2, 64, stream>>>(o_keys, t_keys, c_len, slot_num, total_len);
xpu_wait(stream);
}
void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const int hidden_size,
const int64_t total_length,
const int batch_size) {
XPUStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::XPUDeviceContext*>(dev_ctx)
->x_context()
->xpu_stream;
auto slot_lengths_lod = slot_lengths;
for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
float* buf_grad_value = nullptr;
int64_t* buf_length = nullptr;
int* buf_slot_vector = nullptr;
xpu_malloc(reinterpret_cast<void**>(&buf_grad_value),
grad_values.size() * sizeof(float*));
xpu_malloc(reinterpret_cast<void**>(&buf_length),
slot_lengths.size() * sizeof(int64_t));
xpu_malloc(reinterpret_cast<void**>(&buf_slot_vector),
slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(&buf_grad_value);
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length);
int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector);
xpu_memcpy(gpu_values, grad_values.data(),
grad_values.size() * sizeof(float*), XPU_HOST_TO_DEVICE);
xpu_memcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), XPU_HOST_TO_DEVICE);
xpu_memcpy(d_slot_vector, slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int), XPU_HOST_TO_DEVICE);
long long* c_len = (long long*)gpu_len;
PushCopy<<<2, 64, stream>>>(total_grad_values_gpu, gpu_values, c_len,
hidden_size, slot_lengths.size(), total_length,
batch_size, d_slot_vector);
xpu_wait(stream);
}
void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
float min_bound, float max_bound,
float learning_rate, float initial_g2sum,
float initial_range) {
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::nonclk_coeff),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::clk_coeff),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::min_bound),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::max_bound),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::learning_rate),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::initial_g2sum),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::initial_range),
sizeof(float));
xpu_memcpy((void*)optimizer_config::nonclk_coeff, &nonclk_coeff,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::clk_coeff, &clk_coeff, sizeof(float),
XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::min_bound, &min_bound, sizeof(float),
XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::max_bound, &max_bound, sizeof(float),
XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::learning_rate, &learning_rate,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::initial_g2sum, &initial_g2sum,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::initial_range, &initial_range,
sizeof(float), XPU_HOST_TO_DEVICE);
}
void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds,
float mf_learning_rate, float mf_initial_g2sum,
float mf_initial_range, float mf_min_bound,
float mf_max_bound) {
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_create_thresholds),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_learning_rate),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_initial_g2sum),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_initial_range),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_min_bound),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_max_bound),
sizeof(float));
xpu_memcpy((void*)optimizer_config::mf_create_thresholds,
&mf_create_thresholds, sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::mf_initial_g2sum, &mf_initial_g2sum,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::mf_initial_range, &mf_initial_range,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::mf_min_bound, &mf_min_bound,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::mf_max_bound, &mf_max_bound,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::mf_learning_rate, &mf_learning_rate,
sizeof(float), XPU_HOST_TO_DEVICE);
}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -25,7 +25,9 @@ limitations under the License. */
#include "paddle/fluid/framework/trainer.h"
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \
(defined PADDLE_WITH_PSLIB)
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
namespace paddle {
namespace framework {
......@@ -56,7 +58,12 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
std::vector<int> dev_ids;
for (int i = 0; i < place_num; ++i) {
int num = trainer_desc.worker_places(i);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace place = platform::CUDAPlace(num);
#endif
#ifdef PADDLE_WITH_XPU_KP
platform::XPUPlace place = platform::XPUPlace(num);
#endif
places_.push_back(place);
dev_ids.push_back(num);
}
......
......@@ -20,7 +20,9 @@ limitations under the License. */
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \
(defined PADDLE_WITH_PSLIB)
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#if defined _WIN32 || defined __APPLE__
#else
......
......@@ -132,5 +132,7 @@ REGISTER_OPERATOR(pull_box_sparse, ops::PullBoxSparseOp,
ops::PushBoxSparseOpMaker<paddle::framework::OpDesc>,
ops::PushBoxSparseOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(push_box_sparse, ops::PushBoxSparseOp);
REGISTER_OP_CPU_KERNEL(pull_box_sparse, ops::PullBoxSparseCPUKernel<float>)
REGISTER_OP_CPU_KERNEL(push_box_sparse, ops::PushBoxSparseCPUKernel<float>)
REGISTER_OP_CPU_KERNEL(pull_box_sparse, ops::PullBoxSparseCPUKernel<float>);
REGISTER_OP_CPU_KERNEL(push_box_sparse, ops::PushBoxSparseCPUKernel<float>);
REGISTER_OP_XPU_KERNEL(pull_box_sparse, ops::PullBoxSparseXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(push_box_sparse, ops::PushBoxSparseXPUKernel<float>);
......@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/pull_box_sparse_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
......@@ -38,7 +37,7 @@ class PushBoxSparseCUDAKernel : public framework::OpKernel<T> {
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(pull_box_sparse, ops::PullBoxSparseCUDAKernel<float>)
REGISTER_OP_CUDA_KERNEL(push_box_sparse, ops::PushBoxSparseCUDAKernel<float>)
REGISTER_OP_CUDA_KERNEL(pull_box_sparse, ops::PullBoxSparseCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(push_box_sparse, ops::PushBoxSparseCUDAKernel<float>);
......@@ -114,5 +114,21 @@ class PushBoxSparseCPUKernel : public framework::OpKernel<T> {
PushBoxSparseFunctor<T>(ctx);
}
};
template <typename T>
class PullBoxSparseXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PullBoxSparseFunctor<T>(ctx);
}
};
template <typename T>
class PushBoxSparseXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PushBoxSparseFunctor<T>(ctx);
}
};
} // namespace operators
} // namespace paddle
......@@ -25,6 +25,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/pybind/ps_gpu_wrapper_py.h"
......@@ -39,8 +40,10 @@ void BindPSGPUWrapper(py::module* m) {
.def(py::init([]() { return framework::PSGPUWrapper::GetInstance(); }))
.def("set_slot_vector", &framework::PSGPUWrapper::SetSlotVector,
py::call_guard<py::gil_scoped_release>())
#ifdef PADDLE_WITH_CUDA
.def("set_slot_dim_vector", &framework::PSGPUWrapper::SetSlotDimVector,
py::call_guard<py::gil_scoped_release>())
#endif
.def("set_slot_offset_vector",
&framework::PSGPUWrapper::SetSlotOffsetVector,
py::call_guard<py::gil_scoped_release>())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册