未验证 提交 7fc2ce50 编写于 作者: T Thunderbrook 提交者: GitHub

add topo-aware in heter-ps (#30087) (#30117)

* add topo aware

* resource.h

* topo aware

* format
上级 faeee3c3
......@@ -1225,6 +1225,13 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
void FleetWrapper::LoadWithWhitelist(const uint64_t table_id,
const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->load_with_whitelist(table_id, path,
std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id
<< ", from path: " << path << " failed";
}
#else
VLOG(0) << "FleetWrapper::LoadWhitelist does nothing when no pslib";
#endif
......@@ -1349,7 +1356,16 @@ int32_t FleetWrapper::SaveWithWhitelist(int table_id, const std::string& path,
const int mode,
const std::string& whitelist_path) {
#ifdef PADDLE_WITH_PSLIB
return 0;
auto ret = pslib_ptr_->_worker_ptr->save_with_whitelist(
table_id, path, std::to_string(mode), whitelist_path);
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "table save cache failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return feasign_cnt;
#else
VLOG(0) << "FleetWrapper::SaveCache does nothing when no pslib";
return -1;
......
......@@ -765,7 +765,7 @@ x.second );
unsigned long long get_num_collisions() const { return m_collisions; }
void print() {
for (size_type i = 0; i < m_hashtbl_size; ++i) {
for (size_type i = 0; i < 10; ++i) {
std::cout << i << ": " << m_hashtbl_values[i].first << ","
<< m_hashtbl_values[i].second << std::endl;
}
......
......@@ -68,6 +68,34 @@ class HeterComm {
Sgd& sgd);
int log2i(int x);
bool need_transfer(int send_id, int receive_id) {
return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
}
int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }
struct Node {
cudaStream_t in_stream;
cudaStream_t out_stream;
char* key_storage;
char* val_storage;
int sync;
int key_bytes_len;
int val_bytes_len;
int gpu_num;
};
struct Path {
std::vector<Node> nodes_;
};
void init_path();
void create_storage(
int start_index, int end_index, int keylen, int vallen,
std::vector<std::shared_ptr<memory::Allocation>>& local_strorage);
void walk_to_src(int start_index, int end_index, char* src_val);
void walk_to_dest(int start_index, int end_index, char* src_key,
char* src_val);
private:
using Table = HashTable<KeyType, ValType>;
......@@ -76,6 +104,8 @@ class HeterComm {
std::vector<Table*> tables_;
std::shared_ptr<HeterPsResource> resource_;
CustomGradMerger merger_;
int topo_aware_{1};
std::vector<std::vector<Path>> path_;
};
} // end namespace framework
......
......@@ -100,6 +100,131 @@ HeterComm<KeyType, ValType, GradType>::HeterComm(
auto table = new Table(capacity / load_factor_);
tables_.push_back(table);
}
init_path();
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::init_path() {
int total_gpu = resource_->total_gpu();
path_.resize(total_gpu);
if (!topo_aware_) {
VLOG(1) << "init path without topo aware";
for (int i = 0; i < total_gpu; ++i) {
path_[i].resize(total_gpu);
for (int j = 0; j < total_gpu; ++j) {
auto& nodes = path_[i][j].nodes_;
nodes.resize(1);
nodes[0].in_stream = resource_->comm_stream(i, j);
nodes[0].out_stream = resource_->comm_stream(j, i);
nodes[0].key_storage = NULL;
nodes[0].val_storage = NULL;
nodes[0].sync = 0;
nodes[0].gpu_num = j;
}
}
} else {
VLOG(1) << "init path with topo aware";
for (int i = 0; i < total_gpu; ++i) {
path_[i].resize(total_gpu);
for (int j = 0; j < total_gpu; ++j) {
auto& nodes = path_[i][j].nodes_;
int from = resource_->dev_id(i);
int to = resource_->dev_id(j);
int transfer_id = i;
if (need_transfer(from, to)) {
transfer_id = resource_->get_index_by_devid(get_transfer_devid(from));
nodes.push_back(Node());
Node& node = nodes.back();
node.in_stream = resource_->comm_stream(i, transfer_id);
node.out_stream = resource_->comm_stream(transfer_id, i);
node.key_storage = NULL;
node.val_storage = NULL;
node.sync = 1;
node.gpu_num = transfer_id;
}
nodes.push_back(Node());
Node& node = nodes.back();
node.in_stream = resource_->comm_stream(i, transfer_id);
node.out_stream = resource_->comm_stream(transfer_id, i);
node.key_storage = NULL;
node.val_storage = NULL;
node.sync = 0;
node.gpu_num = j;
}
}
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::create_storage(
int start_index, int end_index, int keylen, int vallen,
std::vector<std::shared_ptr<memory::Allocation>>& local_storage) {
auto& nodes = path_[start_index][end_index].nodes_;
for (size_t i = 0; i < nodes.size(); ++i) {
platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num));
platform::CUDAPlace remote_place =
platform::CUDAPlace(resource_->dev_id(nodes[i].gpu_num));
auto key_mem = memory::AllocShared(remote_place, keylen);
local_storage.push_back(key_mem);
nodes[i].key_storage = reinterpret_cast<char*>(key_mem->ptr());
auto val_mem = memory::AllocShared(remote_place, vallen);
local_storage.push_back(val_mem);
nodes[i].val_storage = reinterpret_cast<char*>(val_mem->ptr());
nodes[i].key_bytes_len = keylen;
nodes[i].val_bytes_len = vallen;
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::walk_to_dest(int start_index,
int end_index,
char* src_key,
char* src_val) {
int need_copy_val = 0;
if (src_val) {
need_copy_val = 1;
}
auto& nodes = path_[start_index][end_index].nodes_;
for (size_t i = 0; i < nodes.size(); ++i) {
cudaMemcpyAsync(nodes[i].key_storage, src_key, nodes[i].key_bytes_len,
cudaMemcpyDefault, nodes[i].in_stream);
if (need_copy_val) {
cudaMemcpyAsync(nodes[i].val_storage, src_val, nodes[i].val_bytes_len,
cudaMemcpyDefault, nodes[i].in_stream);
}
if (nodes[i].sync) {
cudaStreamSynchronize(nodes[i].in_stream);
}
// cudaStreamSynchronize(nodes[i].in_stream);
src_key = nodes[i].key_storage;
src_val = nodes[i].val_storage;
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::walk_to_src(int start_index,
int end_index,
char* src_val) {
auto& nodes = path_[start_index][end_index].nodes_;
int len = nodes.size();
char* start = NULL;
for (int i = len - 1; i >= 0; --i) {
if (start == NULL) {
start = nodes[i].val_storage;
continue;
}
cudaMemcpyAsync(nodes[i].val_storage, start, nodes[i].val_bytes_len,
cudaMemcpyDefault, nodes[i].out_stream);
if (nodes[i].sync) {
cudaStreamSynchronize(nodes[i].out_stream);
}
start = nodes[i].val_storage;
}
cudaMemcpyAsync(src_val, nodes[0].val_storage, nodes[0].val_bytes_len,
cudaMemcpyDefault, nodes[0].out_stream);
// cudaStreamSynchronize(nodes[0].out_stream);
}
template <typename KeyType, typename ValType, typename GradType>
......@@ -131,9 +256,10 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) {
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
ValType* h_vals, size_t len,
size_t chunk_size,
int stream_num) {
ValType* h_vals,
size_t len,
size_t chunk_size,
int stream_num) {
if (len <= 0) {
return;
}
......@@ -182,13 +308,15 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num, KeyType* d_keys,
GradType* d_grads,
size_t len, int& uniq_len) {
void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num,
KeyType* d_keys,
GradType* d_grads,
size_t len,
int& uniq_len) {
int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->stream(gpu_num);
auto stream = resource_->local_stream(gpu_num, 0);
size_t temp_storage_bytes;
......@@ -240,7 +368,7 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->stream(gpu_num);
auto stream = resource_->local_stream(gpu_num, 0);
auto d_idx_tmp = memory::AllocShared(place, len * sizeof(int));
int* d_idx_tmp_ptr = reinterpret_cast<int*>(d_idx_tmp->ptr());
......@@ -272,9 +400,10 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys,
ValType* d_vals,
size_t len) {
void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
KeyType* d_keys,
ValType* d_vals,
size_t len) {
if (len == 0) {
return;
}
......@@ -283,7 +412,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
int dev_id = resource_->dev_id(num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->stream(num);
auto stream = resource_->local_stream(num, 0);
int grid_size = (len - 1) / block_size_ + 1;
......@@ -318,28 +447,15 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
std::vector<KeyType*> d_remote_shard_keys_ptr;
std::vector<ValType*> d_remote_shard_vals_ptr;
std::vector<std::shared_ptr<memory::Allocation>> d_remote_shard_keys;
std::vector<std::shared_ptr<memory::Allocation>> d_remote_shard_vals;
std::vector<std::shared_ptr<memory::Allocation>> local_storage;
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
continue;
}
platform::CUDADeviceGuard guard(resource_->dev_id(i));
platform::CUDAPlace remote_place =
platform::CUDAPlace(resource_->dev_id(i));
d_remote_shard_keys.push_back(
memory::AllocShared(remote_place, shard_len * sizeof(KeyType)));
d_remote_shard_keys_ptr.push_back(
reinterpret_cast<KeyType*>(d_remote_shard_keys[i]->ptr()));
d_remote_shard_vals.push_back(
memory::AllocShared(remote_place, shard_len * sizeof(ValType)));
d_remote_shard_vals_ptr.push_back(
reinterpret_cast<ValType*>(d_remote_shard_vals[i]->ptr()));
create_storage(num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(ValType), local_storage);
}
for (int i = 0; i < total_gpu; ++i) {
......@@ -347,21 +463,23 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
cudaMemcpyAsync(d_remote_shard_keys_ptr[i], d_shard_keys_ptr + h_left[i],
shard_len * sizeof(KeyType), cudaMemcpyDefault, stream);
walk_to_dest(num, i, reinterpret_cast<char*>(d_shard_keys_ptr + h_left[i]),
NULL);
}
cudaStreamSynchronize(stream);
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
auto& node = path_[num][i].nodes_.back();
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
tables_[i]->get(d_remote_shard_keys_ptr[i], d_remote_shard_vals_ptr[i],
h_right[i] - h_left[i] + 1, resource_->stream(i));
tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<ValType*>(node.val_storage),
h_right[i] - h_left[i] + 1, resource_->remote_stream(i));
}
for (int i = 0; i < total_gpu; ++i) {
cudaStreamSynchronize(resource_->stream(i));
cudaStreamSynchronize(resource_->remote_stream(i));
}
for (int i = 0; i < total_gpu; ++i) {
......@@ -370,13 +488,12 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
continue;
}
platform::CUDADeviceGuard guard(resource_->dev_id(i));
cudaMemcpyAsync(d_shard_vals_ptr + h_left[i], d_remote_shard_vals_ptr[i],
shard_len * sizeof(ValType), cudaMemcpyDefault,
resource_->stream(i));
walk_to_src(num, i, reinterpret_cast<char*>(d_shard_vals_ptr + h_left[i]));
}
for (int i = 0; i < total_gpu; ++i) {
cudaStreamSynchronize(resource_->stream(i));
auto& node = path_[num][i].nodes_.front();
cudaStreamSynchronize(node.out_stream);
}
fill_dvals<<<grid_size, block_size_, 0, stream>>>(d_shard_vals_ptr, d_vals,
......@@ -387,9 +504,9 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
KeyType* d_keys,
GradType* d_grads,
size_t len, Sgd& sgd) {
KeyType* d_keys,
GradType* d_grads,
size_t len, Sgd& sgd) {
if (len == 0) {
return;
}
......@@ -398,7 +515,7 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->stream(gpu_num);
auto stream = resource_->local_stream(gpu_num, 0);
int h_left[total_gpu];
int h_right[total_gpu];
......@@ -439,28 +556,15 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
std::vector<KeyType*> d_remote_shard_keys_ptr;
std::vector<GradType*> d_remote_shard_grads_ptr;
std::vector<std::shared_ptr<memory::Allocation>> d_remote_shard_keys;
std::vector<std::shared_ptr<memory::Allocation>> d_remote_shard_grads;
std::vector<std::shared_ptr<memory::Allocation>> local_storage;
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
platform::CUDADeviceGuard guard(resource_->dev_id(i));
platform::CUDAPlace remote_place =
platform::CUDAPlace(resource_->dev_id(i));
d_remote_shard_keys.push_back(
memory::AllocShared(remote_place, shard_len * sizeof(KeyType)));
d_remote_shard_keys_ptr.push_back(
reinterpret_cast<KeyType*>(d_remote_shard_keys[i]->ptr()));
d_remote_shard_grads.push_back(
memory::AllocShared(remote_place, shard_len * sizeof(GradType)));
d_remote_shard_grads_ptr.push_back(
reinterpret_cast<GradType*>(d_remote_shard_grads[i]->ptr()));
create_storage(gpu_num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(GradType), local_storage);
}
for (int i = 0; i < total_gpu; ++i) {
......@@ -468,24 +572,26 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
cudaMemcpyAsync(d_remote_shard_keys_ptr[i], d_shard_keys_ptr + h_left[i],
shard_len * sizeof(KeyType), cudaMemcpyDefault, stream);
cudaMemcpyAsync(d_remote_shard_grads_ptr[i], d_shard_grads_ptr + h_left[i],
shard_len * sizeof(GradType), cudaMemcpyDefault, stream);
walk_to_dest(gpu_num, i,
reinterpret_cast<char*>(d_shard_keys_ptr + h_left[i]),
reinterpret_cast<char*>(d_shard_grads_ptr + h_left[i]));
}
cudaStreamSynchronize(stream);
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
auto& node = path_[gpu_num][i].nodes_.back();
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
tables_[i]->update(d_remote_shard_keys_ptr[i], d_remote_shard_grads_ptr[i],
h_right[i] - h_left[i] + 1, sgd, resource_->stream(i));
tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<GradType*>(node.val_storage),
h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i));
}
for (int i = 0; i < total_gpu; ++i) {
cudaStreamSynchronize(resource_->stream(i));
cudaStreamSynchronize(resource_->remote_stream(i));
}
}
......
......@@ -19,23 +19,35 @@ limitations under the License. */
namespace paddle {
namespace framework {
GPUResource::GPUResource(int dev_id, int index) {
GPUResource::GPUResource(std::vector<int>& dev_ids, int index) {
index_ = index;
dev_id_ = dev_id;
dev_ids_ = dev_ids;
dev_id_ = dev_ids_[index];
platform::CUDADeviceGuard guard(dev_id_);
local_streams_.resize(dev_ids_.size());
comm_streams_.resize(dev_ids_.size());
for (size_t i = 0; i < dev_ids_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&local_streams_[i], cudaStreamNonBlocking));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&comm_streams_[i], cudaStreamNonBlocking));
}
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&copy_stream_, cudaStreamNonBlocking));
cudaStreamCreateWithFlags(&remote_stream_, cudaStreamNonBlocking));
}
GPUResource::~GPUResource() {
platform::CUDADeviceGuard guard(dev_id_);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(copy_stream_));
for (size_t i = 0; i < local_streams_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(local_streams_[i]));
}
for (size_t i = 0; i < comm_streams_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(comm_streams_[i]));
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(remote_stream_));
}
void HeterPsResource::enable_p2p() {
......@@ -64,18 +76,22 @@ HeterPsResource::HeterPsResource(const std::vector<int>& dev_ids) {
dev_ids_ = dev_ids;
for (size_t i = 0; i < dev_ids_.size(); ++i) {
std::shared_ptr<GPUResource> resource =
std::make_shared<GPUResource>(dev_ids_[i], i);
std::make_shared<GPUResource>(dev_ids_, i);
resources_.push_back(resource);
devid_2_index_[dev_ids_[i]] = i;
}
}
cudaStream_t HeterPsResource::copy_stream(int num) {
return resources_[num]->copy_stream();
cudaStream_t HeterPsResource::comm_stream(int gpu_num, int stream_num) {
return resources_[gpu_num]->comm_stream(stream_num);
}
cudaStream_t HeterPsResource::local_stream(int gpu_num, int stream_num) {
return resources_[gpu_num]->local_stream(stream_num);
}
cudaStream_t HeterPsResource::stream(int num) {
return resources_[num]->stream();
cudaStream_t HeterPsResource::remote_stream(int gpu_num) {
return resources_[gpu_num]->remote_stream();
}
int HeterPsResource::dev_id(int num) { return dev_ids_[num]; }
......
......@@ -27,20 +27,23 @@ namespace framework {
class GPUResource {
public:
GPUResource(int device_id, int index);
GPUResource(std::vector<int>& device_id, int index);
virtual ~GPUResource();
GPUResource(const GPUResource&) = delete;
GPUResource& operator=(const GPUResource&) = delete;
int dev_id() const { return dev_id_; }
int index() const { return index_; }
cudaStream_t stream() { return stream_; }
cudaStream_t copy_stream() { return copy_stream_; }
cudaStream_t local_stream(int num) { return local_streams_[num]; }
cudaStream_t remote_stream() { return remote_stream_; }
cudaStream_t comm_stream(int num) { return comm_streams_[num]; }
int dev_id_;
int index_;
cudaStream_t stream_;
cudaStream_t copy_stream_;
std::vector<int> dev_ids_;
cudaStream_t remote_stream_;
std::vector<cudaStream_t> local_streams_;
std::vector<cudaStream_t> comm_streams_;
};
class HeterPsResource {
......@@ -52,9 +55,10 @@ class HeterPsResource {
void enable_p2p();
int total_gpu();
int get_index_by_devid(int devid);
cudaStream_t stream(int num);
cudaStream_t copy_stream(int num);
int dev_id(int num);
cudaStream_t local_stream(int gpu_num, int stream_num);
cudaStream_t remote_stream(int gpu_num);
cudaStream_t comm_stream(int gpu_num, int stream_num);
std::vector<std::shared_ptr<GPUResource>> resources_;
std::vector<int> dev_ids_;
......
......@@ -15,18 +15,19 @@ limitations under the License. */
#pragma once
namespace optimizer_config {
__constant__ float mf_create_thresholds = 1;
__constant__ float nonclk_coeff = 1;
__constant__ float mf_create_thresholds = 0;
__constant__ float nonclk_coeff = 0.1;
__constant__ float clk_coeff = 1;
__constant__ float min_bound = -10000;
__constant__ float max_bound = 10000;
__constant__ float learning_rate = 1;
__constant__ float initial_g2sum = 1;
__constant__ float initial_range = 1;
__constant__ float min_bound = -10;
__constant__ float max_bound = 10;
__constant__ float learning_rate = 0.05;
__constant__ float initial_g2sum = 3.0;
__constant__ float initial_range = 1e-4;
__constant__ float mf_learning_rate = 1;
__constant__ float mf_initial_g2sum = 1;
__constant__ float mf_initial_range = 1;
__constant__ float mf_min_bound = 1;
__constant__ float mf_max_bound = 1;
__constant__ float mf_learning_rate = 0.05;
__constant__ float mf_initial_g2sum = 3.0;
__constant__ float mf_initial_range = 1e-4;
__constant__ float mf_min_bound = -10;
__constant__ float mf_max_bound = 10;
}
......@@ -143,16 +143,17 @@ void PSGPUWorker::SetNeedDump(bool need_dump_field) {
void PSGPUWorker::DumpParam() {}
void PSGPUWorker::TrainFiles() {
VLOG(3) << "train file A";
platform::SetNumThreads(1);
platform::Timer timeline;
timeline.Start();
int total_ins_num = 0;
VLOG(3) << "train file B";
// how to accumulate fetched values here
device_reader_->Start();
VLOG(3) << "train file C";
int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) {
VLOG(3) << "train file D";
total_ins_num += cur_batch;
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
......@@ -169,6 +170,9 @@ void PSGPUWorker::TrainFiles() {
PrintFetchVars();
thread_scope_->DropKids();
}
timeline.Pause();
VLOG(1) << "GpuPs worker " << thread_id_ << " train cost "
<< timeline.ElapsedSec() << " seconds, ins_num: " << total_ins_num;
return;
}
......
......@@ -57,7 +57,11 @@ void BindFleetWrapper(py::module* m) {
.def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &framework::FleetWrapper::CacheShuffle)
.def("save_cache", &framework::FleetWrapper::SaveCache)
.def("save_model_with_whitelist",
&framework::FleetWrapper::SaveWithWhitelist)
.def("load_model", &framework::FleetWrapper::LoadModel)
.def("load_table_with_whitelist",
&framework::FleetWrapper::LoadWithWhitelist)
.def("clear_model", &framework::FleetWrapper::ClearModel)
.def("clear_one_table", &framework::FleetWrapper::ClearOneTable)
.def("stop_server", &framework::FleetWrapper::StopServer)
......
......@@ -101,15 +101,16 @@ class PSLib(Fleet):
# barrier_all for init_worker
self._role_maker._barrier_all()
# prepare for client to client communication
if self._role_maker.is_worker():
info = self._fleet_ptr.get_clients_info()
all_info = self._role_maker._worker_gather(info[0])
self._fleet_ptr.gather_clients(all_info)
self._fleet_ptr.set_client2client_config(
self._client2client_request_timeout_ms,
self._client2client_connect_timeout_ms,
self._client2client_max_retry)
self._fleet_ptr.create_client2client_connection()
if not self._opt_info["use_ps_gpu"]:
if self._role_maker.is_worker():
info = self._fleet_ptr.get_clients_info()
all_info = self._role_maker._worker_gather(info[0])
self._fleet_ptr.gather_clients(all_info)
self._fleet_ptr.set_client2client_config(
self._client2client_request_timeout_ms,
self._client2client_connect_timeout_ms,
self._client2client_max_retry)
self._fleet_ptr.create_client2client_connection()
# barrier for init model
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
......@@ -137,9 +138,10 @@ class PSLib(Fleet):
"var " + var_name + " not found in scope, "
+ "you should run startup program first")
var_name_list.append(var_name)
self._fleet_ptr.init_model(scope,
int(table.table_id),
var_name_list)
if not self._opt_info["use_ps_gpu"]:
self._fleet_ptr.init_model(scope,
int(table.table_id),
var_name_list)
# barrier for init model done
self._role_maker._barrier_worker()
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册