From cb66c53c2db3a6a0f909917fc3cc498cf28bf489 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Mon, 1 Feb 2021 17:06:59 +0800 Subject: [PATCH] dump to cpu (#30750) * dump to cpu * format * format * format --- .../framework/fleet/heter_ps/feature_value.h | 1 + .../framework/fleet/heter_ps/hashtable.h | 5 ++- .../{hashtable.tpp => hashtable_inl.h} | 35 +++++++++++++++++++ .../framework/fleet/heter_ps/heter_comm.h | 9 +++-- .../{heter_comm.tpp => heter_comm_inl.h} | 28 +++++++++++++++ .../framework/fleet/heter_ps/heter_ps.cu | 2 +- .../fluid/framework/fleet/heter_ps/heter_ps.h | 4 +-- .../framework/fleet/heter_ps/heter_ps_base.h | 2 +- .../{optimizer.cuh => optimizer.cuh.h} | 6 ++-- .../framework/fleet/heter_ps/test_comm.cu | 2 +- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 1 + paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 3 ++ paddle/fluid/pybind/ps_gpu_wrapper_py.cc | 2 ++ 13 files changed, 89 insertions(+), 11 deletions(-) rename paddle/fluid/framework/fleet/heter_ps/{hashtable.tpp => hashtable_inl.h} (78%) rename paddle/fluid/framework/fleet/heter_ps/{heter_comm.tpp => heter_comm_inl.h} (96%) rename paddle/fluid/framework/fleet/heter_ps/{optimizer.cuh => optimizer.cuh.h} (96%) diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index efdb90b3362..698ece09de6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -33,6 +33,7 @@ struct FeatureValue { float lr_g2sum; int mf_size; float mf[MF_DIM + 1]; + uint64_t cpu_ptr; friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) { out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 0c45edb57f8..11bd6e7aa69 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include #include +#include "common_value.h" // NOLINT #include "thrust/pair.h" //#include "cudf/concurrent_unordered_map.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" @@ -47,6 +49,7 @@ class HashTable { void get(const KeyType* d_keys, ValType* d_vals, size_t len, cudaStream_t stream); void show(); + void dump_to_cpu(int devid, cudaStream_t stream); template void update(const KeyType* d_keys, const GradType* d_grads, size_t len, @@ -60,5 +63,5 @@ class HashTable { }; } // end namespace framework } // end namespace paddle -#include "hashtable.tpp" +#include "hashtable_inl.h" #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.tpp b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h similarity index 78% rename from paddle/fluid/framework/fleet/heter_ps/hashtable.tpp rename to paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h index 3c125701c6b..ef37ed64c2a 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.tpp +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h @@ -108,6 +108,41 @@ void HashTable::insert(const KeyType* d_keys, d_vals, len); } +template +void HashTable::dump_to_cpu(int devid, cudaStream_t stream) { + container_->prefetch(cudaCpuDeviceId, stream); + size_t num = container_->size(); + KeyType unuse_key = std::numeric_limits::max(); + thrust::pair* kv = container_->data(); + for (size_t i = 0; i < num; ++i) { + if (kv[i].first == unuse_key) { + continue; + } + ValType& gpu_val = kv[i].second; + auto* downpour_value = + (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr); + int downpour_value_size = downpour_value->size(); + if (gpu_val.mf_size > 0 && downpour_value_size == 7) { + downpour_value->resize(gpu_val.mf_size + downpour_value_size); + } + float* cpu_val = downpour_value->data(); + cpu_val[0] = 0; + cpu_val[1] = gpu_val.delta_score; + cpu_val[2] = gpu_val.show; + cpu_val[3] = gpu_val.clk; + cpu_val[4] = gpu_val.lr; + cpu_val[5] = gpu_val.lr_g2sum; + cpu_val[6] = gpu_val.slot; + if (gpu_val.mf_size > 0) { + for (int x = 0; x < gpu_val.mf_size; x++) { + cpu_val[x + 7] = gpu_val.mf[x]; + } + } + } + + container_->prefetch(devid, stream); +} + template template void HashTable::update(const KeyType* d_keys, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index a544d8f44f1..5d299998534 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include "cub/cub.cuh" #include "hashtable.h" #include "heter_resource.h" -#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/place.h" @@ -72,6 +73,10 @@ class HeterComm { return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id); } + // void dump_to_cpu(int index); + + void end_pass(); + int get_transfer_devid(int send_id) { return (send_id + 4) % 8; } struct Node { @@ -110,5 +115,5 @@ class HeterComm { } // end namespace framework } // end namespace paddle -#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp" +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h" #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h similarity index 96% rename from paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp rename to paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index e280397b2a2..f95d4d3948b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -595,6 +595,34 @@ void HeterComm::push_sparse(int gpu_num, } } +template +void HeterComm::end_pass() { + int total_gpu = resource_->total_gpu(); + std::vector threads; + + auto dump_to_cpu_func = [this](int index) { + auto stream = resource_->local_stream(index, 0); + int dev_id = resource_->dev_id(index); + platform::CUDADeviceGuard guard(dev_id); + tables_[index]->dump_to_cpu(dev_id, stream); + }; + + for (int i = 0; i < total_gpu; ++i) { + threads.push_back(std::thread(dump_to_cpu_func, i)); + } + for (auto& t : threads) { + t.join(); + } +} + +// template +// void HeterComm::dump_to_cpu(int index) { +// auto stream = resource_->local_stream(index, 0); +// int dev_id = resource_->dev_id(index); +// platform::CUDADeviceGuard guard(dev_id); +// tables_[index]->dump_to_cpu(dev_id, stream); +//} + } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index a3f306f6100..a9db1a56294 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -48,7 +48,7 @@ int HeterPs::get_index_by_devid(int devid) { return comm_->get_index_by_devid(devid); } -void HeterPs::dump() {} +void HeterPs::end_pass() { comm_->end_pass(); } void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h index 6c6d408a53b..74d24fe43eb 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" -#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #ifdef PADDLE_WITH_PSLIB @@ -35,7 +35,7 @@ class HeterPs : public HeterPsBase { size_t len) override; virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, size_t chunk_size, int stream_num) override; - virtual void dump() override; + virtual void end_pass() override; virtual int get_index_by_devid(int devid) override; virtual void show_one_table(int gpu_num) override; virtual void push_sparse(int num, FeatureKey* d_keys, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h index 3bda03359f6..29c2f68fc4a 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h @@ -35,7 +35,7 @@ class HeterPsBase { virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, size_t chunk_size, int stream_num) = 0; virtual int get_index_by_devid(int devid) = 0; - virtual void dump() = 0; + virtual void end_pass() = 0; virtual void show_one_table(int gpu_num) = 0; virtual void push_sparse(int num, FeatureKey* d_keys, FeaturePushValue* d_grads, size_t len) = 0; diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h similarity index 96% rename from paddle/fluid/framework/fleet/heter_ps/optimizer.cuh rename to paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index e8e027f383f..b3ec9e752e6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include +#include #include "optimizer_conf.h" #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" @@ -111,8 +111,8 @@ class Optimizer { curandState state; curand_init(clock64(), tid_x, 0, &state); for (int i = 0; i < MF_DIM; ++i) { - val.mf[i + 1] = (curand_uniform(&state)) * - optimizer_config::mf_initial_range; + val.mf[i + 1] = + (curand_uniform(&state)) * optimizer_config::mf_initial_range; } } } else { diff --git a/paddle/fluid/framework/fleet/heter_ps/test_comm.cu b/paddle/fluid/framework/fleet/heter_ps/test_comm.cu index 88b02a6947f..3a6ed50ad8e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/test_comm.cu +++ b/paddle/fluid/framework/fleet/heter_ps/test_comm.cu @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" -#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #include "paddle/fluid/platform/cuda_device_guard.h" using namespace paddle::framework; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 67b24a3b037..32eb9418b65 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -183,6 +183,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task, val.slot = ptr_val[6]; val.lr = ptr_val[4]; val.lr_g2sum = ptr_val[5]; + val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); if (dim > 7) { val.mf_size = MF_DIM + 1; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 631c8456c56..98e0028e427 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -162,6 +162,9 @@ class PSGPUWrapper { slot_vector_ = slot_vector; } + void EndPass() { HeterPs_->end_pass(); } + void ShowOneTable(int index) { HeterPs_->show_one_table(index); } + private: static std::shared_ptr s_instance_; Dataset* dataset_; diff --git a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc index b8ecdfe9a56..96acfd7bc04 100644 --- a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc +++ b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc @@ -45,6 +45,8 @@ void BindPSGPUWrapper(py::module* m) { py::call_guard()) .def("init_gpu_ps", &framework::PSGPUWrapper::InitializeGPU, py::call_guard()) + .def("end_pass", &framework::PSGPUWrapper::EndPass, + py::call_guard()) .def("build_gpu_ps", &framework::PSGPUWrapper::BuildGPUPS, py::call_guard()); } // end PSGPUWrapper -- GitLab