未验证 提交 bb0713b2 编写于 作者: 石晓伟 提交者: GitHub

changes the call AllocShared to Alloc, test=develop (#38258)

上级 2635cc86
...@@ -140,7 +140,7 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, ...@@ -140,7 +140,7 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place,
platform::DeviceContextPool::Instance().Get( platform::DeviceContextPool::Instance().Get(
BOOST_GET_CONST(platform::CUDAPlace, place))) BOOST_GET_CONST(platform::CUDAPlace, place)))
->stream(); ->stream();
auto buf_value = memory::AllocShared(place, values.size() * sizeof(float*)); auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr()); float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hipMemcpy(gpu_values, values.data(), values.size() * sizeof(float*), hipMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
...@@ -233,11 +233,10 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place, ...@@ -233,11 +233,10 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place,
slot_lengths_lod[i] += slot_lengths_lod[i - 1]; slot_lengths_lod[i] += slot_lengths_lod[i - 1];
} }
auto buf_grad_value = auto buf_grad_value =
memory::AllocShared(place, grad_values.size() * sizeof(float*)); memory::Alloc(place, grad_values.size() * sizeof(float*));
auto buf_length = auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t));
auto buf_slot_vector = auto buf_slot_vector =
memory::AllocShared(place, slot_lengths_lod.size() * sizeof(int)); memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr()); float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr()); int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
......
...@@ -32,7 +32,7 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, ...@@ -32,7 +32,7 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place,
int64_t total_length = int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
auto buf = memory::AllocShared( auto buf = memory::Alloc(
place, total_length * place, total_length *
sizeof(boxps::FeatureValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>)); sizeof(boxps::FeatureValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>));
boxps::FeatureValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>* total_values_gpu = boxps::FeatureValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>* total_values_gpu =
...@@ -55,9 +55,9 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, ...@@ -55,9 +55,9 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place,
for (size_t i = 1; i < slot_lengths_lod.size(); i++) { for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1]; slot_lengths_lod[i] += slot_lengths_lod[i - 1];
} }
auto buf_key = memory::AllocShared(place, keys.size() * sizeof(uint64_t*)); auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*));
auto buf_length = auto buf_length =
memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t)); memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
uint64_t** gpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr()); uint64_t** gpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr()); int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -118,7 +118,7 @@ void BoxWrapper::PushSparseGradCase( ...@@ -118,7 +118,7 @@ void BoxWrapper::PushSparseGradCase(
all_timer.Start(); all_timer.Start();
int64_t total_length = int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
auto buf = memory::AllocShared( auto buf = memory::Alloc(
place, place,
total_length * total_length *
sizeof(boxps::FeaturePushValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>)); sizeof(boxps::FeaturePushValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>));
......
...@@ -17,9 +17,10 @@ limitations under the License. */ ...@@ -17,9 +17,10 @@ limitations under the License. */
#include <vector> #include <vector>
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "cub/util_allocator.cuh" #include "cub/util_allocator.cuh"
#include "hashtable.h" #include "hashtable.h" // NOLINT
#include "heter_resource.h" #include "heter_resource.h" // NOLINT
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
...@@ -58,7 +59,7 @@ class HeterComm { ...@@ -58,7 +59,7 @@ class HeterComm {
void split_input_to_shard(KeyType* d_keys, int* d_idx_ptr, size_t len, void split_input_to_shard(KeyType* d_keys, int* d_idx_ptr, size_t len,
int* left, int* right, int gpu_num); int* left, int* right, int gpu_num);
void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len); int& uniq_len); // NOLINT
void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len); void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len);
void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len, void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len,
size_t chunk_size, int stream_num); size_t chunk_size, int stream_num);
...@@ -68,15 +69,15 @@ class HeterComm { ...@@ -68,15 +69,15 @@ class HeterComm {
template <typename Sgd> template <typename Sgd>
void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len, void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd); Sgd& sgd); // NOLINT
template <typename Sgd> template <typename Sgd>
void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads, void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads,
size_t len, Sgd& sgd); size_t len, Sgd& sgd); // NOLINT
template <typename Sgd> template <typename Sgd>
void update_one_table(int num, KeyType* d_keys, GradType* d_grads, size_t len, void update_one_table(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd); Sgd& sgd); // NOLINT
int gather_one_node_grad(int num, KeyType* d_keys, GradType* d_grads, int gather_one_node_grad(int num, KeyType* d_keys, GradType* d_grads,
int len); int len);
...@@ -136,16 +137,16 @@ class HeterComm { ...@@ -136,16 +137,16 @@ class HeterComm {
if (force || size > all_keys_mem->size()) { if (force || size > all_keys_mem->size()) {
all_keys_mem.reset(); all_keys_mem.reset();
all_grads_mem.reset(); all_grads_mem.reset();
all_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType)); all_keys_mem = memory::Alloc(place_, size * sizeof(KeyType));
all_grads_mem = memory::AllocShared(place_, size * sizeof(GradType)); all_grads_mem = memory::Alloc(place_, size * sizeof(GradType));
all_keys = reinterpret_cast<KeyType*>(all_keys_mem->ptr()); all_keys = reinterpret_cast<KeyType*>(all_keys_mem->ptr());
all_grads = reinterpret_cast<GradType*>(all_grads_mem->ptr()); all_grads = reinterpret_cast<GradType*>(all_grads_mem->ptr());
} }
if (force || size > local_keys_mem->size()) { if (force || size > local_keys_mem->size()) {
local_keys_mem.reset(); local_keys_mem.reset();
local_grads_mem.reset(); local_grads_mem.reset();
local_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType)); local_keys_mem = memory::Alloc(place_, size * sizeof(KeyType));
local_grads_mem = memory::AllocShared(place_, size * sizeof(GradType)); local_grads_mem = memory::Alloc(place_, size * sizeof(GradType));
local_keys = reinterpret_cast<KeyType*>(local_keys_mem->ptr()); local_keys = reinterpret_cast<KeyType*>(local_keys_mem->ptr());
local_grads = reinterpret_cast<GradType*>(local_grads_mem->ptr()); local_grads = reinterpret_cast<GradType*>(local_grads_mem->ptr());
} }
......
...@@ -28,7 +28,7 @@ __global__ void fill_idx(T* idx, size_t len) { ...@@ -28,7 +28,7 @@ __global__ void fill_idx(T* idx, size_t len) {
template <typename T> template <typename T>
void show_tensor(T* input, size_t len, gpuStream_t stream, std::string name) { void show_tensor(T* input, size_t len, gpuStream_t stream, std::string name) {
T tmp[len]; T tmp[len]; // NOLINT
cudaMemcpyAsync(&tmp, input, sizeof(T) * len, cudaMemcpyDeviceToHost, stream); cudaMemcpyAsync(&tmp, input, sizeof(T) * len, cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
std::cout << name; std::cout << name;
...@@ -101,7 +101,7 @@ HeterComm<KeyType, ValType, GradType>::HeterComm( ...@@ -101,7 +101,7 @@ HeterComm<KeyType, ValType, GradType>::HeterComm(
for (int i = 0; i < resource_->total_gpu(); ++i) { for (int i = 0; i < resource_->total_gpu(); ++i) {
platform::CUDADeviceGuard guard(resource_->dev_id(i)); platform::CUDADeviceGuard guard(resource_->dev_id(i));
allocators_.push_back(std::make_shared<cub::CachingDeviceAllocator>( allocators_.push_back(std::make_shared<cub::CachingDeviceAllocator>(
8, 1, (unsigned int)-1, (size_t)-1, false, false)); 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT
auto table = new Table(capacity / load_factor_); auto table = new Table(capacity / load_factor_);
tables_.push_back(table); tables_.push_back(table);
if (multi_node_) { if (multi_node_) {
...@@ -174,10 +174,12 @@ void HeterComm<KeyType, ValType, GradType>::create_storage(int start_index, ...@@ -174,10 +174,12 @@ void HeterComm<KeyType, ValType, GradType>::create_storage(int start_index,
for (size_t i = 0; i < nodes.size(); ++i) { for (size_t i = 0; i < nodes.size(); ++i) {
platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num)); platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num));
allocator->DeviceAllocate( allocator->DeviceAllocate(
resource_->dev_id(nodes[i].gpu_num), (void**)&(nodes[i].key_storage), resource_->dev_id(nodes[i].gpu_num),
(void**)&(nodes[i].key_storage), // NOLINT
keylen, resource_->remote_stream(nodes[i].gpu_num, start_index)); keylen, resource_->remote_stream(nodes[i].gpu_num, start_index));
allocator->DeviceAllocate( allocator->DeviceAllocate(
resource_->dev_id(nodes[i].gpu_num), (void**)&(nodes[i].val_storage), resource_->dev_id(nodes[i].gpu_num),
(void**)&(nodes[i].val_storage), // NOLINT
vallen, resource_->remote_stream(nodes[i].gpu_num, start_index)); vallen, resource_->remote_stream(nodes[i].gpu_num, start_index));
nodes[i].key_bytes_len = keylen; nodes[i].key_bytes_len = keylen;
...@@ -342,16 +344,16 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys, ...@@ -342,16 +344,16 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
platform::CUDAPlace place = platform::CUDAPlace(dev_id); platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id); platform::CUDADeviceGuard guard(dev_id);
std::vector<std::shared_ptr<memory::Allocation>> d_key_bufs; std::vector<memory::allocation::AllocationPtr> d_key_bufs;
std::vector<std::shared_ptr<memory::Allocation>> d_val_bufs; std::vector<memory::allocation::AllocationPtr> d_val_bufs;
gpuStream_t streams[stream_num]; gpuStream_t streams[stream_num]; // NOLINT
for (int i = 0; i < stream_num; ++i) { for (int i = 0; i < stream_num; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&(streams[i]))); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&(streams[i])));
auto d_k_buf = memory::AllocShared(place, chunk_size * sizeof(KeyType)); auto d_k_buf = memory::Alloc(place, chunk_size * sizeof(KeyType));
auto d_v_buf = memory::AllocShared(place, chunk_size * sizeof(ValType)); auto d_v_buf = memory::Alloc(place, chunk_size * sizeof(ValType));
d_key_bufs.push_back(d_k_buf); d_key_bufs.push_back(std::move(d_k_buf));
d_val_bufs.push_back(d_v_buf); d_val_bufs.push_back(std::move(d_v_buf));
} }
int cur_len = 0; int cur_len = 0;
...@@ -383,11 +385,9 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys, ...@@ -383,11 +385,9 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
} }
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num, void HeterComm<KeyType, ValType, GradType>::merge_grad(
KeyType* d_keys, int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
GradType* d_grads, int& uniq_len) { // NOLINT
size_t len,
int& uniq_len) {
int dev_id = resource_->dev_id(gpu_num); int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id); platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id); platform::CUDADeviceGuard guard(dev_id);
...@@ -395,10 +395,10 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num, ...@@ -395,10 +395,10 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num,
size_t temp_storage_bytes; size_t temp_storage_bytes;
auto d_merge_keys = memory::AllocShared(place, len * sizeof(KeyType)); auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr()); KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr());
auto d_merge_grads = memory::AllocShared(place, len * sizeof(GradType)); auto d_merge_grads = memory::Alloc(place, len * sizeof(GradType));
GradType* d_merge_grads_ptr = GradType* d_merge_grads_ptr =
reinterpret_cast<GradType*>(d_merge_grads->ptr()); reinterpret_cast<GradType*>(d_merge_grads->ptr());
...@@ -407,14 +407,14 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num, ...@@ -407,14 +407,14 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num,
d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false)); d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false));
void* d_buff = NULL; void* d_buff = NULL;
auto d_temp_storage = memory::AllocShared(place, temp_storage_bytes); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr, d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr,
d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false)); d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false));
temp_storage_bytes = 0; temp_storage_bytes = 0;
auto d_num_runs_out_mem = memory::AllocShared(place, sizeof(int)); auto d_num_runs_out_mem = memory::Alloc(place, sizeof(int));
int* d_num_runs_out = reinterpret_cast<int*>(d_num_runs_out_mem->ptr()); int* d_num_runs_out = reinterpret_cast<int*>(d_num_runs_out_mem->ptr());
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey(
...@@ -423,7 +423,7 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num, ...@@ -423,7 +423,7 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num,
if (d_temp_storage->size() < temp_storage_bytes) { if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL; d_temp_storage = NULL;
d_temp_storage = memory::AllocShared(place, temp_storage_bytes); d_temp_storage = memory::Alloc(place, temp_storage_bytes);
} }
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey(
...@@ -445,13 +445,13 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard( ...@@ -445,13 +445,13 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
platform::CUDADeviceGuard guard(dev_id); platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0); auto stream = resource_->local_stream(gpu_num, 0);
auto d_idx_tmp = memory::AllocShared(place, len * sizeof(int)); auto d_idx_tmp = memory::Alloc(place, len * sizeof(int));
int* d_idx_tmp_ptr = reinterpret_cast<int*>(d_idx_tmp->ptr()); int* d_idx_tmp_ptr = reinterpret_cast<int*>(d_idx_tmp->ptr());
auto d_shard_index = memory::AllocShared(place, len * sizeof(int)); auto d_shard_index = memory::Alloc(place, len * sizeof(int));
int* d_shard_index_ptr = reinterpret_cast<int*>(d_shard_index->ptr()); int* d_shard_index_ptr = reinterpret_cast<int*>(d_shard_index->ptr());
auto d_shard_index_tmp = memory::AllocShared(place, len * sizeof(int)); auto d_shard_index_tmp = memory::Alloc(place, len * sizeof(int));
int* d_shard_index_tmp_ptr = reinterpret_cast<int*>(d_shard_index_tmp->ptr()); int* d_shard_index_tmp_ptr = reinterpret_cast<int*>(d_shard_index_tmp->ptr());
int grid_size = (len - 1) / block_size_ + 1; int grid_size = (len - 1) / block_size_ + 1;
...@@ -465,7 +465,7 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard( ...@@ -465,7 +465,7 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
NULL, temp_storage_bytes, d_shard_index_tmp_ptr, d_shard_index_ptr, NULL, temp_storage_bytes, d_shard_index_tmp_ptr, d_shard_index_ptr,
d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream)); d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream));
auto d_temp_storage = memory::AllocShared(place, temp_storage_bytes); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
d_temp_storage->ptr(), temp_storage_bytes, d_shard_index_tmp_ptr, d_temp_storage->ptr(), temp_storage_bytes, d_shard_index_tmp_ptr,
d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream)); d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream));
...@@ -491,23 +491,23 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, ...@@ -491,23 +491,23 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
int grid_size = (len - 1) / block_size_ + 1; int grid_size = (len - 1) / block_size_ + 1;
int h_left[total_gpu]; int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; int h_right[total_gpu]; // NOLINT
auto d_left = memory::AllocShared(place, total_gpu * sizeof(int)); auto d_left = memory::Alloc(place, total_gpu * sizeof(int));
auto d_right = memory::AllocShared(place, total_gpu * sizeof(int)); auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr()); int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr()); int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream); cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream); cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream);
// //
auto d_idx = memory::AllocShared(place, len * sizeof(int)); auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr()); int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
auto d_shard_keys = memory::AllocShared(place, len * sizeof(KeyType)); auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr()); KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
auto d_shard_vals = memory::AllocShared(place, len * sizeof(ValType)); auto d_shard_vals = memory::Alloc(place, len * sizeof(ValType));
ValType* d_shard_vals_ptr = reinterpret_cast<ValType*>(d_shard_vals->ptr()); ValType* d_shard_vals_ptr = reinterpret_cast<ValType*>(d_shard_vals->ptr());
split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num); split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num);
...@@ -574,7 +574,8 @@ template <typename Sgd> ...@@ -574,7 +574,8 @@ template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num, void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
KeyType* d_keys, KeyType* d_keys,
GradType* d_grads, GradType* d_grads,
size_t len, Sgd& sgd) { size_t len,
Sgd& sgd) { // NOLINT
if (len == 0) { if (len == 0) {
return; return;
} }
...@@ -585,23 +586,23 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num, ...@@ -585,23 +586,23 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
platform::CUDADeviceGuard guard(dev_id); platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0); auto stream = resource_->local_stream(gpu_num, 0);
int h_left[total_gpu]; int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; int h_right[total_gpu]; // NOLINT
auto d_left = memory::AllocShared(place, total_gpu * sizeof(int)); auto d_left = memory::Alloc(place, total_gpu * sizeof(int));
auto d_right = memory::AllocShared(place, total_gpu * sizeof(int)); auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr()); int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr()); int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream); cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream); cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream);
// //
auto d_idx = memory::AllocShared(place, len * sizeof(int)); auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr()); int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
auto d_shard_keys = memory::AllocShared(place, len * sizeof(KeyType)); auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr()); KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
auto d_shard_grads = memory::AllocShared(place, len * sizeof(GradType)); auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType));
GradType* d_shard_grads_ptr = GradType* d_shard_grads_ptr =
reinterpret_cast<GradType*>(d_shard_grads->ptr()); reinterpret_cast<GradType*>(d_shard_grads->ptr());
...@@ -664,7 +665,8 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num, ...@@ -664,7 +665,8 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd> template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::update_one_table( void HeterComm<KeyType, ValType, GradType>::update_one_table(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) { int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd) { // NOLINT
if (len == 0) { if (len == 0) {
return; return;
} }
...@@ -681,7 +683,8 @@ void HeterComm<KeyType, ValType, GradType>::update_one_table( ...@@ -681,7 +683,8 @@ void HeterComm<KeyType, ValType, GradType>::update_one_table(
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd> template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::push_sparse_multi_node( void HeterComm<KeyType, ValType, GradType>::push_sparse_multi_node(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) { int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd) { // NOLINT
if (len == 0) { if (len == 0) {
return; return;
} }
...@@ -711,8 +714,8 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad( ...@@ -711,8 +714,8 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
ncclComm_t nccl_inner_comm = nccl_inner_comms_[gpu_num]; ncclComm_t nccl_inner_comm = nccl_inner_comms_[gpu_num];
// alloc for size // alloc for size
int h_node_len[total_gpu]; int h_node_len[total_gpu]; // NOLINT
auto d_node_len_mem = memory::AllocShared(place, total_gpu * sizeof(int)); auto d_node_len_mem = memory::Alloc(place, total_gpu * sizeof(int));
int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr()); int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr());
h_node_len[gpu_num] = len; h_node_len[gpu_num] = len;
...@@ -721,8 +724,9 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad( ...@@ -721,8 +724,9 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
// allgather grad len // allgather grad len
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(
(const void*)(d_node_len + gpu_num), (void*)d_node_len, 1, ncclInt, platform::dynload::ncclAllGather((const void*)(d_node_len + gpu_num),
(void*)d_node_len, 1, ncclInt, // NOLINT
nccl_inner_comm, stream)); nccl_inner_comm, stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
...@@ -747,17 +751,17 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad( ...@@ -747,17 +751,17 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
int h_left[total_gpu]; int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; int h_right[total_gpu]; // NOLINT
auto d_left = memory::AllocShared(place, total_gpu * sizeof(int)); auto d_left = memory::Alloc(place, total_gpu * sizeof(int));
auto d_right = memory::AllocShared(place, total_gpu * sizeof(int)); auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr()); int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr()); int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
int merge_num = 0; int merge_num = 0;
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
int index = i * max_size; int index = i * max_size;
auto d_idx = memory::AllocShared(place, h_node_len[i] * sizeof(int)); auto d_idx = memory::Alloc(place, h_node_len[i] * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr()); int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int)); cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int));
...@@ -794,8 +798,8 @@ int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad( ...@@ -794,8 +798,8 @@ int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad(
int max_size = 0; int max_size = 0;
ncclComm_t nccl_inter_comm = nccl_inter_comms_[gpu_num]; ncclComm_t nccl_inter_comm = nccl_inter_comms_[gpu_num];
// alloc for size // alloc for size
int h_node_len[node_size_]; int h_node_len[node_size_]; // NOLINT
auto d_node_len_mem = memory::AllocShared(place, node_size_ * sizeof(int)); auto d_node_len_mem = memory::Alloc(place, node_size_ * sizeof(int));
int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr()); int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr());
h_node_len[0] = len; h_node_len[0] = len;
......
...@@ -592,7 +592,7 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, ...@@ -592,7 +592,7 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
all_timer.Start(); all_timer.Start();
int64_t total_length = int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
auto buf = memory::AllocShared(place, total_length * sizeof(FeatureValue)); auto buf = memory::Alloc(place, total_length * sizeof(FeatureValue));
FeatureValue* total_values_gpu = reinterpret_cast<FeatureValue*>(buf->ptr()); FeatureValue* total_values_gpu = reinterpret_cast<FeatureValue*>(buf->ptr());
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
...@@ -610,9 +610,9 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, ...@@ -610,9 +610,9 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
for (size_t i = 1; i < slot_lengths_lod.size(); i++) { for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1]; slot_lengths_lod[i] += slot_lengths_lod[i - 1];
} }
auto buf_key = memory::AllocShared(place, keys.size() * sizeof(uint64_t*)); auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*));
auto buf_length = auto buf_length =
memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t)); memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
uint64_t** gpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr()); uint64_t** gpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr()); int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*),
...@@ -660,8 +660,7 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, ...@@ -660,8 +660,7 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
all_timer.Start(); all_timer.Start();
int64_t total_length = int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
auto buf = auto buf = memory::Alloc(place, total_length * sizeof(FeaturePushValue));
memory::AllocShared(place, total_length * sizeof(FeaturePushValue));
FeaturePushValue* total_grad_values_gpu = FeaturePushValue* total_grad_values_gpu =
reinterpret_cast<FeaturePushValue*>(buf->ptr()); reinterpret_cast<FeaturePushValue*>(buf->ptr());
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
......
...@@ -116,7 +116,7 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, ...@@ -116,7 +116,7 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
platform::DeviceContextPool::Instance().Get( platform::DeviceContextPool::Instance().Get(
BOOST_GET_CONST(platform::CUDAPlace, place))) BOOST_GET_CONST(platform::CUDAPlace, place)))
->stream(); ->stream();
auto buf_value = memory::AllocShared(place, values.size() * sizeof(float*)); auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr()); float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*), cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
...@@ -156,11 +156,10 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, ...@@ -156,11 +156,10 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
slot_lengths_lod[i] += slot_lengths_lod[i - 1]; slot_lengths_lod[i] += slot_lengths_lod[i - 1];
} }
auto buf_grad_value = auto buf_grad_value =
memory::AllocShared(place, grad_values.size() * sizeof(float*)); memory::Alloc(place, grad_values.size() * sizeof(float*));
auto buf_length = auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t));
auto buf_slot_vector = auto buf_slot_vector =
memory::AllocShared(place, slot_lengths_lod.size() * sizeof(int)); memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr()); float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr()); int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
......
...@@ -102,8 +102,8 @@ struct TransposeNormal<platform::CUDADeviceContext, T> { ...@@ -102,8 +102,8 @@ struct TransposeNormal<platform::CUDADeviceContext, T> {
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()); BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace());
platform::CPUPlace cpu_place = platform::CPUPlace(); platform::CPUPlace cpu_place = platform::CPUPlace();
size_t size = 3 * rank * sizeof(int64_t); size_t size = 3 * rank * sizeof(int64_t);
auto cpu_buf_holder = memory::AllocShared(cpu_place, size); auto cpu_buf_holder = memory::Alloc(cpu_place, size);
auto cuda_buf_holder = memory::AllocShared(cuda_place, size); auto cuda_buf_holder = memory::Alloc(cuda_place, size);
REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr()); REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr());
REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr()); REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr());
for (int i = 0; i < rank; ++i) { for (int i = 0; i < rank; ++i) {
......
...@@ -69,8 +69,8 @@ struct TransposeNormal<CUDAContext, T> { ...@@ -69,8 +69,8 @@ struct TransposeNormal<CUDAContext, T> {
BOOST_GET_CONST(paddle::platform::CUDAPlace, dev_ctx.GetPlace()); BOOST_GET_CONST(paddle::platform::CUDAPlace, dev_ctx.GetPlace());
paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace(); paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace();
size_t size = 3 * rank * sizeof(int64_t); size_t size = 3 * rank * sizeof(int64_t);
auto cpu_buf_holder = paddle::memory::AllocShared(cpu_place, size); auto cpu_buf_holder = paddle::memory::Alloc(cpu_place, size);
auto cuda_buf_holder = paddle::memory::AllocShared(cuda_place, size); auto cuda_buf_holder = paddle::memory::Alloc(cuda_place, size);
REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr()); REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr());
REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr()); REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr());
for (int i = 0; i < rank; ++i) { for (int i = 0; i < rank; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册