未验证 提交 d1075df2 编写于 作者: D danleifeng 提交者: GitHub

topo and memory performance for heterps (#30440)

* topo and memory performance for heterps; test=develop
* add trainwithprofiler in heter trainier; test=develop
上级 72d99c5d
......@@ -547,6 +547,7 @@ class PSGPUWorker : public HogwildWorker {
virtual ~PSGPUWorker() {}
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
virtual void SetNeedDump(bool need_dump_field);
virtual void SetChannelWriter(ChannelObject<std::string>* queue);
virtual void SetWorkerNum(int num) { worker_num_ = num; }
......@@ -556,7 +557,6 @@ class PSGPUWorker : public HogwildWorker {
virtual void ProduceTasks() override;
virtual void SetStream(const gpuStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const gpuEvent_t event) { event_ = event; }
virtual void TrainFilesWithProfiler() {}
void ResetStat();
protected:
......@@ -618,6 +618,7 @@ class PSGPUWorker : public HogwildWorker {
gpuStream_t copy_stream_;
int batch_cnt_{0};
std::atomic<int> done_cnt_{0};
platform::DeviceContext* dev_ctx_ = nullptr;
double total_time_;
double read_time_;
......
......@@ -66,18 +66,19 @@ class HeterContext {
mutex_[i] = new std::mutex();
}
}
void batch_add_keys(const std::vector<std::vector<uint64_t>>& thread_keys) {
void batch_add_keys(
const std::vector<std::unordered_set<uint64_t>>& thread_keys) {
assert(thread_keys.size() == feature_keys_.size());
for (uint32_t i = 0; i < shard_num_; i++) {
int idx = 0;
idx = feature_keys_[i].size();
feature_keys_[i].resize(feature_keys_[i].size() + thread_keys[i].size());
for (uint64_t j = 0; j < thread_keys[i].size(); j++) {
feature_keys_[i][idx + j] = thread_keys[i][j];
}
std::copy(thread_keys[i].begin(), thread_keys[i].end(),
feature_keys_[i].begin() + idx);
}
}
void UniqueKeys() {
std::vector<std::thread> threads;
auto unique_func = [this](int i) {
......
......@@ -55,6 +55,8 @@ class HashTable {
void update(const KeyType* d_keys, const GradType* d_grads, size_t len,
Sgd sgd, gpuStream_t stream);
int size() { return container_->size(); }
private:
TableContainer<KeyType, ValType>* container_;
int BLOCK_SIZE_{256};
......
......@@ -118,6 +118,12 @@ class HeterComm {
std::vector<Node> nodes_;
};
struct CopyTask {
Path* path;
int step;
CopyTask(Path* path_, int step_) : path(path_), step(step_) {}
};
struct LocalStorage {
LocalStorage() {}
void init(int size, int dev_id) {
......@@ -160,9 +166,10 @@ class HeterComm {
void create_storage(
int start_index, int end_index, int keylen, int vallen,
std::vector<std::shared_ptr<memory::Allocation>>& local_strorage);
void walk_to_src(int start_index, int end_index, char* src_val);
void walk_to_dest(int start_index, int end_index, char* src_key,
char* src_val);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
KeyType* src_key, GradType* src_val);
void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right,
ValType* src_val);
private:
using Table = HashTable<KeyType, ValType>;
......
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <queue>
#ifdef PADDLE_WITH_PSLIB
namespace paddle {
namespace framework {
......@@ -182,53 +184,105 @@ void HeterComm<KeyType, ValType, GradType>::create_storage(
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::walk_to_dest(int start_index,
int end_index,
char* src_key,
char* src_val) {
void HeterComm<KeyType, ValType, GradType>::walk_to_dest(
int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key,
GradType* src_val) {
int need_copy_val = 0;
if (src_val) {
need_copy_val = 1;
}
auto& nodes = path_[start_index][end_index].nodes_;
for (size_t i = 0; i < nodes.size(); ++i) {
cudaMemcpyAsync(nodes[i].key_storage, src_key, nodes[i].key_bytes_len,
cudaMemcpyDefault, nodes[i].in_stream);
std::queue<CopyTask> que;
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
int size = path_[start_index][i].nodes_.size();
auto& node = path_[start_index][i].nodes_[0];
CopyTask t(&path_[start_index][i], 0);
que.push(t);
cudaMemcpyAsync(node.key_storage,
reinterpret_cast<char*>(src_key + h_left[i]),
node.key_bytes_len, cudaMemcpyDefault, node.in_stream);
if (need_copy_val) {
cudaMemcpyAsync(nodes[i].val_storage, src_val, nodes[i].val_bytes_len,
cudaMemcpyDefault, nodes[i].in_stream);
cudaMemcpyAsync(node.val_storage,
reinterpret_cast<char*>(src_val + h_left[i]),
node.val_bytes_len, cudaMemcpyDefault, node.in_stream);
}
}
while (!que.empty()) {
CopyTask& cur_task = que.front();
que.pop();
if (cur_task.path->nodes_[cur_task.step].sync) {
cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream);
}
if (nodes[i].sync) {
cudaStreamSynchronize(nodes[i].in_stream);
if (cur_task.step != cur_task.path->nodes_.size() - 1) {
int cur_step = cur_task.step;
CopyTask c(cur_task.path, cur_step + 1);
que.push(c);
cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage,
cur_task.path->nodes_[cur_step].key_storage,
cur_task.path->nodes_[cur_step + 1].key_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step + 1].in_stream);
if (need_copy_val) {
cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage,
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step + 1].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step + 1].in_stream);
}
}
// cudaStreamSynchronize(nodes[i].in_stream);
src_key = nodes[i].key_storage;
src_val = nodes[i].val_storage;
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::walk_to_src(int start_index,
int end_index,
char* src_val) {
auto& nodes = path_[start_index][end_index].nodes_;
int len = nodes.size();
char* start = NULL;
for (int i = len - 1; i >= 0; --i) {
if (start == NULL) {
start = nodes[i].val_storage;
void HeterComm<KeyType, ValType, GradType>::walk_to_src(
int start_index, int gpu_num, int* h_left, int* h_right, ValType* src_val) {
std::queue<CopyTask> que;
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
cudaMemcpyAsync(nodes[i].val_storage, start, nodes[i].val_bytes_len,
cudaMemcpyDefault, nodes[i].out_stream);
if (nodes[i].sync) {
cudaStreamSynchronize(nodes[i].out_stream);
int cur_step = path_[start_index][i].nodes_.size() - 1;
auto& node = path_[start_index][i].nodes_[cur_step];
if (cur_step == 0) {
cudaMemcpyAsync(reinterpret_cast<char*>(src_val + h_left[i]),
node.val_storage, node.val_bytes_len, cudaMemcpyDefault,
node.out_stream);
} else {
CopyTask t(&path_[start_index][i], cur_step - 1);
que.push(t);
cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage,
node.val_storage,
path_[start_index][i].nodes_[cur_step - 1].val_bytes_len,
cudaMemcpyDefault,
path_[start_index][i].nodes_[cur_step - 1].out_stream);
}
}
while (!que.empty()) {
CopyTask& cur_task = que.front();
que.pop();
int cur_step = cur_task.step;
if (cur_task.path->nodes_[cur_step].sync) {
cudaStreamSynchronize(cur_task.path->nodes_[cur_step].out_stream);
}
if (cur_step > 0) {
CopyTask c(cur_task.path, cur_step - 1);
que.push(c);
cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage,
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step - 1].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step - 1].out_stream);
} else if (cur_step == 0) {
int end_index = cur_task.path->nodes_.back().gpu_num;
cudaMemcpyAsync(reinterpret_cast<char*>(src_val + h_left[end_index]),
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step].out_stream);
}
start = nodes[i].val_storage;
}
cudaMemcpyAsync(src_val, nodes[0].val_storage, nodes[0].val_bytes_len,
cudaMemcpyDefault, nodes[0].out_stream);
// cudaStreamSynchronize(nodes[0].out_stream);
}
template <typename KeyType, typename ValType, typename GradType>
......@@ -462,14 +516,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
shard_len * sizeof(ValType), local_storage);
}
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
walk_to_dest(num, i, reinterpret_cast<char*>(d_shard_keys_ptr + h_left[i]),
NULL);
}
walk_to_dest(num, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL);
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
......@@ -486,14 +533,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
cudaStreamSynchronize(resource_->remote_stream(i));
}
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
platform::CUDADeviceGuard guard(resource_->dev_id(i));
walk_to_src(num, i, reinterpret_cast<char*>(d_shard_vals_ptr + h_left[i]));
}
walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr);
for (int i = 0; i < total_gpu; ++i) {
auto& node = path_[num][i].nodes_.front();
......@@ -561,7 +601,6 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
cudaMemcpyDeviceToHost);
std::vector<std::shared_ptr<memory::Allocation>> local_storage;
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
......@@ -571,15 +610,8 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
shard_len * sizeof(GradType), local_storage);
}
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
walk_to_dest(gpu_num, i,
reinterpret_cast<char*>(d_shard_keys_ptr + h_left[i]),
reinterpret_cast<char*>(d_shard_grads_ptr + h_left[i]));
}
walk_to_dest(gpu_num, total_gpu, h_left, h_right, d_shard_keys_ptr,
d_shard_grads_ptr);
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
......
......@@ -65,9 +65,6 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
for (int i = 0; i < thread_keys_thread_num_; i++) {
thread_keys_[i].resize(thread_keys_shard_num_);
for (int j = 0; j < thread_keys_shard_num_; j++) {
thread_keys_[i][j].reserve(2 * max_fea_num_per_pass_ /
thread_keys_shard_num_ /
thread_keys_thread_num_);
}
}
const std::deque<Record>& vec_data = input_channel->GetData();
......@@ -84,7 +81,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
for (const auto feasign : feasign_v) {
uint64_t cur_key = feasign.sign().uint64_feasign_;
int shard_id = cur_key % thread_keys_shard_num_;
this->thread_keys_[i][shard_id].push_back(cur_key);
this->thread_keys_[i][shard_id].insert(cur_key);
}
}
};
......@@ -123,7 +120,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
VLOG(3) << "GpuPs shard: " << i << " key len: " << local_keys[i].size();
local_ptr[i].resize(local_keys[i].size());
}
timeline.Start();
auto ptl_func = [this, &local_keys, &local_ptr, &table_id,
&fleet_ptr](int i) {
size_t key_size = local_keys[i].size();
......@@ -149,7 +146,8 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
t.join();
}
timeline.Pause();
VLOG(1) << "GpuPs pull sparse cost " << timeline.ElapsedSec() << " seconds.";
VLOG(1) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec()
<< " seconds.";
timeline.Start();
auto build_func = [device_num, &local_keys, &local_ptr, &device_keys,
......@@ -225,6 +223,7 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
size_t size_max = 0;
for (int i = 0; i < device_num; i++) {
feature_keys_count[i] = gpu_task->device_keys_[i].size();
VLOG(1) << i << " card contains feasign nums: " << feature_keys_count[i];
size_max = std::max(size_max, feature_keys_count[i]);
}
if (HeterPs_) {
......@@ -314,7 +313,7 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
"GpuPs: PullSparse Only Support CUDAPlace Now."));
}
all_timer.Pause();
VLOG(1) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec()
VLOG(3) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec()
<< " s, of which GPUPS costs: " << pull_gpups_timer.ElapsedSec()
<< " s";
VLOG(3) << "End PullSparse";
......@@ -360,7 +359,7 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
"GPUPS: PushSparseGrad Only Support CUDAPlace Now."));
}
all_timer.Pause();
VLOG(1) << "PushSparseGrad total cost: " << all_timer.ElapsedSec()
VLOG(3) << "PushSparseGrad total cost: " << all_timer.ElapsedSec()
<< " s, of which GPUPS cost: " << push_gpups_timer.ElapsedSec()
<< " s";
VLOG(3) << "End PushSparseGrad";
......
......@@ -227,7 +227,7 @@ class PSGPUWrapper {
std::vector<int> heter_devices_;
std::unordered_set<std::string> gpu_ps_config_keys_;
HeterObjectPool<HeterContext> gpu_task_pool_;
std::vector<std::vector<std::vector<uint64_t>>> thread_keys_;
std::vector<std::vector<std::unordered_set<uint64_t>>> thread_keys_;
int thread_keys_thread_num_ = 37;
int thread_keys_shard_num_ = 37;
uint64_t max_fea_num_per_pass_ = 5000000000;
......
......@@ -131,8 +131,13 @@ void PSGPUTrainer::InitOtherEnv(const ProgramDesc& main_program) {
void PSGPUTrainer::Run() {
for (size_t thidx = 0; thidx < places_.size(); ++thidx) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
} else {
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
workers_[thidx].get()));
}
}
}
......
......@@ -33,6 +33,7 @@ namespace framework {
void PSGPUWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
mpi_rank_ = desc.mpi_rank();
trainer_desc_ = desc;
/*
......@@ -177,6 +178,81 @@ void PSGPUWorker::TrainFiles() {
return;
}
void PSGPUWorker::TrainFilesWithProfiler() {
platform::SetNumThreads(1);
VLOG(1) << "Begin to train files with profiler";
device_reader_->Start();
std::vector<double> op_total_time;
std::vector<std::string> op_name;
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op_name.push_back(op->Type());
}
}
VLOG(3) << "op name size: " << op_name.size();
op_total_time.resize(op_name.size());
for (size_t i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
platform::Timer timeline;
double total_time = 0.0;
double read_time = 0.0;
int total_ins_num = 0;
int cur_batch;
timeline.Start();
while ((cur_batch = device_reader_->Next()) > 0) {
total_ins_num += cur_batch;
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
int run_op_idx = 0;
dev_ctx_->Wait();
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
timeline.Start();
VLOG(3) << "Going to run op " << op_name[run_op_idx];
op->Run(*thread_scope_, place_);
dev_ctx_->Wait();
VLOG(3) << "Op " << op_name[run_op_idx] << " Finished";
timeline.Pause();
op_total_time[run_op_idx++] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
}
timeline.Start();
PrintFetchVars();
thread_scope_->DropKids();
dev_ctx_->Wait();
timeline.Pause();
total_time += timeline.ElapsedSec();
timeline.Start();
}
VLOG(1) << "GpuPs worker " << thread_id_ << " train cost " << total_time
<< " seconds, ins_num: " << total_ins_num;
for (size_t i = 0; i < op_name.size(); ++i) {
VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i]
<< ", mean time: " << op_total_time[i] / total_ins_num
<< "s, totol time:" << op_total_time[i] << "sec";
}
return;
}
void PSGPUWorker::ResetStat() {
total_time_ = 0;
read_time_ = 0;
......
......@@ -447,9 +447,6 @@ class HDFSClient(FS):
configs,
time_out=5 * 60 * 1000, # ms
sleep_inter=1000): # ms
# Raise exception if JAVA_HOME not exists.
java_home = os.environ["JAVA_HOME"]
self.pre_commands = []
hadoop_bin = '%s/bin/hadoop' % hadoop_home
self.pre_commands.append(hadoop_bin)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册