From 09b6e71928b6faecbce473c1294146a1bc00f7d6 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Wed, 23 Dec 2020 11:01:11 +0800 Subject: [PATCH] heter box (#29734) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit *  add heter box * add trainer, worker, wrapper... * format * for ci * format * remove boost get * boost & copyright * rename *  rename * format * format * format Co-authored-by: yaoxuefeng6 --- paddle/fluid/distributed/CMakeLists.txt | 3 + paddle/fluid/framework/CMakeLists.txt | 24 +- paddle/fluid/framework/data_feed.cc | 8 +- paddle/fluid/framework/data_feed.h | 21 +- paddle/fluid/framework/data_set.h | 14 + paddle/fluid/framework/device_worker.h | 96 ++ .../fluid/framework/device_worker_factory.cc | 8 + paddle/fluid/framework/fleet/CMakeLists.txt | 9 + paddle/fluid/framework/fleet/fleet_wrapper.cc | 21 +- paddle/fluid/framework/fleet/heter_context.h | 47 + .../framework/fleet/heter_ps/CMakeLists.txt | 6 + .../framework/fleet/heter_ps/cudf/LICENSE | 201 +++++ .../cudf/concurrent_unordered_map.cuh.h | 830 ++++++++++++++++++ .../fleet/heter_ps/cudf/hash_functions.cuh | 121 +++ .../framework/fleet/heter_ps/cudf/managed.cuh | 33 + .../fleet/heter_ps/cudf/managed_allocator.cuh | 54 ++ .../framework/fleet/heter_ps/feature_value.h | 76 ++ .../framework/fleet/heter_ps/hashtable.h | 64 ++ .../framework/fleet/heter_ps/hashtable.tpp | 126 +++ .../framework/fleet/heter_ps/heter_comm.h | 84 ++ .../framework/fleet/heter_ps/heter_comm.tpp | 494 +++++++++++ .../framework/fleet/heter_ps/heter_ps.cu | 62 ++ .../fluid/framework/fleet/heter_ps/heter_ps.h | 51 ++ .../framework/fleet/heter_ps/heter_ps_base.h | 47 + .../fleet/heter_ps/heter_resource.cc | 91 ++ .../framework/fleet/heter_ps/heter_resource.h | 66 ++ .../framework/fleet/heter_ps/optimizer.cuh | 122 +++ .../framework/fleet/heter_ps/optimizer_conf.h | 32 + .../framework/fleet/heter_ps/test_comm.cu | 112 +++ .../fluid/framework/fleet/ps_gpu_wrapper.cc | 194 ++++ .../fluid/framework/fleet/ps_gpu_wrapper.cu | 182 ++++ paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 118 +++ paddle/fluid/framework/ps_gpu_trainer.cc | 404 +++++++++ paddle/fluid/framework/ps_gpu_worker.cc | 196 +++++ paddle/fluid/framework/trainer.h | 49 ++ paddle/fluid/framework/trainer_factory.cc | 3 + paddle/fluid/operators/pull_box_sparse_op.cc | 11 + paddle/fluid/operators/pull_box_sparse_op.h | 13 + paddle/fluid/pybind/CMakeLists.txt | 3 +- paddle/fluid/pybind/fleet_wrapper_py.cc | 4 - paddle/fluid/pybind/ps_gpu_wrapper_py.cc | 44 + paddle/fluid/pybind/ps_gpu_wrapper_py.h | 29 + paddle/fluid/pybind/pybind.cc | 5 + python/paddle/fluid/executor.py | 2 - .../pslib/optimizer_factory.py | 2 +- python/paddle/fluid/layers/nn.py | 17 +- python/paddle/fluid/trainer_desc.py | 24 + python/paddle/fluid/trainer_factory.py | 2 +- 48 files changed, 4171 insertions(+), 54 deletions(-) create mode 100644 paddle/fluid/framework/fleet/heter_context.h create mode 100644 paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt create mode 100644 paddle/fluid/framework/fleet/heter_ps/cudf/LICENSE create mode 100644 paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h create mode 100644 paddle/fluid/framework/fleet/heter_ps/cudf/hash_functions.cuh create mode 100644 paddle/fluid/framework/fleet/heter_ps/cudf/managed.cuh create mode 100644 paddle/fluid/framework/fleet/heter_ps/cudf/managed_allocator.cuh create mode 100644 paddle/fluid/framework/fleet/heter_ps/feature_value.h create mode 100644 paddle/fluid/framework/fleet/heter_ps/hashtable.h create mode 100644 paddle/fluid/framework/fleet/heter_ps/hashtable.tpp create mode 100644 paddle/fluid/framework/fleet/heter_ps/heter_comm.h create mode 100644 paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp create mode 100644 paddle/fluid/framework/fleet/heter_ps/heter_ps.cu create mode 100644 paddle/fluid/framework/fleet/heter_ps/heter_ps.h create mode 100644 paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h create mode 100644 paddle/fluid/framework/fleet/heter_ps/heter_resource.cc create mode 100644 paddle/fluid/framework/fleet/heter_ps/heter_resource.h create mode 100644 paddle/fluid/framework/fleet/heter_ps/optimizer.cuh create mode 100644 paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h create mode 100644 paddle/fluid/framework/fleet/heter_ps/test_comm.cu create mode 100644 paddle/fluid/framework/fleet/ps_gpu_wrapper.cc create mode 100644 paddle/fluid/framework/fleet/ps_gpu_wrapper.cu create mode 100644 paddle/fluid/framework/fleet/ps_gpu_wrapper.h create mode 100644 paddle/fluid/framework/ps_gpu_trainer.cc create mode 100644 paddle/fluid/framework/ps_gpu_worker.cc create mode 100644 paddle/fluid/pybind/ps_gpu_wrapper_py.cc create mode 100644 paddle/fluid/pybind/ps_gpu_wrapper_py.h diff --git a/paddle/fluid/distributed/CMakeLists.txt b/paddle/fluid/distributed/CMakeLists.txt index e99b8b76534..5367986491d 100644 --- a/paddle/fluid/distributed/CMakeLists.txt +++ b/paddle/fluid/distributed/CMakeLists.txt @@ -1,3 +1,6 @@ +if (WITH_PSLIB) + return() +endif() if(NOT WITH_DISTRIBUTE) return() endif() diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 93afbbf3236..f67d988536f 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -204,11 +204,11 @@ if(WITH_DISTRIBUTE) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc - data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc - heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc + heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto trainer_desc_proto glog fs shell - fleet_wrapper heter_wrapper box_wrapper lodtensor_printer + fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper data_feed_proto timer monitor heter_service_proto pslib_brpc) @@ -218,11 +218,11 @@ if(WITH_DISTRIBUTE) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc - data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc - heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc + heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto trainer_desc_proto glog fs shell - fleet_wrapper heter_wrapper box_wrapper lodtensor_printer + fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper data_feed_proto timer monitor heter_service_proto) @@ -233,11 +233,11 @@ elseif(WITH_PSLIB) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc - data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc - heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc + heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog - lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method + lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper timer monitor pslib_brpc ) # TODO: Fix these unittest failed on Windows # This unittest will always failed, now no CI will run this unittest @@ -248,11 +248,11 @@ else() cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc - data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc - heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc + heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog - lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method + lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper timer monitor) # TODO: Fix these unittest failed on Windows # This unittest will always failed, now no CI will run this unittest diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index e006bf7c33f..176dd3c25c4 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -968,7 +968,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { if (fabs(feasign) < 1e-6 && !use_slots_is_dense_[i]) { continue; } - FeatureKey f; + FeatureFeasign f; f.float_feasign_ = feasign; instance->float_feasigns_.push_back(FeatureItem(f, idx)); } @@ -980,7 +980,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { if (feasign == 0 && !use_slots_is_dense_[i]) { continue; } - FeatureKey f; + FeatureFeasign f; f.uint64_feasign_ = feasign; instance->uint64_feasigns_.push_back(FeatureItem(f, idx)); } @@ -1038,7 +1038,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { if (fabs(feasign) < 1e-6) { continue; } - FeatureKey f; + FeatureFeasign f; f.float_feasign_ = feasign; instance->float_feasigns_.push_back(FeatureItem(f, idx)); } @@ -1048,7 +1048,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { if (feasign == 0) { continue; } - FeatureKey f; + FeatureFeasign f; f.uint64_feasign_ = feasign; instance->uint64_feasigns_.push_back(FeatureItem(f, idx)); } diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index da156bfc5c7..a89e6f8f14f 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -69,20 +69,23 @@ namespace framework { // while (reader->Next()) { // // trainer do something // } -union FeatureKey { +union FeatureFeasign { uint64_t uint64_feasign_; float float_feasign_; }; struct FeatureItem { FeatureItem() {} - FeatureItem(FeatureKey sign, uint16_t slot) { + FeatureItem(FeatureFeasign sign, uint16_t slot) { this->sign() = sign; this->slot() = slot; } - FeatureKey& sign() { return *(reinterpret_cast(sign_buffer())); } - const FeatureKey& sign() const { - const FeatureKey* ret = reinterpret_cast(sign_buffer()); + FeatureFeasign& sign() { + return *(reinterpret_cast(sign_buffer())); + } + const FeatureFeasign& sign() const { + const FeatureFeasign* ret = + reinterpret_cast(sign_buffer()); return *ret; } uint16_t& slot() { return slot_; } @@ -90,7 +93,7 @@ struct FeatureItem { private: char* sign_buffer() const { return const_cast(sign_); } - char sign_[sizeof(FeatureKey)]; + char sign_[sizeof(FeatureFeasign)]; uint16_t slot_; }; @@ -514,7 +517,7 @@ paddle::framework::Archive& operator>>(paddle::framework::Archive& ar, struct RecordCandidate { std::string ins_id_; - std::unordered_multimap feas_; + std::unordered_multimap feas_; size_t shadow_index_ = -1; // Optimization for Reservoir Sample RecordCandidate() {} @@ -606,7 +609,7 @@ class RecordCandidateList { template paddle::framework::Archive& operator<<(paddle::framework::Archive& ar, - const FeatureKey& fk) { + const FeatureFeasign& fk) { ar << fk.uint64_feasign_; ar << fk.float_feasign_; return ar; @@ -614,7 +617,7 @@ paddle::framework::Archive& operator<<(paddle::framework::Archive& ar, template paddle::framework::Archive& operator>>(paddle::framework::Archive& ar, - FeatureKey& fk) { + FeatureFeasign& fk) { ar >> fk.uint64_feasign_; ar >> fk.float_feasign_; return ar; diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 462f6771a01..1c9869fa5af 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -229,6 +229,20 @@ class DatasetImpl : public Dataset { virtual void DynamicAdjustReadersNum(int thread_num); virtual void SetFleetSendSleepSeconds(int seconds); + std::vector>& GetMultiOutputChannel() { + return multi_output_channel_; + } + + std::vector>& GetCurOutputChannel() { + if (cur_channel_ == 0) { + return multi_output_channel_; + } else { + return multi_consume_channel_; + } + } + + Channel& GetInputChannelRef() { return input_channel_; } + protected: virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg); diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index e81e0c66f98..6ecc02bbae6 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -537,6 +537,102 @@ class HeterBoxWorker : public HogwildWorker { }; #endif +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +class PSGPUWorker : public HogwildWorker { + public: + PSGPUWorker() {} + virtual ~PSGPUWorker() {} + virtual void Initialize(const TrainerDesc& desc); + virtual void TrainFiles(); + virtual void SetNeedDump(bool need_dump_field); + virtual void SetChannelWriter(ChannelObject* queue); + virtual void SetWorkerNum(int num) { worker_num_ = num; } + virtual void CacheProgram(const ProgramDesc& main_program) { + new (&program_) ProgramDesc(main_program); + } + virtual void ProduceTasks() override; + virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } + virtual void SetEvent(const cudaEvent_t event) { event_ = event; } + virtual void TrainFilesWithProfiler() {} + void ResetStat(); + + protected: + std::shared_ptr fleet_ptr_; + void PushGradients(); + void DumpParam(); + void CopySparseTable(); + void CopyDenseTable(); + void CopyDenseVars(); + + private: + int mpi_rank_; + std::mutex mutex_; + std::vector send_var_list_; + int worker_num_; + ProgramDesc program_; + HeterObjectPool object_pool_; + bool need_dump_param_; + std::vector dump_param_; + bool need_to_push_dense_; + bool need_dump_field_; + bool dump_slot_; + bool need_to_push_sparse_; + std::vector dump_fields_; + ChannelWriter writer_; + DownpourWorkerParameter param_; + float scale_datanorm_; + // just save the value in param_ for easy access + std::map label_var_name_; + std::map> sparse_key_names_; + std::map> sparse_value_names_; + std::map> sparse_grad_names_; + std::map> dense_value_names_; + std::map> dense_grad_names_; + platform::Place root_place_; + // actually pushed feasign of each table + std::map> sparse_push_keys_; + + // skipped ops + std::vector skip_ops_; + + std::vector<::std::future> push_sparse_status_; + std::vector<::std::future> push_dense_status_; + + // adjust ins weight + AdjustInsWeightConfig adjust_ins_weight_config_; + std::vector nid_show_; + // check nan and inf during training + std::vector check_nan_var_names_; + // copy table + CopyTableConfig copy_table_config_; + std::map table_dependency_; + std::vector> copy_sparse_tables_; + std::vector> copy_dense_tables_; + std::unordered_map> feasign_set_; + paddle::framework::Channel> pull_queue_; + paddle::framework::Channel> push_queue_; + cudaEvent_t event_; + cudaStream_t copy_stream_; + int batch_cnt_{0}; + std::atomic done_cnt_{0}; + + double total_time_; + double read_time_; + double pack_time_; + double pull_sparse_local_time_; + double op_all_time_; + double xpu_op_time_; + double xpu_wait_time_; + double cpu_op_time_; + double collect_label_time_; + double fill_sparse_time_; + double push_sparse_time_; + double gpu_2_cpu_time_; + double cpu_2_gpu_time_; + uint64_t total_inst_; +}; +#endif + #if defined(PADDLE_WITH_NCCL) class SectionWorker : public DeviceWorker { public: diff --git a/paddle/fluid/framework/device_worker_factory.cc b/paddle/fluid/framework/device_worker_factory.cc index ca5a035b4ab..109b520f5a7 100644 --- a/paddle/fluid/framework/device_worker_factory.cc +++ b/paddle/fluid/framework/device_worker_factory.cc @@ -66,8 +66,16 @@ REGISTER_DEVICE_WORKER_CLASS(DownpourWorker); REGISTER_DEVICE_WORKER_CLASS(DownpourWorkerOpt); #ifdef PADDLE_WITH_PSLIB REGISTER_DEVICE_WORKER_CLASS(HeterCpuWorker); +#endif + +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) REGISTER_DEVICE_WORKER_CLASS(HeterBoxWorker); #endif + +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +REGISTER_DEVICE_WORKER_CLASS(PSGPUWorker); +#endif + #if defined(PADDLE_WITH_NCCL) REGISTER_DEVICE_WORKER_CLASS(SectionWorker); #endif diff --git a/paddle/fluid/framework/fleet/CMakeLists.txt b/paddle/fluid/framework/fleet/CMakeLists.txt index 3eee0a1abba..106685cdd9d 100644 --- a/paddle/fluid/framework/fleet/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/CMakeLists.txt @@ -1,7 +1,15 @@ if(WITH_PSLIB) cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope pslib_brpc pslib) + if(WITH_NCCL) + nv_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc + DEPS heter_ps) + add_subdirectory(heter_ps) + else() + cc_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cc) + endif(WITH_NCCL) else() cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope) + cc_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cc) endif(WITH_PSLIB) if(WITH_NCCL) @@ -13,6 +21,7 @@ else() cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor) endif(WITH_BOX_PS) + if(WITH_GLOO) cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo) else() diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 84683b76e98..d073b08ae92 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -198,6 +198,7 @@ void FleetWrapper::HeterPullSparseVars( for (auto& t : fea_values) { pull_result_ptr.push_back(t.data()); } + /* auto status = pslib_ptr_->_worker_ptr->heter_pull_sparse( workerid, pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size(), task->taskid_); @@ -211,6 +212,7 @@ void FleetWrapper::HeterPullSparseVars( exit(-1); } } + */ } void FleetWrapper::HeterPushSparseVars( @@ -359,6 +361,7 @@ int FleetWrapper::RegisterHeterCallback(HeterCallBackFunc handler) { VLOG(3) << "pslib_ptr_=" << pslib_ptr_; VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr; return pslib_ptr_->_worker_ptr->registe_heter_callback(handler); + #else VLOG(0) << "FleetWrapper::RegisterHeterCallback" << " does nothing when no pslib"; @@ -1222,13 +1225,6 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id, void FleetWrapper::LoadWithWhitelist(const uint64_t table_id, const std::string& path, const int mode) { #ifdef PADDLE_WITH_PSLIB - auto ret = pslib_ptr_->_worker_ptr->load_with_whitelist(table_id, path, - std::to_string(mode)); - ret.wait(); - if (ret.get() != 0) { - LOG(ERROR) << "load model of table id: " << table_id - << ", from path: " << path << " failed"; - } #else VLOG(0) << "FleetWrapper::LoadWhitelist does nothing when no pslib"; #endif @@ -1353,16 +1349,7 @@ int32_t FleetWrapper::SaveWithWhitelist(int table_id, const std::string& path, const int mode, const std::string& whitelist_path) { #ifdef PADDLE_WITH_PSLIB - auto ret = pslib_ptr_->_worker_ptr->save_with_whitelist( - table_id, path, std::to_string(mode), whitelist_path); - ret.wait(); - int32_t feasign_cnt = ret.get(); - if (feasign_cnt == -1) { - LOG(ERROR) << "table save cache failed"; - sleep(sleep_seconds_before_fail_exit_); - exit(-1); - } - return feasign_cnt; + return 0; #else VLOG(0) << "FleetWrapper::SaveCache does nothing when no pslib"; return -1; diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h new file mode 100644 index 00000000000..3fad689c17d --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) + +#include +#include +#include + +#include "common_value.h" // NOLINT +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { + +class HeterContext { + public: + Scope* scope_{nullptr}; + std::vector> feature_keys_; + std::vector> value_ptr_; + std::vector> feature_values_; + uint64_t size() { + uint64_t total_size = 0; + for (auto& keys : feature_keys_) { + total_size += keys.size(); + } + return total_size; + } +}; + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt new file mode 100644 index 00000000000..2eed13c530d --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt @@ -0,0 +1,6 @@ +nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc +heter_resource.h hashtable.h DEPS cub device_context) +nv_test(test_heter_comm SRCS test_heter_comm.cu feature_value.h DEPS +heter_comm) + +nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) diff --git a/paddle/fluid/framework/fleet/heter_ps/cudf/LICENSE b/paddle/fluid/framework/fleet/heter_ps/cudf/LICENSE new file mode 100644 index 00000000000..18bcb4316e6 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/cudf/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018 NVIDIA Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h b/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h new file mode 100644 index 00000000000..a884929223b --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h @@ -0,0 +1,830 @@ +/* + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CONCURRENT_UNORDERED_MAP_CUH +#define CONCURRENT_UNORDERED_MAP_CUH + +#include +#include +#include +#include +#include + +#include "hash_functions.cuh" +#include "managed.cuh" +#include "managed_allocator.cuh" + +// TODO: replace this with CUDA_TRY and propagate the error +#ifndef CUDA_RT_CALL +#define CUDA_RT_CALL(call) \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaSuccess != cudaStatus) { \ + fprintf(stderr, \ + "ERROR: CUDA RT call \"%s\" in line %d of file %s failed with " \ + "%s (%d).\n", \ + #call, __LINE__, __FILE__, cudaGetErrorString(cudaStatus), \ + cudaStatus); \ + exit(1); \ + } \ + } +#endif + +// TODO: can we do this more efficiently? +__inline__ __device__ int8_t atomicCAS(int8_t* address, int8_t compare, + int8_t val) { + int32_t* base_address = (int32_t*)((char*)address - ((size_t)address & 3)); + int32_t int_val = (int32_t)val << (((size_t)address & 3) * 8); + int32_t int_comp = (int32_t)compare << (((size_t)address & 3) * 8); + return (int8_t)atomicCAS(base_address, int_comp, int_val); +} + +// TODO: can we do this more efficiently? +__inline__ __device__ int16_t atomicCAS(int16_t* address, int16_t compare, + int16_t val) { + int32_t* base_address = (int32_t*)((char*)address - ((size_t)address & 2)); + int32_t int_val = (int32_t)val << (((size_t)address & 2) * 8); + int32_t int_comp = (int32_t)compare << (((size_t)address & 2) * 8); + return (int16_t)atomicCAS(base_address, int_comp, int_val); +} + +__inline__ __device__ int64_t atomicCAS(int64_t* address, int64_t compare, + int64_t val) { + return (int64_t)atomicCAS((unsigned long long*)address, + (unsigned long long)compare, + (unsigned long long)val); +} + +__inline__ __device__ uint64_t atomicCAS(uint64_t* address, uint64_t compare, + uint64_t val) { + return (uint64_t)atomicCAS((unsigned long long*)address, + (unsigned long long)compare, + (unsigned long long)val); +} + +__inline__ __device__ long long int atomicCAS(long long int* address, + long long int compare, + long long int val) { + return (long long int)atomicCAS((unsigned long long*)address, + (unsigned long long)compare, + (unsigned long long)val); +} + +__inline__ __device__ double atomicCAS(double* address, double compare, + double val) { + return __longlong_as_double(atomicCAS((unsigned long long int*)address, + __double_as_longlong(compare), + __double_as_longlong(val))); +} + +__inline__ __device__ float atomicCAS(float* address, float compare, + float val) { + return __int_as_float( + atomicCAS((int*)address, __float_as_int(compare), __float_as_int(val))); +} + +__inline__ __device__ int64_t atomicAdd(int64_t* address, int64_t val) { + return (int64_t)atomicAdd((unsigned long long*)address, + (unsigned long long)val); +} + +__inline__ __device__ uint64_t atomicAdd(uint64_t* address, uint64_t val) { + return (uint64_t)atomicAdd((unsigned long long*)address, + (unsigned long long)val); +} + +template +__forceinline__ __device__ pair_type +load_pair_vectorized(const pair_type* __restrict__ const ptr) { + if (sizeof(uint4) == sizeof(pair_type)) { + union pair_type2vec_type { + uint4 vec_val; + pair_type pair_val; + }; + pair_type2vec_type converter = {0, 0, 0, 0}; + converter.vec_val = *reinterpret_cast(ptr); + return converter.pair_val; + } else if (sizeof(uint2) == sizeof(pair_type)) { + union pair_type2vec_type { + uint2 vec_val; + pair_type pair_val; + }; + pair_type2vec_type converter = {0, 0}; + converter.vec_val = *reinterpret_cast(ptr); + return converter.pair_val; + } else if (sizeof(int) == sizeof(pair_type)) { + union pair_type2vec_type { + int vec_val; + pair_type pair_val; + }; + pair_type2vec_type converter = {0}; + converter.vec_val = *reinterpret_cast(ptr); + return converter.pair_val; + } else if (sizeof(short) == sizeof(pair_type)) { + union pair_type2vec_type { + short vec_val; + pair_type pair_val; + }; + pair_type2vec_type converter = {0}; + converter.vec_val = *reinterpret_cast(ptr); + return converter.pair_val; + } else { + return *ptr; + } +} + +template +__forceinline__ __device__ void store_pair_vectorized( + pair_type* __restrict__ const ptr, const pair_type val) { + if (sizeof(uint4) == sizeof(pair_type)) { + union pair_type2vec_type { + uint4 vec_val; + pair_type pair_val; + }; + pair_type2vec_type converter = {0, 0, 0, 0}; + converter.pair_val = val; + *reinterpret_cast(ptr) = converter.vec_val; + } else if (sizeof(uint2) == sizeof(pair_type)) { + union pair_type2vec_type { + uint2 vec_val; + pair_type pair_val; + }; + pair_type2vec_type converter = {0, 0}; + converter.pair_val = val; + *reinterpret_cast(ptr) = converter.vec_val; + } else if (sizeof(int) == sizeof(pair_type)) { + union pair_type2vec_type { + int vec_val; + pair_type pair_val; + }; + pair_type2vec_type converter = {0}; + converter.pair_val = val; + *reinterpret_cast(ptr) = converter.vec_val; + } else if (sizeof(short) == sizeof(pair_type)) { + union pair_type2vec_type { + short vec_val; + pair_type pair_val; + }; + pair_type2vec_type converter = {0}; + converter.pair_val = val; + *reinterpret_cast(ptr) = converter.vec_val; + } else { + *ptr = val; + } +} + +template +__global__ void init_hashtbl( // Init every entry of the table with + // pair + value_type* __restrict__ const hashtbl_values, const size_type n, + const key_type key_val, const elem_type elem_val) { + const size_type idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + store_pair_vectorized( + hashtbl_values + idx, + thrust::make_pair( + key_val, elem_val)); // Simply store every element a pair + } +} + +template +struct equal_to { + using result_type = bool; + using first_argument_type = T; + using second_argument_type = T; + __forceinline__ __host__ __device__ constexpr bool operator()( + const first_argument_type& lhs, const second_argument_type& rhs) const { + return lhs == rhs; + } +}; + +template +class cycle_iterator_adapter { + public: + using value_type = typename std::iterator_traits::value_type; + using difference_type = + typename std::iterator_traits::difference_type; + using pointer = typename std::iterator_traits::pointer; + using reference = typename std::iterator_traits::reference; + using iterator_type = Iterator; + + cycle_iterator_adapter() = delete; + + __host__ __device__ explicit cycle_iterator_adapter( + const iterator_type& begin, const iterator_type& end, + const iterator_type& current) + : m_begin(begin), m_end(end), m_current(current) {} + + __host__ __device__ cycle_iterator_adapter& operator++() { + if (m_end == (m_current + 1)) + m_current = m_begin; + else + ++m_current; + return *this; + } + + __host__ __device__ const cycle_iterator_adapter& operator++() const { + if (m_end == (m_current + 1)) + m_current = m_begin; + else + ++m_current; + return *this; + } + + __host__ __device__ cycle_iterator_adapter& operator++(int) { + cycle_iterator_adapter old(m_begin, m_end, m_current); + if (m_end == (m_current + 1)) + m_current = m_begin; + else + ++m_current; + return old; + } + + __host__ __device__ const cycle_iterator_adapter& operator++(int)const { + cycle_iterator_adapter old(m_begin, m_end, m_current); + if (m_end == (m_current + 1)) + m_current = m_begin; + else + ++m_current; + return old; + } + + __host__ __device__ bool equal( + const cycle_iterator_adapter& other) const { + return m_current == other.m_current && m_begin == other.m_begin && + m_end == other.m_end; + } + + __host__ __device__ reference& operator*() { return *m_current; } + + __host__ __device__ const reference& operator*() const { return *m_current; } + + __host__ __device__ const pointer operator->() const { + return m_current.operator->(); + } + + __host__ __device__ pointer operator->() { return m_current; } + + __host__ __device__ iterator_type getter() const { return m_current; } + + private: + iterator_type m_current; + iterator_type m_begin; + iterator_type m_end; +}; + +template +__host__ __device__ bool operator==(const cycle_iterator_adapter& lhs, + const cycle_iterator_adapter& rhs) { + return lhs.equal(rhs); +} + +template +__host__ __device__ bool operator!=(const cycle_iterator_adapter& lhs, + const cycle_iterator_adapter& rhs) { + return !lhs.equal(rhs); +} + +/** + * Does support concurrent insert, but not concurrent insert and probping. + * + * TODO: + * - add constructor that takes pointer to hash_table to avoid allocations + * - extend interface to accept streams + */ +template , + typename Equality = equal_to, + typename Allocator = managed_allocator>, + bool count_collisions = false> +class concurrent_unordered_map : public managed { + public: + using size_type = size_t; + using hasher = Hasher; + using key_equal = Equality; + using allocator_type = Allocator; + using key_type = Key; + using value_type = thrust::pair; + using mapped_type = Element; + using iterator = cycle_iterator_adapter; + using const_iterator = const cycle_iterator_adapter; + + private: + union pair2longlong { + unsigned long long int longlong; + value_type pair; + }; + + public: + concurrent_unordered_map(const concurrent_unordered_map&) = delete; + concurrent_unordered_map& operator=(const concurrent_unordered_map&) = delete; + explicit concurrent_unordered_map(size_type n, + const mapped_type unused_element, + const Hasher& hf = hasher(), + const Equality& eql = key_equal(), + const allocator_type& a = allocator_type()) + : m_hf(hf), + m_equal(eql), + m_allocator(a), + m_hashtbl_size(n), + m_hashtbl_capacity(n), + m_collisions(0), + m_unused_element( + unused_element) { // allocate the raw data of hash table: + // m_hashtbl_values,pre-alloc it on current GPU if UM. + m_hashtbl_values = m_allocator.allocate(m_hashtbl_capacity); + constexpr int block_size = 128; + { + cudaPointerAttributes hashtbl_values_ptr_attributes; + cudaError_t status = cudaPointerGetAttributes( + &hashtbl_values_ptr_attributes, m_hashtbl_values); + +#if CUDART_VERSION >= 10000 + if (cudaSuccess == status && + hashtbl_values_ptr_attributes.type == cudaMemoryTypeManaged) +#else + if (cudaSuccess == status && hashtbl_values_ptr_attributes.isManaged) +#endif + { + int dev_id = 0; + CUDA_RT_CALL(cudaGetDevice(&dev_id)); + CUDA_RT_CALL(cudaMemPrefetchAsync( + m_hashtbl_values, m_hashtbl_size * sizeof(value_type), dev_id, 0)); + } + } + // Initialize kernel, set all entry to unused + init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size>>>( + m_hashtbl_values, m_hashtbl_size, unused_key, m_unused_element); + // CUDA_RT_CALL( cudaGetLastError() ); + CUDA_RT_CALL(cudaStreamSynchronize(0)); + CUDA_RT_CALL(cudaGetLastError()); + } + + ~concurrent_unordered_map() { + m_allocator.deallocate(m_hashtbl_values, m_hashtbl_capacity); + } + + __host__ __device__ iterator begin() { + return iterator(m_hashtbl_values, m_hashtbl_values + m_hashtbl_size, + m_hashtbl_values); + } + __host__ __device__ const_iterator begin() const { + return const_iterator(m_hashtbl_values, m_hashtbl_values + m_hashtbl_size, + m_hashtbl_values); + } + __host__ __device__ iterator end() { + return iterator(m_hashtbl_values, m_hashtbl_values + m_hashtbl_size, + m_hashtbl_values + m_hashtbl_size); + } + __host__ __device__ const_iterator end() const { + return const_iterator(m_hashtbl_values, m_hashtbl_values + m_hashtbl_size, + m_hashtbl_values + m_hashtbl_size); + } + __host__ __device__ size_type size() const { return m_hashtbl_size; } + __host__ __device__ value_type* data() const { return m_hashtbl_values; } + + __forceinline__ static constexpr __host__ __device__ key_type + get_unused_key() { + return unused_key; + } + + // Generic update of a hash table value for any aggregator + template + __forceinline__ __device__ void update_existing_value( + mapped_type& existing_value, value_type const& insert_pair, + aggregation_type) { + // update without CAS + existing_value = insert_pair.second; + } + + __forceinline__ __device__ void accum_existing_value_atomic( + mapped_type& existing_value, value_type const& accum_pair) { + // update with CAS + // existing_value = insert_pair.second; + int num_element = + sizeof(existing_value.data) / sizeof(*(existing_value.data)); + const mapped_type& accumulator = accum_pair.second; + + for (int i = 0; i < num_element; i++) { + atomicAdd(existing_value.data + i, accumulator.data[i]); + } + + // atomicAdd(&existing_value, double val) + } + + // TODO Overload atomicAdd for 1 byte and 2 byte types, until then, overload + // specifically for the + // types where atomicAdd already has an overload. Otherwise the generic + // update_existing_value will + // be used. Specialization for COUNT aggregator + /* + __forceinline__ __host__ __device__ + void update_existing_value(mapped_type & existing_value, value_type const & + insert_pair, + count_op op) + { + atomicAdd(&existing_value, static_cast(1)); + } + // Specialization for COUNT aggregator + __forceinline__ __host__ __device__ + void update_existing_value(mapped_type & existing_value, value_type const & + insert_pair, + count_op op) + { + atomicAdd(&existing_value, static_cast(1)); + } + // Specialization for COUNT aggregator + __forceinline__ __host__ __device__ + void update_existing_value(mapped_type & existing_value, value_type const & + insert_pair, + count_op op) + { + atomicAdd(&existing_value, static_cast(1)); + } + // Specialization for COUNT aggregator + __forceinline__ __host__ __device__ + void update_existing_value(mapped_type & existing_value, value_type const & + insert_pair, + count_op op) + { + atomicAdd(&existing_value, static_cast(1)); + } + */ + + /* --------------------------------------------------------------------------*/ + /** + * @Synopsis Inserts a new (key, value) pair. If the key already exists in + the map + an aggregation operation is performed with the new value and + existing value. + E.g., if the aggregation operation is 'max', then the maximum is + computed + between the new value and existing value and the result is + stored in the map. + * + * @Param[in] x The new (key, value) pair to insert + * @Param[in] op The aggregation operation to perform + * @Param[in] keys_equal An optional functor for comparing two keys + * @Param[in] precomputed_hash Indicates if a precomputed hash value is being + passed in to use + * to determine the write location of the new key + * @Param[in] precomputed_hash_value The precomputed hash value + * @tparam aggregation_type A functor for a binary operation that performs the + aggregation + * @tparam comparison_type A functor for comparing two keys + * + * @Returns An iterator to the newly inserted key,value pair + */ + /* ----------------------------------------------------------------------------*/ + template + __forceinline__ __device__ iterator insert( + const value_type& x, aggregation_type op, + comparison_type keys_equal = key_equal(), bool precomputed_hash = false, + hash_value_type precomputed_hash_value = 0) { + const size_type hashtbl_size = m_hashtbl_size; + value_type* hashtbl_values = m_hashtbl_values; + + hash_value_type hash_value{0}; + + // If a precomputed hash value has been passed in, then use it to determine + // the write location of the new key + if (true == precomputed_hash) { + hash_value = precomputed_hash_value; + } + // Otherwise, compute the hash value from the new key + else { + hash_value = m_hf(x.first); + } + + size_type current_index = hash_value % hashtbl_size; + value_type* current_hash_bucket = &(hashtbl_values[current_index]); + + const key_type insert_key = x.first; + + bool insert_success = false; + + size_type counter = 0; + while (false == insert_success) { + if (counter++ >= hashtbl_size) { + return end(); + } + + key_type& existing_key = current_hash_bucket->first; + mapped_type& existing_value = current_hash_bucket->second; + + // Try and set the existing_key for the current hash bucket to insert_key + const key_type old_key = atomicCAS(&existing_key, unused_key, insert_key); + + // If old_key == unused_key, the current hash bucket was empty + // and existing_key was updated to insert_key by the atomicCAS. + // If old_key == insert_key, this key has already been inserted. + // In either case, perform the atomic aggregation of existing_value and + // insert_value + // Because the hash table is initialized with the identity value of the + // aggregation + // operation, it is safe to perform the operation when the existing_value + // still + // has its initial value + // TODO: Use template specialization to make use of native atomic + // functions + // TODO: How to handle data types less than 32 bits? + if (keys_equal(unused_key, old_key) || keys_equal(insert_key, old_key)) { + update_existing_value(existing_value, x, op); + + insert_success = true; + } + + current_index = (current_index + 1) % hashtbl_size; + current_hash_bucket = &(hashtbl_values[current_index]); + } + + return iterator(m_hashtbl_values, m_hashtbl_values + hashtbl_size, + current_hash_bucket); + } + + /* This function is not currently implemented + __forceinline__ + __host__ __device__ iterator insert(const value_type& x) + { + const size_type hashtbl_size = m_hashtbl_size; + value_type* hashtbl_values = m_hashtbl_values; + const size_type key_hash = m_hf( x.first ); + size_type hash_tbl_idx = key_hash%hashtbl_size; + + value_type* it = 0; + + while (0 == it) { + value_type* tmp_it = hashtbl_values + hash_tbl_idx; +#ifdef __CUDA_ARCH__ + if ( std::numeric_limits::is_integer && +std::numeric_limits::is_integer && sizeof(unsigned long long int) +== sizeof(value_type) +) + { + pair2longlong converter = {0ull}; + converter.pair = thrust::make_pair( unused_key, m_unused_element +); + const unsigned long long int unused = converter.longlong; + converter.pair = x; + const unsigned long long int value = converter.longlong; + const unsigned long long int old_val = atomicCAS( +reinterpret_cast(tmp_it), unused, value ); if ( old_val == unused ) { it = tmp_it; + } + else if ( count_collisions ) + { + atomicAdd( &m_collisions, 1 ); + } + } else { + const key_type old_key = atomicCAS( &(tmp_it->first), unused_key, +x.first ); + if ( m_equal( unused_key, old_key ) ) { + (m_hashtbl_values+hash_tbl_idx)->second = x.second; + it = tmp_it; + } + else if ( count_collisions ) + { + atomicAdd( &m_collisions, 1 ); + } + } +#else + + #pragma omp critical + { + if ( m_equal( unused_key, tmp_it->first ) ) { + hashtbl_values[hash_tbl_idx] = thrust::make_pair( x.first, +x.second ); + it = tmp_it; + } + } +#endif + hash_tbl_idx = (hash_tbl_idx+1)%hashtbl_size; + } + + return iterator( m_hashtbl_values,m_hashtbl_values+hashtbl_size,it); + } + */ + + __forceinline__ __host__ __device__ const_iterator + find(const key_type& k) const { + size_type key_hash = m_hf(k); + size_type hash_tbl_idx = key_hash % m_hashtbl_size; + + value_type* begin_ptr = 0; + + size_type counter = 0; + while (0 == begin_ptr) { + value_type* tmp_ptr = m_hashtbl_values + hash_tbl_idx; + const key_type tmp_val = tmp_ptr->first; + if (m_equal(k, tmp_val)) { + begin_ptr = tmp_ptr; + break; + } + if (m_equal(unused_key, tmp_val) || counter > m_hashtbl_size) { + begin_ptr = m_hashtbl_values + m_hashtbl_size; + break; + } + hash_tbl_idx = (hash_tbl_idx + 1) % m_hashtbl_size; + ++counter; + } + + return const_iterator(m_hashtbl_values, m_hashtbl_values + m_hashtbl_size, + begin_ptr); + } + + template + __forceinline__ __device__ iterator get_insert( + const key_type& k, aggregation_type op, counter_type* value_counter, + comparison_type keys_equal = key_equal(), bool precomputed_hash = false, + hash_value_type precomputed_hash_value = 0) { + const size_type hashtbl_size = m_hashtbl_size; + value_type* hashtbl_values = m_hashtbl_values; + + hash_value_type hash_value{0}; + + // If a precomputed hash value has been passed in, then use it to determine + // the write location of the new key + if (true == precomputed_hash) { + hash_value = precomputed_hash_value; + } + // Otherwise, compute the hash value from the new key + else { + hash_value = m_hf(k); + } + + size_type current_index = hash_value % hashtbl_size; + value_type* current_hash_bucket = &(hashtbl_values[current_index]); + + const key_type insert_key = k; + + bool insert_success = false; + + size_type counter = 0; + while (false == insert_success) { + // Situation %5: No slot: All slot in the hashtable is occupied by other + // key, both get and + // insert fail. Return empty iterator + if (counter++ >= hashtbl_size) { + return end(); + } + + key_type& existing_key = current_hash_bucket->first; + volatile mapped_type& existing_value = current_hash_bucket->second; + + // Try and set the existing_key for the current hash bucket to insert_key + const key_type old_key = atomicCAS(&existing_key, unused_key, insert_key); + + // If old_key == unused_key, the current hash bucket was empty + // and existing_key was updated to insert_key by the atomicCAS. + // If old_key == insert_key, this key has already been inserted. + // In either case, perform the atomic aggregation of existing_value and + // insert_value + // Because the hash table is initialized with the identity value of the + // aggregation + // operation, it is safe to perform the operation when the existing_value + // still + // has its initial value + // TODO: Use template specialization to make use of native atomic + // functions + // TODO: How to handle data types less than 32 bits? + + // Situation #1: Empty slot: this key never exist in the table, ready to + // insert. + if (keys_equal(unused_key, old_key)) { + // update_existing_value(existing_value, x, op); + existing_value = (mapped_type)(atomicAdd(value_counter, 1)); + break; + + } // Situation #2+#3: Target slot: This slot is the slot for this key + else if (keys_equal(insert_key, old_key)) { + while (existing_value == m_unused_element) { + // Situation #2: This slot is inserting by another CUDA thread and the + // value is not yet + // ready, just wait + } + // Situation #3: This slot is already ready, get successfully and return + // (iterator of) the + // value + break; + } + // Situation 4: Wrong slot: This slot is occupied by other key, get fail, + // do nothing and + // linear probing to next slot. + + current_index = (current_index + 1) % hashtbl_size; + current_hash_bucket = &(hashtbl_values[current_index]); + } + + return iterator(m_hashtbl_values, m_hashtbl_values + hashtbl_size, + current_hash_bucket); + } + + int assign_async(const concurrent_unordered_map& other, + cudaStream_t stream = 0) { + m_collisions = other.m_collisions; + if (other.m_hashtbl_size <= m_hashtbl_capacity) { + m_hashtbl_size = other.m_hashtbl_size; + } else { + m_allocator.deallocate(m_hashtbl_values, m_hashtbl_capacity); + m_hashtbl_capacity = other.m_hashtbl_size; + m_hashtbl_size = other.m_hashtbl_size; + + m_hashtbl_values = m_allocator.allocate(m_hashtbl_capacity); + } + CUDA_RT_CALL(cudaMemcpyAsync(m_hashtbl_values, other.m_hashtbl_values, + m_hashtbl_size * sizeof(value_type), + cudaMemcpyDefault, stream)); + return 0; + } + + void clear_async(cudaStream_t stream = 0) { + constexpr int block_size = 128; + init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0, + stream>>>(m_hashtbl_values, m_hashtbl_size, unused_key, + m_unused_element); + if (count_collisions) m_collisions = 0; + } + + unsigned long long get_num_collisions() const { return m_collisions; } + + void print() { + for (size_type i = 0; i < m_hashtbl_size; ++i) { + std::cout << i << ": " << m_hashtbl_values[i].first << "," + << m_hashtbl_values[i].second << std::endl; + } + } + + int prefetch(const int dev_id, cudaStream_t stream = 0) { + cudaPointerAttributes hashtbl_values_ptr_attributes; + cudaError_t status = cudaPointerGetAttributes( + &hashtbl_values_ptr_attributes, m_hashtbl_values); + +#if CUDART_VERSION >= 10000 + if (cudaSuccess == status && + hashtbl_values_ptr_attributes.type == cudaMemoryTypeManaged) +#else + if (cudaSuccess == status && hashtbl_values_ptr_attributes.isManaged) +#endif + { + CUDA_RT_CALL(cudaMemPrefetchAsync(m_hashtbl_values, + m_hashtbl_size * sizeof(value_type), + dev_id, stream)); + } + CUDA_RT_CALL(cudaMemPrefetchAsync(this, sizeof(*this), dev_id, stream)); + + return 0; + } + + template + __forceinline__ __device__ const_iterator + accum(const value_type& x, comparison_type keys_equal = key_equal(), + bool precomputed_hash = false, + hash_value_type precomputed_hash_value = 0) { + const key_type& dst_key = x.first; + auto it = find(dst_key); + + if (it == end()) { + return it; + } + + value_type* dst = it.getter(); + + accum_existing_value_atomic(dst->second, x); + + return it; + } + + private: + const hasher m_hf; + const key_equal m_equal; + + const mapped_type m_unused_element; + + allocator_type m_allocator; + + size_type m_hashtbl_size; + size_type m_hashtbl_capacity; + value_type* m_hashtbl_values; + + unsigned long long m_collisions; +}; + +#endif // CONCURRENT_UNORDERED_MAP_CUH diff --git a/paddle/fluid/framework/fleet/heter_ps/cudf/hash_functions.cuh b/paddle/fluid/framework/fleet/heter_ps/cudf/hash_functions.cuh new file mode 100644 index 00000000000..9264bd0a21c --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/cudf/hash_functions.cuh @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2017, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HASH_FUNCTIONS_CUH +#define HASH_FUNCTIONS_CUH + +using hash_value_type = uint32_t; + +// MurmurHash3_32 implementation from +// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. +// Note - The x86 and x64 versions do _not_ produce the same results, as the +// algorithms are optimized for their respective platforms. You can still +// compile and run any of them on any platform, but your performance with the +// non-native version will be less than optimal. +template +struct MurmurHash3_32 { + using argument_type = Key; + using result_type = hash_value_type; + + __forceinline__ __host__ __device__ MurmurHash3_32() : m_seed(0) {} + + __forceinline__ __host__ __device__ uint32_t rotl32(uint32_t x, int8_t r) const { + return (x << r) | (x >> (32 - r)); + } + + __forceinline__ __host__ __device__ uint32_t fmix32(uint32_t h) const { + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + return h; + } + + /* --------------------------------------------------------------------------*/ + /** + * @Synopsis Combines two hash values into a new single hash value. Called + * repeatedly to create a hash value from several variables. + * Taken from the Boost hash_combine function + * https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html + * + * @Param lhs The first hash value to combine + * @Param rhs The second hash value to combine + * + * @Returns A hash value that intelligently combines the lhs and rhs hash values + */ + /* ----------------------------------------------------------------------------*/ + __host__ __device__ result_type hash_combine(result_type lhs, result_type rhs) { + result_type combined{lhs}; + + combined ^= rhs + 0x9e3779b9 + (combined << 6) + (combined >> 2); + + return combined; + } + + __forceinline__ __host__ __device__ result_type operator()(const Key& key) const { + constexpr int len = sizeof(argument_type); + const uint8_t* const data = (const uint8_t*)&key; + constexpr int nblocks = len / 4; + uint32_t h1 = m_seed; + constexpr uint32_t c1 = 0xcc9e2d51; + constexpr uint32_t c2 = 0x1b873593; + //---------- + // body + const uint32_t* const blocks = (const uint32_t*)(data + nblocks * 4); + for (int i = -nblocks; i; i++) { + uint32_t k1 = blocks[i]; // getblock32(blocks,i); + k1 *= c1; + k1 = rotl32(k1, 15); + k1 *= c2; + h1 ^= k1; + h1 = rotl32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + } + //---------- + // tail + const uint8_t* tail = (const uint8_t*)(data + nblocks * 4); + uint32_t k1 = 0; + switch (len & 3) { + case 3: + k1 ^= tail[2] << 16; + case 2: + k1 ^= tail[1] << 8; + case 1: + k1 ^= tail[0]; + k1 *= c1; + k1 = rotl32(k1, 15); + k1 *= c2; + h1 ^= k1; + }; + //---------- + // finalization + h1 ^= len; + h1 = fmix32(h1); + return h1; + } + + private: + const uint32_t m_seed; +}; + +template +using default_hash = MurmurHash3_32; + +#endif // HASH_FUNCTIONS_CUH diff --git a/paddle/fluid/framework/fleet/heter_ps/cudf/managed.cuh b/paddle/fluid/framework/fleet/heter_ps/cudf/managed.cuh new file mode 100644 index 00000000000..a0e34c66f0b --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/cudf/managed.cuh @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2017, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MANAGED_CUH +#define MANAGED_CUH + +#include + +struct managed { + static void *operator new(size_t n) { + void *ptr = 0; + cudaError_t result = cudaMallocManaged(&ptr, n); + if (cudaSuccess != result || 0 == ptr) throw std::bad_alloc(); + return ptr; + } + + static void operator delete(void *ptr) noexcept { cudaFree(ptr); } +}; + +#endif // MANAGED_CUH diff --git a/paddle/fluid/framework/fleet/heter_ps/cudf/managed_allocator.cuh b/paddle/fluid/framework/fleet/heter_ps/cudf/managed_allocator.cuh new file mode 100644 index 00000000000..62c7d7aa74d --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/cudf/managed_allocator.cuh @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2017, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MANAGED_ALLOCATOR_CUH +#define MANAGED_ALLOCATOR_CUH + +#include + +template +struct managed_allocator { + typedef T value_type; + + managed_allocator() = default; + + template + constexpr managed_allocator(const managed_allocator&) noexcept {} + + T* allocate(std::size_t n) const { + T* ptr = 0; + cudaError_t result = cudaMallocManaged(&ptr, n * sizeof(T)); + if (cudaSuccess != result || nullptr == ptr) { + std::cerr << "ERROR: CUDA Runtime call in line " << __LINE__ << "of file " << __FILE__ + << " failed with " << cudaGetErrorString(result) << " (" << result << ") " + << " Attempted to allocate: " << n * sizeof(T) << " bytes.\n"; + throw std::bad_alloc(); + } + return ptr; + } + void deallocate(T* p, std::size_t) const { cudaFree(p); } +}; + +template +bool operator==(const managed_allocator&, const managed_allocator&) { + return true; +} +template +bool operator!=(const managed_allocator&, const managed_allocator&) { + return false; +} + +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h new file mode 100644 index 00000000000..efdb90b3362 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -0,0 +1,76 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_PSLIB + +#include + +namespace paddle { +namespace framework { +#define MF_DIM 8 + +typedef uint64_t FeatureKey; + +struct FeatureValue { + float delta_score; + float show; + float clk; + int slot; + float lr; + float lr_g2sum; + int mf_size; + float mf[MF_DIM + 1]; + + friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) { + out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot + << " lr: " << val.lr << " mf_size: " << val.mf_size << " mf:"; + for (int i = 0; i < val.mf_size; ++i) { + out << " " << val.mf[i]; + } + return out; + } +}; + +struct FeaturePushValue { + float show; + float clk; + int slot; + float lr_g; + float mf_g[MF_DIM]; +}; +// class DownpourFixedFeatureValue { +// public: +// DownpourFixedFeatureValue() {} +// ~DownpourFixedFeatureValue() {} +// float* data() { +// return _data.data(); +// } +// size_t size() { +// return _data.size(); +// } +// void resize(size_t size) { +// _data.resize(size); +// } +// void shrink_to_fit() { +// _data.shrink_to_fit(); +// } +// private: +// std::vector _data; +// }; + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h new file mode 100644 index 00000000000..0c45edb57f8 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include "thrust/pair.h" +//#include "cudf/concurrent_unordered_map.cuh.h" +#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" +#ifdef PADDLE_WITH_PSLIB + +namespace paddle { +namespace framework { + +template +class TableContainer + : public concurrent_unordered_map::max()> { + public: + TableContainer(size_t capacity) + : concurrent_unordered_map::max()>( + capacity, ValType()) {} +}; + +template +class HashTable { + public: + HashTable(size_t capacity); + virtual ~HashTable(); + HashTable(const HashTable&) = delete; + HashTable& operator=(const HashTable&) = delete; + void insert(const KeyType* d_keys, const ValType* d_vals, size_t len, + cudaStream_t stream); + void get(const KeyType* d_keys, ValType* d_vals, size_t len, + cudaStream_t stream); + void show(); + + template + void update(const KeyType* d_keys, const GradType* d_grads, size_t len, + Sgd sgd, cudaStream_t stream); + + private: + TableContainer* container_; + int BLOCK_SIZE_{256}; + float LOAD_FACTOR{0.75f}; + size_t capacity_; +}; +} // end namespace framework +} // end namespace paddle +#include "hashtable.tpp" +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.tpp b/paddle/fluid/framework/fleet/heter_ps/hashtable.tpp new file mode 100644 index 00000000000..3c125701c6b --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.tpp @@ -0,0 +1,126 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_PSLIB + +namespace paddle { +namespace framework { + +template +struct ReplaceOp { + __host__ __device__ value_type operator()(value_type new_value, + value_type old_value) { + return new_value; + } +}; + +template +__global__ void insert_kernel(Table* table, + const typename Table::key_type* const keys, + const typename Table::mapped_type* const vals, + size_t len) { + ReplaceOp op; + thrust::pair kv; + + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + kv.first = keys[i]; + kv.second = vals[i]; + auto it = table->insert(kv, op); + assert(it != table->end() && "error: insert fails: table is full"); + } +} + +template +__global__ void search_kernel(Table* table, + const typename Table::key_type* const keys, + typename Table::mapped_type* const vals, + size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + auto it = table->find(keys[i]); + if (it != table->end()) { + vals[i] = it->second; + } + } +} + +template +__global__ void update_kernel(Table* table, + const typename Table::key_type* const keys, + const GradType* const grads, size_t len, + Sgd sgd) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + auto it = table->find(keys[i]); + if (it != table->end()) { + sgd.update_value((it.getter())->second, grads[i]); + } + } +} + +template +HashTable::HashTable(size_t capacity) { + container_ = new TableContainer(capacity); +} + +template +HashTable::~HashTable() { + delete container_; +} + +template +void HashTable::show() { + container_->print(); +} + +template +void HashTable::get(const KeyType* d_keys, ValType* d_vals, + size_t len, cudaStream_t stream) { + if (len == 0) { + return; + } + const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; + search_kernel<<>>(container_, d_keys, + d_vals, len); +} + +template +void HashTable::insert(const KeyType* d_keys, + const ValType* d_vals, size_t len, + cudaStream_t stream) { + if (len == 0) { + return; + } + const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; + insert_kernel<<>>(container_, d_keys, + d_vals, len); +} + +template +template +void HashTable::update(const KeyType* d_keys, + const GradType* d_grads, size_t len, + Sgd sgd, cudaStream_t stream) { + if (len == 0) { + return; + } + const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; + update_kernel<<>>(container_, d_keys, + d_grads, len, sgd); +} + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h new file mode 100644 index 00000000000..70dae31c175 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "cub/cub.cuh" +#include "hashtable.h" +#include "heter_resource.h" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh" +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/place.h" +#include "thrust/pair.h" + +#ifdef PADDLE_WITH_PSLIB + +namespace paddle { +namespace framework { + +struct CustomGradMerger { + template + CUB_RUNTIME_FUNCTION __forceinline__ __device__ T + operator()(const T& a, const T& b) const { + T out; + out.slot = a.slot; + out.show = a.show + b.show; + out.clk = a.clk + b.clk; + out.lr_g = a.lr_g + b.lr_g; + for (int i = 0; i < MF_DIM; ++i) { + out.mf_g[i] = a.mf_g[i] + b.mf_g[i]; + } + return out; + } +}; + +template +class HeterComm { + public: + HeterComm(size_t capacity, std::shared_ptr resource); + virtual ~HeterComm(); + HeterComm(const HeterComm&) = delete; + HeterComm& operator=(const HeterComm&) = delete; + + void split_input_to_shard(KeyType* d_keys, int* d_idx_ptr, size_t len, + int* left, int* right, int gpu_num); + void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, + int& uniq_len); + void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len); + void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len, + size_t chunk_size, int stream_num); + void dump(); + void show_one_table(int gpu_num); + int get_index_by_devid(int devid); + + template + void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len, + Sgd& sgd); + + int log2i(int x); + + private: + using Table = HashTable; + int block_size_{256}; + float load_factor_{0.75}; + std::vector tables_; + std::shared_ptr resource_; + CustomGradMerger merger_; +}; + +} // end namespace framework +} // end namespace paddle +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp" +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp b/paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp new file mode 100644 index 00000000000..781e3a3a714 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp @@ -0,0 +1,494 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#ifdef PADDLE_WITH_PSLIB +namespace paddle { +namespace framework { + +template +__global__ void fill_idx(T* idx, size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + idx[i] = i; + } +} + +template +void show_tensor(T* input, size_t len, cudaStream_t stream, std::string name) { + T tmp[len]; + cudaMemcpyAsync(&tmp, input, sizeof(T) * len, cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + std::cout << name; + for (int i = 0; i < len; ++i) { + std::cout << ":" << tmp[i]; + } + std::cout << std::endl; +} + +template +__global__ void calc_shard_offset(T* idx, T* left, T* right, size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len - 1) { + if (idx[i] != idx[i + 1]) { + right[idx[i]] = i; + left[idx[i + 1]] = i + 1; + } + } + if (i == 0) { + left[idx[i]] = i; + } + if (i == (len - 1)) { + right[idx[i]] = i; + } +} + +template +__global__ void calc_shard_index(KeyType* d_keys, size_t len, T* shard_index, + int total_gpu) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + shard_index[i] = d_keys[i] % total_gpu; + } +} + +template +__global__ void fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, T* idx, + size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + d_shard_keys[i] = d_keys[idx[i]]; + } +} + +template +__global__ void fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, + GradType* d_shard_grads, GradType* d_grads, + T* idx, size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + d_shard_keys[i] = d_keys[idx[i]]; + d_shard_grads[i] = d_grads[idx[i]]; + } +} + +template +__global__ void fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, + size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + d_vals[idx[i]] = d_shard_vals[i]; + } +} + +template +HeterComm::HeterComm( + size_t capacity, std::shared_ptr resource) { + resource_ = resource; + for (int i = 0; i < resource_->total_gpu(); ++i) { + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + auto table = new Table(capacity / load_factor_); + tables_.push_back(table); + } +} + +template +HeterComm::~HeterComm() { + for (auto& table : tables_) { + delete table; + table = nullptr; + } +} + +template +void HeterComm::show_one_table(int gpu_num) { + tables_[gpu_num]->show(); +} + +template +int HeterComm::log2i(int x) { + unsigned res = 0; + while (x >>= 1) { + ++res; + } + return res; +} + +template +int HeterComm::get_index_by_devid(int devid) { + return resource_->get_index_by_devid(devid); +} + +template +void HeterComm::build_ps(int num, KeyType* h_keys, + ValType* h_vals, size_t len, + size_t chunk_size, + int stream_num) { + if (len <= 0) { + return; + } + int dev_id = resource_->dev_id(num); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + + std::vector> d_key_bufs; + std::vector> d_val_bufs; + + cudaStream_t streams[stream_num]; + for (int i = 0; i < stream_num; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&(streams[i]))); + auto d_k_buf = memory::AllocShared(place, chunk_size * sizeof(KeyType)); + auto d_v_buf = memory::AllocShared(place, chunk_size * sizeof(ValType)); + d_key_bufs.push_back(d_k_buf); + d_val_bufs.push_back(d_v_buf); + } + + int cur_len = 0; + int cur_stream = 0; + + while (cur_len < len) { + cur_stream = cur_stream % stream_num; + int tmp_len = cur_len + chunk_size > len ? len - cur_len : chunk_size; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpyAsync(d_key_bufs[cur_stream]->ptr(), h_keys + cur_len, + sizeof(KeyType) * tmp_len, cudaMemcpyHostToDevice, + streams[cur_stream])); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpyAsync(d_val_bufs[cur_stream]->ptr(), h_vals + cur_len, + sizeof(ValType) * tmp_len, cudaMemcpyHostToDevice, + streams[cur_stream])); + tables_[num]->insert( + reinterpret_cast(d_key_bufs[cur_stream]->ptr()), + reinterpret_cast(d_val_bufs[cur_stream]->ptr()), tmp_len, + streams[cur_stream]); + cur_stream += 1; + cur_len += tmp_len; + } + + for (int i = 0; i < stream_num; ++i) { + cudaStreamSynchronize(streams[i]); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(streams[i])); + } +} + +template +void HeterComm::merge_grad(int gpu_num, KeyType* d_keys, + GradType* d_grads, + size_t len, int& uniq_len) { + int dev_id = resource_->dev_id(gpu_num); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->stream(gpu_num); + + size_t temp_storage_bytes; + + auto d_merge_keys = memory::AllocShared(place, len * sizeof(KeyType)); + KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); + + auto d_merge_grads = memory::AllocShared(place, len * sizeof(GradType)); + GradType* d_merge_grads_ptr = + reinterpret_cast(d_merge_grads->ptr()); + + PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceRadixSort::SortPairs( + NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_grads, + d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false)); + + void* d_buff = NULL; + auto d_temp_storage = memory::AllocShared(place, temp_storage_bytes); + + PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceRadixSort::SortPairs( + d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr, + d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false)); + temp_storage_bytes = 0; + + auto d_num_runs_out_mem = memory::AllocShared(place, sizeof(int)); + int* d_num_runs_out = reinterpret_cast(d_num_runs_out_mem->ptr()); + + PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceReduce::ReduceByKey( + NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_merge_grads_ptr, + d_grads, d_num_runs_out, merger_, len, stream, false)); + + if (d_temp_storage->size() < temp_storage_bytes) { + d_temp_storage = NULL; + d_temp_storage = memory::AllocShared(place, temp_storage_bytes); + } + + PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceReduce::ReduceByKey( + d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys, + d_merge_grads_ptr, d_grads, d_num_runs_out, merger_, len, stream, false)); + + cudaMemcpyAsync(&uniq_len, d_num_runs_out, sizeof(int), + cudaMemcpyDeviceToHost, stream); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +} + +template +void HeterComm::split_input_to_shard( + KeyType* d_keys, int* d_idx_ptr, size_t len, int* left, int* right, + int gpu_num) { + int total_gpu = resource_->total_gpu(); + int dev_id = resource_->dev_id(gpu_num); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->stream(gpu_num); + + auto d_idx_tmp = memory::AllocShared(place, len * sizeof(int)); + int* d_idx_tmp_ptr = reinterpret_cast(d_idx_tmp->ptr()); + + auto d_shard_index = memory::AllocShared(place, len * sizeof(int)); + int* d_shard_index_ptr = reinterpret_cast(d_shard_index->ptr()); + + auto d_shard_index_tmp = memory::AllocShared(place, len * sizeof(int)); + int* d_shard_index_tmp_ptr = reinterpret_cast(d_shard_index_tmp->ptr()); + + int grid_size = (len - 1) / block_size_ + 1; + fill_idx<<>>(d_idx_tmp_ptr, len); + calc_shard_index<<>>( + d_keys, len, d_shard_index_tmp_ptr, total_gpu); + + size_t temp_storage_bytes; + const int num_bits = 1 + log2i(total_gpu); + PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceRadixSort::SortPairs( + NULL, temp_storage_bytes, d_shard_index_tmp_ptr, d_shard_index_ptr, + d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream)); + + auto d_temp_storage = memory::AllocShared(place, temp_storage_bytes); + PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceRadixSort::SortPairs( + d_temp_storage->ptr(), temp_storage_bytes, d_shard_index_tmp_ptr, + d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream)); + calc_shard_offset<<>>(d_shard_index_ptr, + left, right, len); + cudaStreamSynchronize(stream); +} + +template +void HeterComm::pull_sparse(int num, KeyType* d_keys, + ValType* d_vals, + size_t len) { + if (len == 0) { + return; + } + + int total_gpu = resource_->total_gpu(); + int dev_id = resource_->dev_id(num); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->stream(num); + + int grid_size = (len - 1) / block_size_ + 1; + + int h_left[total_gpu]; + int h_right[total_gpu]; + + auto d_left = memory::AllocShared(place, total_gpu * sizeof(int)); + auto d_right = memory::AllocShared(place, total_gpu * sizeof(int)); + int* d_left_ptr = reinterpret_cast(d_left->ptr()); + int* d_right_ptr = reinterpret_cast(d_right->ptr()); + + cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int)); + cudaMemset(d_right_ptr, -1, total_gpu * sizeof(int)); + // + auto d_idx = memory::AllocShared(place, len * sizeof(int)); + int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); + + auto d_shard_keys = memory::AllocShared(place, len * sizeof(KeyType)); + KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); + auto d_shard_vals = memory::AllocShared(place, len * sizeof(ValType)); + ValType* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); + + split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num); + + fill_shard_key<<>>(d_shard_keys_ptr, + d_keys, d_idx_ptr, len); + + cudaStreamSynchronize(stream); + + cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), + cudaMemcpyDeviceToHost); + cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), + cudaMemcpyDeviceToHost); + + std::vector d_remote_shard_keys_ptr; + std::vector d_remote_shard_vals_ptr; + std::vector> d_remote_shard_keys; + std::vector> d_remote_shard_vals; + + for (int i = 0; i < total_gpu; ++i) { + int shard_len = h_right[i] - h_left[i] + 1; + if (shard_len == 0) { + continue; + } + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + platform::CUDAPlace remote_place = + platform::CUDAPlace(resource_->dev_id(i)); + d_remote_shard_keys.push_back( + memory::AllocShared(remote_place, shard_len * sizeof(KeyType))); + d_remote_shard_keys_ptr.push_back( + reinterpret_cast(d_remote_shard_keys[i]->ptr())); + + d_remote_shard_vals.push_back( + memory::AllocShared(remote_place, shard_len * sizeof(ValType))); + d_remote_shard_vals_ptr.push_back( + reinterpret_cast(d_remote_shard_vals[i]->ptr())); + } + + for (int i = 0; i < total_gpu; ++i) { + int shard_len = h_right[i] - h_left[i] + 1; + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + cudaMemcpyAsync(d_remote_shard_keys_ptr[i], d_shard_keys_ptr + h_left[i], + shard_len * sizeof(KeyType), cudaMemcpyDefault, stream); + } + cudaStreamSynchronize(stream); + + for (int i = 0; i < total_gpu; ++i) { + if (h_left[i] == -1) { + continue; + } + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + tables_[i]->get(d_remote_shard_keys_ptr[i], d_remote_shard_vals_ptr[i], + h_right[i] - h_left[i] + 1, resource_->stream(i)); + } + for (int i = 0; i < total_gpu; ++i) { + cudaStreamSynchronize(resource_->stream(i)); + } + + for (int i = 0; i < total_gpu; ++i) { + int shard_len = h_right[i] - h_left[i] + 1; + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + cudaMemcpyAsync(d_shard_vals_ptr + h_left[i], d_remote_shard_vals_ptr[i], + shard_len * sizeof(ValType), cudaMemcpyDefault, + resource_->stream(i)); + } + + for (int i = 0; i < total_gpu; ++i) { + cudaStreamSynchronize(resource_->stream(i)); + } + + fill_dvals<<>>(d_shard_vals_ptr, d_vals, + d_idx_ptr, len); + cudaStreamSynchronize(stream); +} + +template +template +void HeterComm::push_sparse(int gpu_num, + KeyType* d_keys, + GradType* d_grads, + size_t len, Sgd& sgd) { + if (len == 0) { + return; + } + + int total_gpu = resource_->total_gpu(); + int dev_id = resource_->dev_id(gpu_num); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->stream(gpu_num); + + int h_left[total_gpu]; + int h_right[total_gpu]; + + auto d_left = memory::AllocShared(place, total_gpu * sizeof(int)); + auto d_right = memory::AllocShared(place, total_gpu * sizeof(int)); + int* d_left_ptr = reinterpret_cast(d_left->ptr()); + int* d_right_ptr = reinterpret_cast(d_right->ptr()); + + cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int)); + cudaMemset(d_right_ptr, -1, total_gpu * sizeof(int)); + // + auto d_idx = memory::AllocShared(place, len * sizeof(int)); + int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); + + auto d_shard_keys = memory::AllocShared(place, len * sizeof(KeyType)); + KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); + auto d_shard_grads = memory::AllocShared(place, len * sizeof(GradType)); + GradType* d_shard_grads_ptr = + reinterpret_cast(d_shard_grads->ptr()); + + int uniq_len = len; + merge_grad(gpu_num, d_keys, d_grads, len, uniq_len); + + int grid_size = (uniq_len - 1) / block_size_ + 1; + + split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, + gpu_num); + + fill_shard_grads<<>>( + d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, + uniq_len); + + cudaStreamSynchronize(stream); + + cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), + cudaMemcpyDeviceToHost); + cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), + cudaMemcpyDeviceToHost); + + std::vector d_remote_shard_keys_ptr; + std::vector d_remote_shard_grads_ptr; + std::vector> d_remote_shard_keys; + std::vector> d_remote_shard_grads; + + for (int i = 0; i < total_gpu; ++i) { + int shard_len = h_right[i] - h_left[i] + 1; + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + platform::CUDAPlace remote_place = + platform::CUDAPlace(resource_->dev_id(i)); + d_remote_shard_keys.push_back( + memory::AllocShared(remote_place, shard_len * sizeof(KeyType))); + d_remote_shard_keys_ptr.push_back( + reinterpret_cast(d_remote_shard_keys[i]->ptr())); + + d_remote_shard_grads.push_back( + memory::AllocShared(remote_place, shard_len * sizeof(GradType))); + d_remote_shard_grads_ptr.push_back( + reinterpret_cast(d_remote_shard_grads[i]->ptr())); + } + + for (int i = 0; i < total_gpu; ++i) { + int shard_len = h_right[i] - h_left[i] + 1; + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + cudaMemcpyAsync(d_remote_shard_keys_ptr[i], d_shard_keys_ptr + h_left[i], + shard_len * sizeof(KeyType), cudaMemcpyDefault, stream); + cudaMemcpyAsync(d_remote_shard_grads_ptr[i], d_shard_grads_ptr + h_left[i], + shard_len * sizeof(GradType), cudaMemcpyDefault, stream); + } + + cudaStreamSynchronize(stream); + + for (int i = 0; i < total_gpu; ++i) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + tables_[i]->update(d_remote_shard_keys_ptr[i], d_remote_shard_grads_ptr[i], + h_right[i] - h_left[i] + 1, sgd, resource_->stream(i)); + } + for (int i = 0; i < total_gpu; ++i) { + cudaStreamSynchronize(resource_->stream(i)); + } +} + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu new file mode 100644 index 00000000000..a3f306f6100 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -0,0 +1,62 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/fleet/heter_ps/heter_ps.h" + +#ifdef PADDLE_WITH_PSLIB + +namespace paddle { +namespace framework { + +HeterPsBase* HeterPsBase::get_instance( + size_t capacity, std::shared_ptr resource) { + return new HeterPs(capacity, resource); +} + +HeterPs::HeterPs(size_t capacity, std::shared_ptr resource) { + comm_ = + std::make_shared>( + capacity, resource); + opt_ = Optimizer(); +} + +HeterPs::~HeterPs() {} + +void HeterPs::pull_sparse(int num, FeatureKey* d_keys, FeatureValue* d_vals, + size_t len) { + comm_->pull_sparse(num, d_keys, d_vals, len); +} + +void HeterPs::build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, + size_t len, size_t chunk_size, int stream_num) { + comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num); +} + +int HeterPs::get_index_by_devid(int devid) { + return comm_->get_index_by_devid(devid); +} + +void HeterPs::dump() {} + +void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } + +void HeterPs::push_sparse(int num, FeatureKey* d_keys, + FeaturePushValue* d_grads, size_t len) { + comm_->push_sparse(num, d_keys, d_grads, len, opt_); +} + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h new file mode 100644 index 00000000000..6c6d408a53b --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh" + +#ifdef PADDLE_WITH_PSLIB + +namespace paddle { +namespace framework { + +class HeterPs : public HeterPsBase { + public: + HeterPs() {} + HeterPs(size_t capacity, std::shared_ptr resource); + virtual ~HeterPs(); + HeterPs(const HeterPs&) = delete; + HeterPs& operator=(const HeterPs&) = delete; + + virtual void pull_sparse(int num, FeatureKey* d_keys, FeatureValue* d_vals, + size_t len) override; + virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, + size_t len, size_t chunk_size, int stream_num) override; + virtual void dump() override; + virtual int get_index_by_devid(int devid) override; + virtual void show_one_table(int gpu_num) override; + virtual void push_sparse(int num, FeatureKey* d_keys, + FeaturePushValue* d_grads, size_t len) override; + + private: + std::shared_ptr> comm_; + Optimizer opt_; +}; + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h new file mode 100644 index 00000000000..a8802b00eac --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" + +#ifdef PADDLE_WITH_PSLIB + +namespace paddle { +namespace framework { + +class HeterPsBase { + public: + HeterPsBase(){}; + HeterPsBase(size_t capacity, std::shared_ptr resource){}; + virtual ~HeterPsBase(){}; + HeterPsBase(const HeterPsBase&) = delete; + HeterPsBase& operator=(const HeterPsBase&) = delete; + + virtual void pull_sparse(int num, FeatureKey* d_keys, FeatureValue* d_vals, + size_t len) = 0; + virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, + size_t len, size_t chunk_size, int stream_num) = 0; + virtual int get_index_by_devid(int devid) = 0; + virtual void dump() = 0; + virtual void show_one_table(int gpu_num) = 0; + virtual void push_sparse(int num, FeatureKey* d_keys, + FeaturePushValue* d_grads, size_t len) = 0; + static HeterPsBase* get_instance(size_t capacity, + std::shared_ptr resource); +}; + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc b/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc new file mode 100644 index 00000000000..916ef5c5ee4 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_resource.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_PSLIB +#include "heter_resource.h" +#include "paddle/fluid/platform/cuda_device_guard.h" + +namespace paddle { +namespace framework { + +GPUResource::GPUResource(int dev_id, int index) { + index_ = index; + dev_id_ = dev_id; + + platform::CUDADeviceGuard guard(dev_id_); + + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamCreateWithFlags(©_stream_, cudaStreamNonBlocking)); +} + +GPUResource::~GPUResource() { + platform::CUDADeviceGuard guard(dev_id_); + + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(copy_stream_)); +} + +void HeterPsResource::enable_p2p() { + for (size_t i = 0; i < dev_ids_.size(); ++i) { + platform::CUDADeviceGuard guard(dev_ids_[i]); + for (size_t j = 0; j < dev_ids_.size(); ++j) { + if (i != j) { + int p2p_flag; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaDeviceCanAccessPeer(&p2p_flag, dev_ids_[i], dev_ids_[j])); + if (p2p_flag == 1) { + cudaError_t ret = cudaDeviceEnablePeerAccess(dev_ids_[j], 0); + if (ret != cudaSuccess && ret != cudaErrorPeerAccessAlreadyEnabled) { + VLOG(0) << " Cuda error(" << ret << "), " << cudaGetErrorString(ret) + << "."; + } else { + cudaGetLastError(); + } + } + } + } + } +} + +HeterPsResource::HeterPsResource(const std::vector& dev_ids) { + dev_ids_ = dev_ids; + for (size_t i = 0; i < dev_ids_.size(); ++i) { + std::shared_ptr resource = + std::make_shared(dev_ids_[i], i); + resources_.push_back(resource); + devid_2_index_[dev_ids_[i]] = i; + } +} + +cudaStream_t HeterPsResource::copy_stream(int num) { + return resources_[num]->copy_stream(); +} + +cudaStream_t HeterPsResource::stream(int num) { + return resources_[num]->stream(); +} + +int HeterPsResource::dev_id(int num) { return dev_ids_[num]; } + +int HeterPsResource::get_index_by_devid(int devid) { + return devid_2_index_[devid]; +} + +int HeterPsResource::total_gpu() { return dev_ids_.size(); } + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_resource.h b/paddle/fluid/framework/fleet/heter_ps/heter_resource.h new file mode 100644 index 00000000000..ca78888260d --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/heter_resource.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/enforce.h" + +#ifdef PADDLE_WITH_PSLIB + +namespace paddle { +namespace framework { + +class GPUResource { + public: + GPUResource(int device_id, int index); + virtual ~GPUResource(); + GPUResource(const GPUResource&) = delete; + GPUResource& operator=(const GPUResource&) = delete; + + int dev_id() const { return dev_id_; } + int index() const { return index_; } + cudaStream_t stream() { return stream_; } + cudaStream_t copy_stream() { return copy_stream_; } + + int dev_id_; + int index_; + cudaStream_t stream_; + cudaStream_t copy_stream_; +}; + +class HeterPsResource { + public: + HeterPsResource(const std::vector& dev_ids); + HeterPsResource(const HeterPsResource&) = delete; + HeterPsResource& operator=(const HeterPsResource&) = delete; + virtual ~HeterPsResource() {} + void enable_p2p(); + int total_gpu(); + int get_index_by_devid(int devid); + cudaStream_t stream(int num); + cudaStream_t copy_stream(int num); + int dev_id(int num); + + std::vector> resources_; + std::vector dev_ids_; + std::map devid_2_index_; +}; + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh new file mode 100644 index 00000000000..7263f610fcb --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh @@ -0,0 +1,122 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "optimizer_conf.h" +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" + +#ifdef PADDLE_WITH_PSLIB + +namespace paddle { +namespace framework { + +__device__ double cuda_double_random(unsigned long long seed) { + // copy from MurmurHash3 + seed ^= seed >> 33; + seed *= 0xff51afd7ed558ccd; + seed ^= seed >> 33; + seed *= 0xc4ceb9fe1a85ec53; + seed ^= seed >> 33; + return ((double)seed / 18446744073709551615.0); +} + +__device__ float cuda_normal_random(unsigned long long idx) { + static double pi = 3.1415926897932384; + unsigned long long x = clock64() + idx; + double x1, x2, res; + while (1) { + x1 = cuda_double_random(x); + x2 = cuda_double_random(x + 33); + res = sqrt(-2.0 * log(x1)) * cos(2.0 * pi * x2); + if (-10 < res && res < 10) break; + x += 207; + } + return res; +} + +template +class Optimizer { + public: + Optimizer() {} + + ~Optimizer() {} + + void initialize() {} + + __device__ void update_lr(float& w, float& g2sum, float g, float scale) { + double add_g2sum = 0; + double ratio = optimizer_config::learning_rate * + sqrt(optimizer_config::initial_g2sum / + (optimizer_config::initial_g2sum + g2sum)); + double scaled_grad = g / scale; + + w += scaled_grad * ratio; + + if (w < optimizer_config::min_bound) w = optimizer_config::min_bound; + if (w > optimizer_config::max_bound) w = optimizer_config::max_bound; + + add_g2sum = scaled_grad * scaled_grad; + + g2sum += add_g2sum; + } + + __device__ void update_mf(int n, float* w, float& g2sum, const float* g, + float scale) { + double add_g2sum = 0; + double ratio = optimizer_config::mf_learning_rate * + sqrt(optimizer_config::mf_initial_g2sum / + (optimizer_config::mf_initial_g2sum + g2sum)); + for (int i = 0; i < n; ++i) { + double scaled_grad = g[i] / scale; + + w[i] += scaled_grad * ratio; + + if (w[i] < optimizer_config::mf_min_bound) + w[i] = optimizer_config::mf_min_bound; + if (w[i] > optimizer_config::mf_max_bound) + w[i] = optimizer_config::mf_max_bound; + add_g2sum = scaled_grad * scaled_grad; + } + + g2sum += add_g2sum / n; + } + __device__ void update_value(ValType& val, const GradType& grad) { + val.slot = grad.slot; + ; + val.show += grad.show; + val.clk += grad.clk; + + update_lr(val.lr, val.lr_g2sum, grad.lr_g, 1.0); + + if (val.mf_size == 0) { + if (optimizer_config::mf_create_thresholds <= + optimizer_config::nonclk_coeff * (val.show - val.clk) + + optimizer_config::clk_coeff * val.clk) { + val.mf_size = MF_DIM + 1; + val.mf[0] = 0; + for (int i = 0; i < MF_DIM; ++i) { + val.mf[i + 1] = (cuda_normal_random((int)grad.show) * 2 - 1) * + optimizer_config::mf_initial_range; + } + } + } else { + update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, 1.0); + } + } +}; + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h b/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h new file mode 100644 index 00000000000..d63d59ad2c0 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace optimizer_config { +__constant__ float mf_create_thresholds = 1; +__constant__ float nonclk_coeff = 1; +__constant__ float clk_coeff = 1; +__constant__ float min_bound = -10000; +__constant__ float max_bound = 10000; +__constant__ float learning_rate = 1; +__constant__ float initial_g2sum = 1; +__constant__ float initial_range = 1; + +__constant__ float mf_learning_rate = 1; +__constant__ float mf_initial_g2sum = 1; +__constant__ float mf_initial_range = 1; +__constant__ float mf_min_bound = 1; +__constant__ float mf_max_bound = 1; +} diff --git a/paddle/fluid/framework/fleet/heter_ps/test_comm.cu b/paddle/fluid/framework/fleet/heter_ps/test_comm.cu new file mode 100644 index 00000000000..88b02a6947f --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/test_comm.cu @@ -0,0 +1,112 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh" +#include "paddle/fluid/platform/cuda_device_guard.h" + +using namespace paddle::framework; + +TEST(TEST_FLEET, heter_comm) { + int gpu_count = 3; + std::vector dev_ids; + dev_ids.push_back(0); + dev_ids.push_back(1); + dev_ids.push_back(2); + std::shared_ptr resource = + std::make_shared(dev_ids); + resource->enable_p2p(); + std::vector count; + std::vector> keys; + std::vector> vals; + count.resize(dev_ids.size(), 0); + keys.resize(dev_ids.size()); + vals.resize(dev_ids.size()); + + for (int i = 0; i < 10; i++) { + FeatureKey key; + FeatureValue val; + int gpu_num = i % gpu_count; + key = i; + val.lr = i; + val.lr_g2sum = val.mf_size = val.show = val.clk = val.slot = 0; + keys[gpu_num].push_back(key); + vals[gpu_num].push_back(val); + count[gpu_num] += 1; + } + + size_t size = 0; + for (size_t i = 0; i < count.size(); ++i) { + size = std::max(size, count[i]); + } + + auto heter_comm = + std::make_shared>( + size, resource); + for (int i = 0; i < gpu_count; ++i) { + std::cout << "building table: " << i << std::endl; + heter_comm->build_ps(i, keys[i].data(), vals[i].data(), count[i], 10, 1); + heter_comm->show_one_table(i); + } + + std::cout << "testing pull sparse:" << std::endl; + paddle::platform::CUDADeviceGuard guard(0); + FeatureKey* pull_keys; + FeatureValue* pull_vals; + cudaMallocManaged(&pull_keys, 5 * sizeof(FeatureKey)); + cudaMallocManaged(&pull_vals, 5 * sizeof(FeatureValue)); + + pull_keys[0] = 2; + pull_keys[1] = 3; + pull_keys[2] = 9; + pull_keys[3] = 1; + pull_keys[4] = 6; + + heter_comm->pull_sparse(0, pull_keys, pull_vals, 5); + for (int i = 0; i < 5; i++) { + std::cout << pull_keys[i] << ": " << pull_vals[i] << std::endl; + } + cudaFree(pull_keys); + cudaFree(pull_vals); + + std::cout << "testing push sparse:" << std::endl; + Optimizer opt; + FeatureKey* push_keys; + FeaturePushValue* push_vals; + cudaMallocManaged(&push_keys, 5 * sizeof(FeatureKey)); + cudaMallocManaged(&push_vals, 5 * sizeof(FeaturePushValue)); + push_keys[0] = 2; + push_keys[1] = 3; + push_keys[2] = 9; + push_keys[3] = 1; + push_keys[4] = 3; + for (int i = 0; i < 5; ++i) { + push_vals[i].lr_g = push_keys[i] * 100; + push_vals[i].slot = push_keys[i]; + push_vals[i].show = push_keys[i]; + push_vals[i].clk = push_keys[i]; + } + heter_comm->push_sparse(0, push_keys, push_vals, 5, opt); + for (int i = 0; i < gpu_count; ++i) { + std::cout << "table " << i << ";" << std::endl; + heter_comm->show_one_table(i); + } + + cudaFree(push_keys); + cudaFree(push_vals); +} diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc new file mode 100644 index 00000000000..e70b1ca84f9 --- /dev/null +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -0,0 +1,194 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +/* +#include +#include +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" +*/ +#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" +#include "paddle/fluid/platform/timer.h" + +namespace paddle { +namespace framework { + +std::shared_ptr PSGPUWrapper::s_instance_ = NULL; +bool PSGPUWrapper::is_initialized_ = false; + +void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim, + std::shared_ptr gpu_task) { + platform::Timer timeline; + timeline.Start(); + int shard_num = gpu_task->feature_keys_.size(); + if (shard_num == 0) { + return; + } + + std::vector feature_keys_count(shard_num); + size_t size_max = 0; + for (int i = 0; i < shard_num; i++) { + feature_keys_count[i] = gpu_task->feature_keys_[i].size(); + size_max = std::max(size_max, feature_keys_count[i]); + } + if (HeterPs_) { + HeterPs_->show_one_table(0); + return; + } + HeterPs_ = HeterPsBase::get_instance(size_max, resource_); + for (int i = 0; i < shard_num; ++i) { + std::cout << "building table: " << i << std::endl; + HeterPs_->build_ps(i, gpu_task->feature_keys_[i].data(), + gpu_task->feature_values_[i].data(), + feature_keys_count[i], 10000, 2); + HeterPs_->show_one_table(i); + } + timeline.Pause(); + VLOG(0) << "GpuPs build table total costs: " << timeline.ElapsedSec() + << " s."; +} + +void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, + const int table_id, + const std::vector& keys, + const std::vector& values, + const std::vector& slot_lengths, + const int hidden_size) { + VLOG(3) << "Begine Gpu Ps PullSparse"; + platform::Timer all_timer; + platform::Timer pull_gpups_timer; + all_timer.Start(); + int64_t total_length = + std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); + auto buf = memory::AllocShared(place, total_length * sizeof(FeatureValue)); + FeatureValue* total_values_gpu = reinterpret_cast(buf->ptr()); + if (platform::is_cpu_place(place)) { + PADDLE_THROW(platform::errors::Unimplemented( + "Warning:: CPUPlace is not supported in GpuPs now.")); + } else if (platform::is_gpu_place(place)) { + VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; + int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId(); + int devid_2_index = HeterPs_->get_index_by_devid(device_id); + LoDTensor& total_keys_tensor = keys_tensor[devid_2_index]; + uint64_t* total_keys = reinterpret_cast( + total_keys_tensor.mutable_data({total_length, 1}, place)); + + // construct slot_level lod info + auto slot_lengths_lod = slot_lengths; + for (size_t i = 1; i < slot_lengths_lod.size(); i++) { + slot_lengths_lod[i] += slot_lengths_lod[i - 1]; + } + auto buf_key = memory::AllocShared(place, keys.size() * sizeof(uint64_t*)); + auto buf_length = + memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t)); + uint64_t** gpu_keys = reinterpret_cast(buf_key->ptr()); + int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); + cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), + cudaMemcpyHostToDevice); + cudaMemcpy(gpu_len, slot_lengths_lod.data(), + slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); + + this->CopyKeys(place, gpu_keys, total_keys, gpu_len, + static_cast(slot_lengths.size()), + static_cast(total_length)); + VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index + << " len: " << total_length; + pull_gpups_timer.Start(); + HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu, + static_cast(total_length)); + // PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( + // "PullSparseGPU failed in GPUPS.")); + pull_gpups_timer.Pause(); + + VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length + << "]"; + this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len, + static_cast(slot_lengths.size()), hidden_size, + total_length); + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "GpuPs: PullSparse Only Support CUDAPlace Now.")); + } + all_timer.Pause(); + VLOG(1) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec() + << " s, of which GPUPS costs: " << pull_gpups_timer.ElapsedSec() + << " s"; + VLOG(3) << "End PullSparse"; +} + +void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, + const int table_id, + const std::vector& keys, + const std::vector& grad_values, + const std::vector& slot_lengths, + const int hidden_size, const int batch_size) { + VLOG(3) << "Begin GPUPS PushSparseGrad"; + platform::Timer all_timer; + platform::Timer push_gpups_timer; + all_timer.Start(); + int64_t total_length = + std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); + auto buf = + memory::AllocShared(place, total_length * sizeof(FeaturePushValue)); + FeaturePushValue* total_grad_values_gpu = + reinterpret_cast(buf->ptr()); + if (platform::is_cpu_place(place)) { + PADDLE_THROW(platform::errors::Unimplemented( + "Warning:: CPUPlace is not supported in GPUPS now.")); + } else if (platform::is_gpu_place(place)) { + int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId(); + int devid_2_index = HeterPs_->get_index_by_devid(device_id); + LoDTensor& cached_total_keys_tensor = keys_tensor[devid_2_index]; + uint64_t* total_keys = + reinterpret_cast(cached_total_keys_tensor.data()); + VLOG(3) << "Begin copy grad tensor to gpups struct"; + this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths, + hidden_size, total_length, batch_size); + + VLOG(3) << "Begin call PushSparseGPU in GPUPS, dev: " << devid_2_index + << " len: " << total_length; + push_gpups_timer.Start(); + HeterPs_->push_sparse(devid_2_index, total_keys, total_grad_values_gpu, + static_cast(total_length)); + push_gpups_timer.Pause(); + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "GPUPS: PushSparseGrad Only Support CUDAPlace Now.")); + } + all_timer.Pause(); + VLOG(1) << "PushSparseGrad total cost: " << all_timer.ElapsedSec() + << " s, of which GPUPS cost: " << push_gpups_timer.ElapsedSec() + << " s"; + VLOG(3) << "End PushSparseGrad"; +} + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu new file mode 100644 index 00000000000..9b7920acef3 --- /dev/null +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -0,0 +1,182 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_PSLIB +#include +#include +#include +#include +#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace framework { + +__global__ void PullCopy(float** dest, const FeatureValue* src, + const int64_t* len, int hidden, int slot_num, + int total_len, uint64_t** keys) { + CUDA_KERNEL_LOOP(i, total_len) { + int low = 0; + int high = slot_num - 1; + while (low < high) { + int mid = (low + high) / 2; + if (i < len[mid]) + high = mid; + else + low = mid + 1; + } + int x = low; + int y = i - (x ? len[x - 1] : 0); + if (*(keys[x] + y) == 0) { + *(dest[x] + y * hidden) = 0; + *(dest[x] + y * hidden + 1) = 0; + *(dest[x] + y * hidden + 2) = 0; + } else { + *(dest[x] + y * hidden) = (src + i)->show; + *(dest[x] + y * hidden + 1) = (src + i)->clk; + *(dest[x] + y * hidden + 2) = (src + i)->lr; + } + if ((src + i)->mf_size == 0 || *(keys[x] + y) == 0) { + for (int j = 0; j < 8; j++) { + *(dest[x] + y * hidden + 3 + j) = 0; + } + } else { + for (int j = 0; j < 8; j++) { + *(dest[x] + y * hidden + 3 + j) = (src + i)->mf[1 + j]; + } + } + } +} + +__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys, + const int64_t* len, int slot_num, + int total_len) { + CUDA_KERNEL_LOOP(i, total_len) { + int low = 0; + int high = slot_num - 1; + while (low < high) { + int mid = (low + high) / 2; + if (i < len[mid]) + high = mid; + else + low = mid + 1; + } + int x = low; + int y = i - (x ? len[x - 1] : 0); + dest_total_keys[i] = src_keys[x][y]; + } +} + +__global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len, + int hidden, int slot_num, int total_len, int bs, + int* slot_vector) { + CUDA_KERNEL_LOOP(i, total_len) { + int low = 0; + int high = slot_num - 1; + while (low < high) { + int mid = (low + high) / 2; + if (i < len[mid]) + high = mid; + else + low = mid + 1; + } + int x = low; + int y = i - (x ? len[low - 1] : 0); + (dest + i)->slot = slot_vector[x]; + (dest + i)->show = *(src[x] + y * hidden); + (dest + i)->clk = *(src[x] + y * hidden + 1); + (dest + i)->lr_g = *(src[x] + y * hidden + 2) * -1. * bs; + for (int j = 0; j < 8; j++) { + (dest + i)->mf_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs; + } + } +} + +void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, + uint64_t** gpu_keys, + const std::vector& values, + const FeatureValue* total_values_gpu, + const int64_t* gpu_len, const int slot_num, + const int hidden_size, + const int64_t total_length) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + auto buf_value = memory::AllocShared(place, values.size() * sizeof(float*)); + float** gpu_values = reinterpret_cast(buf_value->ptr()); + cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*), + cudaMemcpyHostToDevice); + + PullCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( + gpu_values, total_values_gpu, gpu_len, hidden_size, slot_num, + total_length, gpu_keys); + cudaStreamSynchronize(stream); +} + +void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, + uint64_t** origin_keys, uint64_t* total_keys, + const int64_t* gpu_len, int slot_num, + int total_len) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + CopyKeysKernel<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>( + origin_keys, total_keys, gpu_len, slot_num, total_len); + cudaStreamSynchronize(stream); +} + +void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, + const std::vector& grad_values, + FeaturePushValue* total_grad_values_gpu, + const std::vector& slot_lengths, + const int hidden_size, + const int64_t total_length, + const int batch_size) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + auto slot_lengths_lod = slot_lengths; + for (int i = 1; i < slot_lengths_lod.size(); i++) { + slot_lengths_lod[i] += slot_lengths_lod[i - 1]; + } + auto buf_grad_value = + memory::AllocShared(place, grad_values.size() * sizeof(float*)); + auto buf_length = + memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t)); + auto buf_slot_vector = + memory::AllocShared(place, slot_lengths_lod.size() * sizeof(int)); + + float** gpu_values = reinterpret_cast(buf_grad_value->ptr()); + int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); + int* d_slot_vector = reinterpret_cast(buf_slot_vector->ptr()); + + cudaMemcpy(gpu_values, grad_values.data(), + grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice); + cudaMemcpy(gpu_len, slot_lengths_lod.data(), + slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_slot_vector, slot_vector_.data(), + slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); + + PushCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( + total_grad_values_gpu, gpu_values, gpu_len, hidden_size, + slot_lengths.size(), total_length, batch_size, d_slot_vector); + cudaStreamSynchronize(stream); +} +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h new file mode 100644 index 00000000000..df6af23d701 --- /dev/null +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/fleet/heter_context.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { + +class PSGPUWrapper { + public: + virtual ~PSGPUWrapper() { delete HeterPs_; } + + PSGPUWrapper() { + HeterPs_ = NULL; + sleep_seconds_before_fail_exit_ = 300; + } + + void PullSparse(const paddle::platform::Place& place, const int table_id, + const std::vector& keys, + const std::vector& values, + const std::vector& slot_lengths, + const int hidden_size); + void PushSparseGrad(const paddle::platform::Place& place, const int table_id, + const std::vector& keys, + const std::vector& grad_values, + const std::vector& slot_lengths, + const int hidden_size, const int batch_size); + void CopyKeys(const paddle::platform::Place& place, uint64_t** origin_keys, + uint64_t* total_keys, const int64_t* gpu_len, int slot_num, + int total_len); + void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys, + const std::vector& values, + const FeatureValue* total_values_gpu, const int64_t* gpu_len, + const int slot_num, const int hidden_size, + const int64_t total_length); + + void CopyForPush(const paddle::platform::Place& place, + const std::vector& grad_values, + FeaturePushValue* total_grad_values_gpu, + const std::vector& slot_lengths, + const int hidden_size, const int64_t total_length, + const int batch_size); + + void BuildGPUPS(const uint64_t table_id, int feature_dim, + std::shared_ptr context); + void InitializeGPU(const std::vector& dev_ids) { + if (s_instance_ != NULL) { + VLOG(3) << "PSGPUWrapper Begin InitializeGPU"; + resource_ = std::make_shared(dev_ids); + resource_->enable_p2p(); + keys_tensor.resize(resource_->total_gpu()); + } + } + // PSGPUWrapper singleton + static std::shared_ptr GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new paddle::framework::PSGPUWrapper()); + } + return s_instance_; + } + std::vector>>& GetLocalTable( + int table_id) { + return local_tables_[table_id]; + } + void SetSlotVector(const std::vector& slot_vector) { + slot_vector_ = slot_vector; + } + + private: + static std::shared_ptr s_instance_; + std::unordered_map< + uint64_t, std::vector>>> + local_tables_; + HeterPsBase* HeterPs_; + std::vector keys_tensor; // Cache for pull_sparse + std::shared_ptr resource_; + int32_t sleep_seconds_before_fail_exit_; + std::vector slot_vector_; + + protected: + static bool is_initialized_; +}; + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/ps_gpu_trainer.cc b/paddle/fluid/framework/ps_gpu_trainer.cc new file mode 100644 index 00000000000..530750d98ac --- /dev/null +++ b/paddle/fluid/framework/ps_gpu_trainer.cc @@ -0,0 +1,404 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "io/fs.h" +#include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/data_set.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/fleet/heter_context.h" +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" +#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" +#include "paddle/fluid/framework/trainer.h" +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +#include "paddle/fluid/platform/cuda_device_guard.h" + +namespace paddle { +namespace framework { + +void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc, + Dataset* dataset) { + dataset_ = dataset; + thread_num_ = trainer_desc.thread_num(); + param_ = trainer_desc.downpour_param(); + for (int i = 0; i < param_.dense_table_size(); ++i) { + uint64_t table_id = static_cast(param_.dense_table(i).table_id()); + auto table = param_.dense_table(i); + dense_grad_names_[table_id].resize(table.dense_grad_name_size()); + for (int j = 0; j < table.dense_grad_name_size(); ++j) { + dense_grad_names_[table_id][j] = table.dense_grad_name(j); + } + } + scale_datanorm_ = trainer_desc.scale_datanorm(); + int place_num = trainer_desc.worker_places_size(); + const std::vector readers = + dataset->GetReaders(); + std::vector dev_ids; + for (int i = 0; i < place_num; ++i) { + int num = trainer_desc.worker_places(i); + platform::CUDAPlace place = platform::CUDAPlace(num); + places_.push_back(place); + dev_ids.push_back(num); + } + for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size(); + i++) { + need_merge_var_names_.push_back( + trainer_desc.downpour_param().stat_var_names(i)); + } + VLOG(3) << "going to initialize pull dense worker"; + pull_dense_worker_ = PullDenseWorker::GetInstance(); + pull_dense_worker_->Initialize(trainer_desc); + SetDebug(trainer_desc.debug()); + fleet_ptr_ = FleetWrapper::GetInstance(); + trainer_desc_ = trainer_desc; + workers_.resize(place_num); + for (int i = 0; i < place_num; ++i) { + workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( + trainer_desc.device_worker_name()); + workers_[i]->SetDeviceIndex(i); + workers_[i]->SetDataFeed(readers[i]); + workers_[i]->Initialize(trainer_desc); + workers_[i]->SetWorkerNum(place_num); + } + auto gpu_ps_wrapper = PSGPUWrapper::GetInstance(); + gpu_ps_wrapper->InitializeGPU(dev_ids); + return; +} + +void PSGPUTrainer::DumpWork(int tid) {} + +void PSGPUTrainer::RegisterHeterCallback() { + /* + auto fleet_ptr = FleetWrapper::GetInstance(); + fleet_ptr->RegisterHeterCallback([this](int worker, int taskid) { + // workers_[worker]->Schedule(taskid); + }); + */ +} + +void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program, + const platform::Place& place) { + for (size_t i = 0; i < places_.size(); ++i) { + workers_[i]->SetPlace(places_[i]); + workers_[i]->SetReaderPlace(places_[i]); + workers_[i]->SetRootScope(root_scope_); + workers_[i]->CreateDeviceResource(main_program); // Program + workers_[i]->BindingDataFeedMemory(); + } + for (size_t num = 0; num < places_.size(); ++num) { + auto place = places_[num]; + Scope* scope = workers_[num]->GetThreadScope(); + auto& block = main_program.Block(0); + for (auto& var : block.AllVars()) { + if (var->Persistable()) { + auto name = var->Name(); + Variable* root_var = root_scope_->FindVar(name); + if (!root_var) { + continue; + } + LoDTensor* root_tensor = root_var->GetMutable(); + auto* ptr = scope->Var(name); + InitializeVariable(ptr, proto::VarType::LOD_TENSOR); + LoDTensor* thread_tensor = ptr->GetMutable(); + TensorCopy(*root_tensor, place, thread_tensor); + } + } + } + place_ = place; + return; +} + +void PSGPUTrainer::InitOtherEnv(const ProgramDesc& main_program) { + pull_dense_worker_->SetRootScope(root_scope_); + for (size_t i = 0; i < places_.size(); ++i) { + pull_dense_worker_->AddThreadScope(workers_[i]->GetThreadScope()); + } + VLOG(3) << "init other env done."; +} + +void PSGPUTrainer::Run() { + BuildGPUPSTask(0, 8); + for (size_t thidx = 0; thidx < places_.size(); ++thidx) { + threads_.push_back( + std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get())); + } +} +void PSGPUTrainer::BuildGPUPSTask(int table_id, int feadim) { + VLOG(3) << "PSGPUTrainer::BuildGPUPSTask begin"; + platform::Timer timeline; + timeline.Start(); + MultiSlotDataset* dataset = dynamic_cast(dataset_); + auto fleet_ptr = FleetWrapper::GetInstance(); + std::shared_ptr heter_context = + std::make_shared(); + auto& multi_output_channel = dataset->GetCurOutputChannel(); + auto& input_channel = dataset->GetInputChannelRef(); + int gen_shard_num = multi_output_channel.size(); + int device_num = places_.size(); + auto gpu_ps_wrapper = PSGPUWrapper::GetInstance(); + auto& local_keys = heter_context->feature_keys_; + local_keys.resize(device_num); + auto& local_values = heter_context->feature_values_; + local_values.resize(device_num); + auto& local_ptr = heter_context->value_ptr_; + local_ptr.resize(device_num); + for (auto& ks : local_keys) { + ks.reserve(100000); + } + // read thread + std::vector threads(gen_shard_num); + std::vector> consume_task_pool(device_num); + for (size_t i = 0; i < consume_task_pool.size(); i++) { + consume_task_pool[i].reset(new ::ThreadPool(1)); + } + auto consume_func = [&local_keys](int shard_id, int feadim, + std::vector& keys) { + local_keys[shard_id].insert(local_keys[shard_id].end(), keys.begin(), + keys.end()); + }; + + if (input_channel->Size() == 0) { + // output_channel_ should hold one pass instances now + uint64_t output_channels_data_size = 0; + for (size_t i = 0; i < multi_output_channel.size(); i++) { + int cur_channel_size = multi_output_channel[i]->Size(); + output_channels_data_size += cur_channel_size; + } + CHECK(output_channels_data_size > 0); + for (auto& ks : local_keys) { + ks.reserve(output_channels_data_size * 10); // magic number + } + auto gen_func = [&dataset, &device_num, &feadim, &consume_task_pool, + &multi_output_channel, &consume_func](int i) { + const std::deque& vec_data = multi_output_channel[i]->GetData(); + std::vector> task_keys(device_num); + std::vector> task_futures; + for (size_t j = 0; j < vec_data.size(); j++) { + for (auto& feature : vec_data[j].uint64_feasigns_) { + int shard = feature.sign().uint64_feasign_ % device_num; + task_keys[shard].push_back(feature.sign().uint64_feasign_); + } + } + + for (int shard_id = 0; shard_id < device_num; shard_id++) { + task_futures.emplace_back(consume_task_pool[shard_id]->enqueue( + consume_func, shard_id, feadim, task_keys[shard_id])); + } + + for (auto& tf : task_futures) { + tf.wait(); + } + for (auto& tk : task_keys) { + tk.clear(); + std::vector().swap(tk); + } + task_keys.clear(); + std::vector>().swap(task_keys); + }; + for (size_t i = 0; i < threads.size(); i++) { + threads[i] = std::thread(gen_func, i); + } + for (std::thread& t : threads) { + t.join(); + } + } else { + int input_channel_size = input_channel->Size(); + CHECK(input_channel_size > 0); + CHECK(gen_shard_num > 0); + for (auto& ks : local_keys) { + ks.reserve(input_channel_size * 10); // magic number + } + const std::deque& vec_data = input_channel->GetData(); + auto gen_func = [&dataset, &vec_data, &device_num, &gen_shard_num, + &input_channel_size, &feadim, &consume_task_pool, + multi_output_channel, &consume_func](int i) { + std::vector> task_keys(device_num); + std::vector> task_futures; + size_t per_shard_num = input_channel_size / gen_shard_num + 1; + size_t total_size = vec_data.size(); + size_t start_index = i * per_shard_num; + size_t end_index = + std::min(start_index + per_shard_num - 1, total_size - 1); + for (size_t j = start_index; j <= end_index; j++) { + for (auto& feature : vec_data[j].uint64_feasigns_) { + int shard = feature.sign().uint64_feasign_ % device_num; + task_keys[shard].push_back(feature.sign().uint64_feasign_); + } + } + + for (int shard_id = 0; shard_id < device_num; shard_id++) { + task_futures.emplace_back(consume_task_pool[shard_id]->enqueue( + consume_func, shard_id, feadim, task_keys[shard_id])); + } + + for (auto& tf : task_futures) { + tf.wait(); + } + for (auto& tk : task_keys) { + tk.clear(); + std::vector().swap(tk); + } + task_keys.clear(); + std::vector>().swap(task_keys); + }; + for (size_t i = 0; i < threads.size(); i++) { + threads[i] = std::thread(gen_func, i); + } + for (std::thread& t : threads) { + t.join(); + } + } + timeline.Pause(); + VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; + timeline.Start(); + auto unique_func = [&local_keys](int i) { + auto& cur_keys = local_keys[i]; + std::sort(cur_keys.begin(), cur_keys.end()); + cur_keys.erase(std::unique(cur_keys.begin(), cur_keys.end()), + cur_keys.end()); + }; + for (size_t i = 0; i < threads.size(); i++) { + threads[i] = std::thread(unique_func, i); + } + for (std::thread& t : threads) { + t.join(); + } + timeline.Pause(); + + VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; + + timeline.Start(); + for (size_t i = 0; i < consume_task_pool.size(); i++) { + consume_task_pool[i].reset(); + } + consume_task_pool.clear(); + + for (int i = 0; i < device_num; i++) { + local_values[i].resize(local_keys[i].size()); + local_ptr[i].resize(local_keys[i].size()); + } + + auto ptl_func = [this, &local_keys, &local_values, &local_ptr, &table_id, + &fleet_ptr](int i) { + size_t key_size = local_keys[i].size(); + auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr( + (char**)(local_ptr[i].data()), table_id, local_keys[i].data(), + key_size); + tt.wait(); + auto status = tt.get(); + // auto status = 0; + if (status != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; + sleep(300); + exit(-1); + } else { + VLOG(3) << "FleetWrapper Pull sparse to local done with table size: " + << local_keys[i].size(); + } + for (size_t num = 0; num < local_ptr[i].size(); ++num) { + float* ptr_val = local_ptr[i][num]->data(); + FeatureValue& val = local_values[i][num]; + size_t dim = local_ptr[i][num]->size(); + + val.delta_score = ptr_val[1]; + val.show = ptr_val[2]; + val.clk = ptr_val[3]; + val.slot = ptr_val[6]; + val.lr = ptr_val[4]; + val.lr_g2sum = ptr_val[5]; + + if (dim > 7) { + val.mf_size = MF_DIM + 1; + for (int x = 0; x < val.mf_size; x++) { + val.mf[x] = ptr_val[x + 7]; + } + } else { + val.mf_size = 0; + for (int x = 0; x < MF_DIM + 1; x++) { + val.mf[x] = 0; + } + } + } + }; + for (size_t i = 0; i < threads.size(); i++) { + threads[i] = std::thread(ptl_func, i); + } + for (std::thread& t : threads) { + t.join(); + } + timeline.Pause(); + VLOG(0) << "GpuPs pull sparse cost " << timeline.ElapsedSec() << " seconds."; + gpu_ps_wrapper->BuildGPUPS(table_id, feadim, heter_context); +} + +Scope* PSGPUTrainer::GetWorkerScope(int thread_id) { return nullptr; } + +template +void PSGPUTrainer::MergeToRootScope(LoDTensor* root_tensor, LoDTensor* tensor) { + LoDTensor tmp_root; + TensorCopy(*root_tensor, platform::CPUPlace(), &tmp_root); + T* tmp_root_data = tmp_root.data(); + LoDTensor tmp_tensor; + TensorCopy(*tensor, platform::CPUPlace(), &tmp_tensor); + T* data = tmp_tensor.data(); + for (int i = 0; i < tmp_tensor.numel(); i++) { + tmp_root_data[i] += data[i]; + } + TensorCopy(tmp_root, platform::CPUPlace(), root_tensor); +} + +void PSGPUTrainer::Finalize() { + for (auto& th : threads_) { + th.join(); + } + for (size_t i = 0; i < need_merge_var_names_.size(); i++) { + Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]); + if (root_var == nullptr) { + continue; + } + LoDTensor* root_tensor = root_var->GetMutable(); + + for (size_t j = 0; j < places_.size(); j++) { + Scope* cur_thread_scope = workers_[j]->GetThreadScope(); + Variable* thread_var = + cur_thread_scope->FindVar(need_merge_var_names_[i]); + if (thread_var == nullptr) { + continue; + } + LoDTensor* thread_tensor = thread_var->GetMutable(); +#define MergeCallback(cpp_type, proto_type) \ + do { \ + if (root_tensor->type() == proto_type) { \ + if (thread_tensor->type() != proto_type) { \ + VLOG(0) << "Error: thread id=" << j << ", need_merge_var_names_[" << i \ + << "] " << need_merge_var_names_[i] \ + << ", root tensor type=" << root_tensor->type() \ + << ", thread tensor type=" << thread_tensor->type(); \ + exit(-1); \ + } \ + MergeToRootScope(root_tensor, thread_tensor); \ + } \ + } while (0) + _ForEachDataType_(MergeCallback); + } + } + pull_dense_worker_->MergeDenseParam(); + root_scope_->DropKids(); +} +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/ps_gpu_worker.cc b/paddle/fluid/framework/ps_gpu_worker.cc new file mode 100644 index 00000000000..b965b8a2dc8 --- /dev/null +++ b/paddle/fluid/framework/ps_gpu_worker.cc @@ -0,0 +1,196 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/string/string_helper.h" + +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +#include "paddle/fluid/platform/cuda_device_guard.h" + +#if defined _WIN32 || defined __APPLE__ +#else +#define _LINUX +#endif + +namespace paddle { +namespace framework { + +void PSGPUWorker::Initialize(const TrainerDesc& desc) { + param_ = desc.downpour_param(); + mpi_rank_ = desc.mpi_rank(); + trainer_desc_ = desc; + /* + for (int i = 0; i < trainer_desc_.xpu_recv_list_size(); ++i) { + send_var_list_.push_back(trainer_desc_.xpu_recv_list(i)); + } + */ + for (int i = 0; i < param_.sparse_table_size(); ++i) { + uint64_t table_id = + static_cast(param_.sparse_table(i).table_id()); + TableParameter table = param_.sparse_table(i); + sparse_key_names_[table_id].resize(table.sparse_key_name_size()); + for (int j = 0; j < table.sparse_key_name_size(); ++j) { + sparse_key_names_[table_id][j] = table.sparse_key_name(j); + } + sparse_value_names_[table_id].resize(table.sparse_value_name_size()); + for (int j = 0; j < table.sparse_value_name_size(); ++j) { + sparse_value_names_[table_id][j] = table.sparse_value_name(j); + } + sparse_grad_names_[table_id].resize(table.sparse_grad_name_size()); + for (int j = 0; j < table.sparse_grad_name_size(); ++j) { + sparse_grad_names_[table_id][j] = table.sparse_grad_name(j); + } + label_var_name_[table_id] = table.label_var_name(); + sparse_push_keys_[table_id] = std::vector(); + } + + for (int i = 0; i < param_.dense_table_size(); ++i) { + uint64_t table_id = static_cast(param_.dense_table(i).table_id()); + auto table = param_.dense_table(i); + dense_value_names_[table_id].resize(table.dense_value_name_size()); + for (int j = 0; j < table.dense_value_name_size(); ++j) { + dense_value_names_[table_id][j] = table.dense_value_name(j); + } + dense_grad_names_[table_id].resize(table.dense_grad_name_size()); + for (int j = 0; j < table.dense_grad_name_size(); ++j) { + dense_grad_names_[table_id][j] = table.dense_grad_name(j); + } + } + + skip_ops_.resize(param_.skip_ops_size()); + for (int i = 0; i < param_.skip_ops_size(); ++i) { + skip_ops_[i] = param_.skip_ops(i); + } + for (int i = 0; i < param_.stat_var_names_size(); ++i) { + stat_var_name_map_[param_.stat_var_names(i)] = 1; + } + + need_to_push_sparse_ = param_.push_sparse(); + need_to_push_dense_ = param_.push_dense(); + + fetch_config_ = desc.fetch_config(); + use_cvm_ = desc.use_cvm(); + // for sparse value accessor, embedding only + no_cvm_ = desc.no_cvm(); + scale_datanorm_ = desc.scale_datanorm(); + dump_slot_ = desc.dump_slot(); + dump_fields_.resize(desc.dump_fields_size()); + for (int i = 0; i < desc.dump_fields_size(); ++i) { + dump_fields_[i] = desc.dump_fields(i); + } + adjust_ins_weight_config_ = desc.adjust_ins_weight_config(); + need_dump_param_ = false; + dump_param_.resize(desc.dump_param_size()); + for (int i = 0; i < desc.dump_param_size(); ++i) { + dump_param_[i] = desc.dump_param(i); + } + if (desc.dump_param_size() != 0) { + need_dump_param_ = true; + } + for (int i = 0; i < desc.check_nan_var_names_size(); ++i) { + check_nan_var_names_.push_back(desc.check_nan_var_names(i)); + } + copy_table_config_ = desc.copy_table_config(); + for (int i = 0; i < copy_table_config_.src_sparse_tables_size(); ++i) { + uint64_t src_table = copy_table_config_.src_sparse_tables(i); + uint64_t dest_table = copy_table_config_.dest_sparse_tables(i); + VLOG(3) << "copy_sparse_tables_ push back " << src_table << "->" + << dest_table; + copy_sparse_tables_.push_back(std::make_pair(src_table, dest_table)); + } + for (int i = 0; i < copy_table_config_.src_dense_tables_size(); ++i) { + uint64_t src_table = copy_table_config_.src_dense_tables(i); + uint64_t dest_table = copy_table_config_.dest_dense_tables(i); + VLOG(3) << "copy_dense_tables_ push back " << src_table << "->" + << dest_table; + copy_dense_tables_.push_back(std::make_pair(src_table, dest_table)); + } + for (auto& m : copy_table_config_.table_denpendency_map()) { + if (sparse_key_names_.find(m.key()) != sparse_key_names_.end()) { + // currently only support one dependency + for (auto& value : m.values()) { + table_dependency_[m.key()] = value; + } + } + } + // pull_queue_ = paddle::framework::MakeChannel>(); + // push_queue_ = paddle::framework::MakeChannel>(); +} + +void PSGPUWorker::SetChannelWriter(ChannelObject* queue) { + writer_.Reset(queue); +} + +void PSGPUWorker::SetNeedDump(bool need_dump_field) { + need_dump_field_ = need_dump_field; +} + +void PSGPUWorker::DumpParam() {} + +void PSGPUWorker::TrainFiles() { + VLOG(3) << "train file A"; + platform::SetNumThreads(1); + + VLOG(3) << "train file B"; + // how to accumulate fetched values here + device_reader_->Start(); + VLOG(3) << "train file C"; + int cur_batch; + while ((cur_batch = device_reader_->Next()) > 0) { + VLOG(3) << "train file D"; + for (auto& op : ops_) { + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + op->Run(*thread_scope_, place_); + } + } + + PrintFetchVars(); + thread_scope_->DropKids(); + } + return; +} + +void PSGPUWorker::ResetStat() { + total_time_ = 0; + read_time_ = 0; + pack_time_ = 0; + pull_sparse_local_time_ = 0; + op_all_time_ = 0; + xpu_op_time_ = 0; + xpu_wait_time_ = 0; + cpu_op_time_ = 0; + collect_label_time_ = 0; + fill_sparse_time_ = 0; + push_sparse_time_ = 0; + gpu_2_cpu_time_ = 0; + cpu_2_gpu_time_ = 0; + total_inst_ = 0; +} + +void PSGPUWorker::ProduceTasks() { return; } + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index be85247c7ea..25b215df3e4 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -277,6 +277,55 @@ class HeterBoxTrainer : public TrainerBase { }; #endif +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +class PSGPUTrainer : public TrainerBase { + public: + PSGPUTrainer() {} + virtual ~PSGPUTrainer() {} + virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set); + virtual void InitTrainerEnv(const ProgramDesc& main_program, + const platform::Place& place); + virtual void InitOtherEnv(const ProgramDesc& main_program); + virtual void Run(); + virtual void Finalize(); + virtual void RegisterHeterCallback(); + virtual void DumpWork(int tid); + virtual Scope* GetWorkerScope(int thread_id); + virtual void CacheProgram(const ProgramDesc& main_program) { + new (&program_) ProgramDesc(main_program); + } + virtual std::string GetDumpPath(int tid) { return ""; } + virtual void InitDumpEnv() {} + void BuildGPUPSTask(int table_id, int feadim); + /* + template + void HeterMemCpy(LoDTensor* tensor, LoDTensor* root_tensor, + const paddle::platform::Place& thread_place, + cudaStream_t stream); + */ + + template + void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor); + + protected: + Dataset* dataset_; + DownpourWorkerParameter param_; + std::map> dense_grad_names_; + std::vector need_merge_var_names_; + float scale_datanorm_; + paddle::platform::Place place_; + ProgramDesc program_; + std::shared_ptr fleet_ptr_; + std::shared_ptr pull_dense_worker_; + std::vector> workers_; + std::vector places_; + // ps-gpu + std::vector threads_; + int use_ps_gpu_; + int thread_num_; +}; +#endif + #if defined(PADDLE_WITH_NCCL) class PipelineTrainer : public TrainerBase { public: diff --git a/paddle/fluid/framework/trainer_factory.cc b/paddle/fluid/framework/trainer_factory.cc index 087d1ea0af8..226f62701d8 100644 --- a/paddle/fluid/framework/trainer_factory.cc +++ b/paddle/fluid/framework/trainer_factory.cc @@ -68,6 +68,9 @@ REGISTER_TRAINER_CLASS(DistMultiTrainer); REGISTER_TRAINER_CLASS(HeterXpuTrainer); REGISTER_TRAINER_CLASS(HeterBoxTrainer); #endif +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +REGISTER_TRAINER_CLASS(PSGPUTrainer); +#endif #if defined(PADDLE_WITH_NCCL) REGISTER_TRAINER_CLASS(PipelineTrainer); #endif diff --git a/paddle/fluid/operators/pull_box_sparse_op.cc b/paddle/fluid/operators/pull_box_sparse_op.cc index 5b62edda247..d680fe11047 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_sparse_op.cc @@ -64,12 +64,23 @@ class PullBoxSparseOp : public framework::OperatorWithKernel { class PullBoxSparseOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { + AddInput("W", + "(Tensor) The input represents embedding tensors, " + "which is a learnable parameter.") + .AsDispensable(); AddInput("Ids", "Input tensors with type int32 or int64 " "contains the ids to be looked up in BoxPS. " "The last dimension size must be 1.") .AsDuplicable(); AddOutput("Out", "The lookup results tensors.").AsDuplicable(); + AddAttr("is_sparse", + "(boolean, default false) " + "Sparse update.") + .SetDefault(false); + AddAttr("is_distributed", + "(boolean, default false) distributed lookup table.") + .SetDefault(false); AddAttr("size", "(int, the embedding hidden size").SetDefault(1); AddComment(R"DOC( Pull Box Sparse Operator. diff --git a/paddle/fluid/operators/pull_box_sparse_op.h b/paddle/fluid/operators/pull_box_sparse_op.h index 3b48341368c..48e42c32324 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.h +++ b/paddle/fluid/operators/pull_box_sparse_op.h @@ -16,6 +16,7 @@ #include #include #include "paddle/fluid/framework/fleet/box_wrapper.h" +#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" @@ -46,6 +47,12 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) { box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths, hidden_size, 0); #endif +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) + auto hidden_size = ctx.Attr("size"); + auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance(); + gpu_ps_ptr->PullSparse(ctx.GetPlace(), 0, all_keys, all_values, slot_lengths, + hidden_size); +#endif } template @@ -83,6 +90,12 @@ static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) { box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values, slot_lengths, hidden_size, 0, batch_size); #endif +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) + auto hidden_size = ctx.Attr("size"); + auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance(); + gpu_ps_ptr->PushSparseGrad(ctx.GetPlace(), 0, all_keys, all_grad_values, + slot_lengths, hidden_size, batch_size); +#endif } using LoDTensor = framework::LoDTensor; diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index bc1ab96528c..e9bda383bb0 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,7 +1,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context - gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry) + gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper) if (WITH_NCCL) set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper) @@ -33,6 +33,7 @@ set(PYBIND_SRCS reader_py.cc fleet_wrapper_py.cc heter_wrapper_py.cc + ps_gpu_wrapper_py.cc gloo_wrapper_py.cc box_helper_py.cc data_set_py.cc diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 1e70bd9381b..4b72b09addd 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -57,11 +57,7 @@ void BindFleetWrapper(py::module* m) { .def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold) .def("cache_shuffle", &framework::FleetWrapper::CacheShuffle) .def("save_cache", &framework::FleetWrapper::SaveCache) - .def("save_model_with_whitelist", - &framework::FleetWrapper::SaveWithWhitelist) .def("load_model", &framework::FleetWrapper::LoadModel) - .def("load_table_with_whitelist", - &framework::FleetWrapper::LoadWithWhitelist) .def("clear_model", &framework::FleetWrapper::ClearModel) .def("clear_one_table", &framework::FleetWrapper::ClearOneTable) .def("stop_server", &framework::FleetWrapper::StopServer) diff --git a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc new file mode 100644 index 00000000000..0bbe8091975 --- /dev/null +++ b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#ifdef _POSIX_C_SOURCE +#undef _POSIX_C_SOURCE +#endif + +#ifdef _XOPEN_SOURCE +#undef _XOPEN_SOURCE +#endif + +#include +#include + +#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" +#include "paddle/fluid/pybind/ps_gpu_wrapper_py.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +void BindPSGPUWrapper(py::module* m) { + py::class_>( + *m, "PSGPU") + .def(py::init([]() { return framework::PSGPUWrapper::GetInstance(); })) + .def("set_slot_vector", &framework::PSGPUWrapper::SetSlotVector, + py::call_guard()); +} // end PSGPUWrapper +#endif +} // end namespace pybind +} // end namespace paddle diff --git a/paddle/fluid/pybind/ps_gpu_wrapper_py.h b/paddle/fluid/pybind/ps_gpu_wrapper_py.h new file mode 100644 index 00000000000..4048e88a55a --- /dev/null +++ b/paddle/fluid/pybind/ps_gpu_wrapper_py.h @@ -0,0 +1,29 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +void BindPSGPUWrapper(py::module* m); +#endif +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 5cefb26a4a3..f7b1c3523fd 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -79,6 +79,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/inference_api.h" #include "paddle/fluid/pybind/ir.h" +#include "paddle/fluid/pybind/ps_gpu_wrapper_py.h" #include "paddle/fluid/pybind/pybind_boost_headers.h" #ifdef PADDLE_WITH_NCCL @@ -2809,8 +2810,12 @@ All parameter, weight, gradient are variables in Paddle. .def("device_count", &ParallelExecutor::DeviceCount); BindFleetWrapper(&m); + #ifdef PADDLE_WITH_PSLIB BindHeterWrapper(&m); +#endif +#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) + BindPSGPUWrapper(&m); #endif BindGlooWrapper(&m); BindBoxHelper(&m); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 57e44fca9ca..9b17d61c33c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1375,8 +1375,6 @@ class Executor(object): is_heter = 1 if program._fleet_opt.get("trainer", "") == "HeterXpuTrainer": is_heter = 1 - if program._fleet_opt.get("use_ps_gpu", ""): - is_heter = 1 if scope is None: scope = global_scope() if fetch_list is None: diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 727cc2b1b54..f83dfd6a4eb 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -85,7 +85,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ".batch_size@GRAD", ".batch_square_sum@GRAD", ".batch_sum@GRAD" ] self.supported_embedding_types = [ - "lookup_table", "pull_sparse", "pull_sparse_v2" + "lookup_table", "pull_sparse", "pull_sparse_v2", "pull_box_sparse" ] self.supported_embedding_grad_types = [ "lookup_table_grad", "push_sparse", "push_sparse_v2" diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6c6820d52be..45f22460a9c 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -663,7 +663,11 @@ def _pull_sparse_v2(input, return outs -def _pull_box_sparse(input, size, dtype='float32'): +def _pull_box_sparse(input, + size, + dtype='float32', + is_distributed=False, + is_sparse=False): r""" **Pull Box Sparse Layer** @@ -701,11 +705,18 @@ def _pull_box_sparse(input, size, dtype='float32'): helper.create_variable_for_type_inference(dtype) for i in range(len(inputs)) ] + w = helper.create_parameter( + attr=helper.param_attr, shape=[size], dtype=dtype, is_bias=False) helper.append_op( type='pull_box_sparse', - inputs={'Ids': inputs}, + inputs={'Ids': inputs, + 'W': w}, outputs={'Out': outs}, - attrs={'size': size}) + attrs={ + 'size': size, + 'is_distributed': is_distributed, + 'is_sparse': is_sparse + }) if len(outs) == 1: return outs[0] return outs diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index d1fb843b566..989db9efea1 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -370,6 +370,30 @@ class HeterBoxTrainer(TrainerDesc): self._device_worker._gen_worker_desc(self.proto_desc) +class PSGPUTrainer(TrainerDesc): + """ + Implement of PSGPUTrainer. + It's for Distributed training. + """ + + def __init__(self): + super(PSGPUTrainer, self).__init__() + pass + + def _set_program(self, program): + super(PSGPUTrainer, self)._set_program(program) + self._program = program + + def _gen_trainer_desc(self): + super(PSGPUTrainer, self)._gen_trainer_desc() + self.proto_desc.class_name = "PSGPUTrainer" + if self._program == None: + raise RuntimeError("None Program") + self._device_worker._set_infer(self._infer) + self._device_worker._set_program(self._program) + self._device_worker._gen_worker_desc(self.proto_desc) + + class PipelineTrainer(TrainerDesc): """ Implement of PipelineTrainer. diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index 5aff7811330..c61141bcd32 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -22,7 +22,7 @@ from paddle.fluid.log_helper import get_logger local_logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') -from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer, HeterBoxTrainer +from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer, HeterBoxTrainer, PSGPUTrainer from .device_worker import Hogwild, DownpourSGD, Section, DownpourSGDOPT from .framework import Variable from multiprocessing import Process, Manager -- GitLab