diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index efdb90b3362d64f364759c4264902d10e4f62715..698ece09de6c50781662659e317f4b1fc8f340b1 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -33,6 +33,7 @@ struct FeatureValue { float lr_g2sum; int mf_size; float mf[MF_DIM + 1]; + uint64_t cpu_ptr; friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) { out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 0c45edb57f876bcc7051c71c8994977fe859e9f1..11bd6e7aa69c3b720609c4f1bd4e90f952ebe866 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include #include +#include "common_value.h" // NOLINT #include "thrust/pair.h" //#include "cudf/concurrent_unordered_map.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" @@ -47,6 +49,7 @@ class HashTable { void get(const KeyType* d_keys, ValType* d_vals, size_t len, cudaStream_t stream); void show(); + void dump_to_cpu(int devid, cudaStream_t stream); template void update(const KeyType* d_keys, const GradType* d_grads, size_t len, @@ -60,5 +63,5 @@ class HashTable { }; } // end namespace framework } // end namespace paddle -#include "hashtable.tpp" +#include "hashtable_inl.h" #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.tpp b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h similarity index 78% rename from paddle/fluid/framework/fleet/heter_ps/hashtable.tpp rename to paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h index 3c125701c6b77ee03db8057118173d4df1151f1f..ef37ed64c2a5f785b9ec79c4b971df214581c5df 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.tpp +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h @@ -108,6 +108,41 @@ void HashTable::insert(const KeyType* d_keys, d_vals, len); } +template +void HashTable::dump_to_cpu(int devid, cudaStream_t stream) { + container_->prefetch(cudaCpuDeviceId, stream); + size_t num = container_->size(); + KeyType unuse_key = std::numeric_limits::max(); + thrust::pair* kv = container_->data(); + for (size_t i = 0; i < num; ++i) { + if (kv[i].first == unuse_key) { + continue; + } + ValType& gpu_val = kv[i].second; + auto* downpour_value = + (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr); + int downpour_value_size = downpour_value->size(); + if (gpu_val.mf_size > 0 && downpour_value_size == 7) { + downpour_value->resize(gpu_val.mf_size + downpour_value_size); + } + float* cpu_val = downpour_value->data(); + cpu_val[0] = 0; + cpu_val[1] = gpu_val.delta_score; + cpu_val[2] = gpu_val.show; + cpu_val[3] = gpu_val.clk; + cpu_val[4] = gpu_val.lr; + cpu_val[5] = gpu_val.lr_g2sum; + cpu_val[6] = gpu_val.slot; + if (gpu_val.mf_size > 0) { + for (int x = 0; x < gpu_val.mf_size; x++) { + cpu_val[x + 7] = gpu_val.mf[x]; + } + } + } + + container_->prefetch(devid, stream); +} + template template void HashTable::update(const KeyType* d_keys, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index a544d8f44f176ab10602866ebfbed8d0ac757b7a..5d299998534d15b8272063ebf50c9ad227728fac 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include "cub/cub.cuh" #include "hashtable.h" #include "heter_resource.h" -#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/place.h" @@ -72,6 +73,10 @@ class HeterComm { return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id); } + // void dump_to_cpu(int index); + + void end_pass(); + int get_transfer_devid(int send_id) { return (send_id + 4) % 8; } struct Node { @@ -110,5 +115,5 @@ class HeterComm { } // end namespace framework } // end namespace paddle -#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp" +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h" #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h similarity index 96% rename from paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp rename to paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index e280397b2a244732e34a5c46db71b78424fc7798..f95d4d3948b1924497f62689b97b13e7917aaf2b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -595,6 +595,34 @@ void HeterComm::push_sparse(int gpu_num, } } +template +void HeterComm::end_pass() { + int total_gpu = resource_->total_gpu(); + std::vector threads; + + auto dump_to_cpu_func = [this](int index) { + auto stream = resource_->local_stream(index, 0); + int dev_id = resource_->dev_id(index); + platform::CUDADeviceGuard guard(dev_id); + tables_[index]->dump_to_cpu(dev_id, stream); + }; + + for (int i = 0; i < total_gpu; ++i) { + threads.push_back(std::thread(dump_to_cpu_func, i)); + } + for (auto& t : threads) { + t.join(); + } +} + +// template +// void HeterComm::dump_to_cpu(int index) { +// auto stream = resource_->local_stream(index, 0); +// int dev_id = resource_->dev_id(index); +// platform::CUDADeviceGuard guard(dev_id); +// tables_[index]->dump_to_cpu(dev_id, stream); +//} + } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index a3f306f6100ce8a6576b8f7c59f78edabbd1180e..a9db1a562945356214cd865bf505bed2c1dda0f9 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -48,7 +48,7 @@ int HeterPs::get_index_by_devid(int devid) { return comm_->get_index_by_devid(devid); } -void HeterPs::dump() {} +void HeterPs::end_pass() { comm_->end_pass(); } void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h index 6c6d408a53b321801adcdad959e1de86d80c8c46..74d24fe43ebfd84beca65132242ce8a24b391841 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" -#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #ifdef PADDLE_WITH_PSLIB @@ -35,7 +35,7 @@ class HeterPs : public HeterPsBase { size_t len) override; virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, size_t chunk_size, int stream_num) override; - virtual void dump() override; + virtual void end_pass() override; virtual int get_index_by_devid(int devid) override; virtual void show_one_table(int gpu_num) override; virtual void push_sparse(int num, FeatureKey* d_keys, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h index 3bda03359f6a5c5c3b54719c8a258d5e870af5ad..29c2f68fc4abac4e1f051b6455b561425e07534c 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h @@ -35,7 +35,7 @@ class HeterPsBase { virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, size_t chunk_size, int stream_num) = 0; virtual int get_index_by_devid(int devid) = 0; - virtual void dump() = 0; + virtual void end_pass() = 0; virtual void show_one_table(int gpu_num) = 0; virtual void push_sparse(int num, FeatureKey* d_keys, FeaturePushValue* d_grads, size_t len) = 0; diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h similarity index 96% rename from paddle/fluid/framework/fleet/heter_ps/optimizer.cuh rename to paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index e8e027f383f6491a8de3f0d818ff1fd954b8d4db..b3ec9e752e62bb01a73f2d2070f94a47f8fe0730 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include +#include #include "optimizer_conf.h" #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" @@ -111,8 +111,8 @@ class Optimizer { curandState state; curand_init(clock64(), tid_x, 0, &state); for (int i = 0; i < MF_DIM; ++i) { - val.mf[i + 1] = (curand_uniform(&state)) * - optimizer_config::mf_initial_range; + val.mf[i + 1] = + (curand_uniform(&state)) * optimizer_config::mf_initial_range; } } } else { diff --git a/paddle/fluid/framework/fleet/heter_ps/test_comm.cu b/paddle/fluid/framework/fleet/heter_ps/test_comm.cu index 88b02a6947f9417270bcfbe5628f397b06a878b6..3a6ed50ad8e70229ee7dfa97c1a222b1abe296df 100644 --- a/paddle/fluid/framework/fleet/heter_ps/test_comm.cu +++ b/paddle/fluid/framework/fleet/heter_ps/test_comm.cu @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" -#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #include "paddle/fluid/platform/cuda_device_guard.h" using namespace paddle::framework; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 67b24a3b037665d90dcb5d060f7bce4156bf515b..32eb9418b659b9b5b35d6de081dc4f2b6fc733f5 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -183,6 +183,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task, val.slot = ptr_val[6]; val.lr = ptr_val[4]; val.lr_g2sum = ptr_val[5]; + val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); if (dim > 7) { val.mf_size = MF_DIM + 1; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 631c8456c562976cdb8e6b8af86ed94d8855d829..98e0028e42758a24e2b301b4f33d072d19c9f9ed 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -162,6 +162,9 @@ class PSGPUWrapper { slot_vector_ = slot_vector; } + void EndPass() { HeterPs_->end_pass(); } + void ShowOneTable(int index) { HeterPs_->show_one_table(index); } + private: static std::shared_ptr s_instance_; Dataset* dataset_; diff --git a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc index b8ecdfe9a56a3883888cebd5373750d1f60f256c..96acfd7bc040439852bb8f02fce44bfdfebae335 100644 --- a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc +++ b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc @@ -45,6 +45,8 @@ void BindPSGPUWrapper(py::module* m) { py::call_guard()) .def("init_gpu_ps", &framework::PSGPUWrapper::InitializeGPU, py::call_guard()) + .def("end_pass", &framework::PSGPUWrapper::EndPass, + py::call_guard()) .def("build_gpu_ps", &framework::PSGPUWrapper::BuildGPUPS, py::call_guard()); } // end PSGPUWrapper