From 988b5fe16bab886fb64a4f36ff57c0b0e2efb1d1 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Thu, 27 May 2021 11:25:55 +0800 Subject: [PATCH] [PsCore] support ssd (#33031) * support ssd in PsCore * remove log * remove bz2 * defalut value * code style * parse table class * code style * add define --- cmake/external/rocksdb.cmake | 51 +++ cmake/third_party.cmake | 5 + paddle/fluid/distributed/fleet.cc | 11 +- paddle/fluid/distributed/fleet.h | 2 +- .../distributed/service/ps_local_client.cc | 11 +- .../distributed/service/ps_local_server.h | 7 +- paddle/fluid/distributed/service/server.h | 2 +- paddle/fluid/distributed/table/CMakeLists.txt | 15 +- .../distributed/table/common_sparse_table.cc | 104 +---- .../distributed/table/common_sparse_table.h | 94 ++++- .../table/depends/large_scale_kv.h | 32 +- .../table/depends/rocksdb_warpper.h | 158 ++++++++ .../distributed/table/ssd_sparse_table.cc | 362 ++++++++++++++++++ .../distributed/table/ssd_sparse_table.h | 61 +++ paddle/fluid/distributed/table/table.cc | 6 + paddle/fluid/operators/lookup_table_op.cc | 5 + paddle/fluid/pybind/fleet_py.cc | 2 + python/paddle/distributed/fleet/__init__.py | 1 + .../distributed/fleet/base/fleet_base.py | 23 ++ .../distributed/fleet/runtime/the_one_ps.py | 28 +- python/paddle/fluid/contrib/layers/nn.py | 8 +- .../fleet/parameter_server/ir/trainer_pass.py | 34 ++ 22 files changed, 914 insertions(+), 108 deletions(-) create mode 100644 cmake/external/rocksdb.cmake create mode 100644 paddle/fluid/distributed/table/depends/rocksdb_warpper.h create mode 100644 paddle/fluid/distributed/table/ssd_sparse_table.cc create mode 100644 paddle/fluid/distributed/table/ssd_sparse_table.h diff --git a/cmake/external/rocksdb.cmake b/cmake/external/rocksdb.cmake new file mode 100644 index 00000000000..f5b85cc71a2 --- /dev/null +++ b/cmake/external/rocksdb.cmake @@ -0,0 +1,51 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INCLUDE(ExternalProject) + +SET(ROCKSDB_SOURCES_DIR ${THIRD_PARTY_PATH}/rocksdb) +SET(ROCKSDB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/rocksdb) +SET(ROCKSDB_INCLUDE_DIR "${ROCKSDB_INSTALL_DIR}/include" CACHE PATH "rocksdb include directory." FORCE) +SET(ROCKSDB_LIBRARIES "${ROCKSDB_INSTALL_DIR}/lib/librocksdb.a" CACHE FILEPATH "rocksdb library." FORCE) +SET(ROCKSDB_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") +INCLUDE_DIRECTORIES(${ROCKSDB_INCLUDE_DIR}) + +ExternalProject_Add( + extern_rocksdb + ${EXTERNAL_PROJECT_LOG_ARGS} + PREFIX ${ROCKSDB_SOURCES_DIR} + GIT_REPOSITORY "https://github.com/facebook/rocksdb" + GIT_TAG v6.10.1 + UPDATE_COMMAND "" + CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DWITH_BZ2=OFF + -DWITH_GFLAGS=OFF + -DCMAKE_CXX_FLAGS=${ROCKSDB_CMAKE_CXX_FLAGS} + -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} +# BUILD_BYPRODUCTS ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a + INSTALL_COMMAND mkdir -p ${ROCKSDB_INSTALL_DIR}/lib/ + && cp ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a ${ROCKSDB_LIBRARIES} + && cp -r ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/include ${ROCKSDB_INSTALL_DIR}/ + BUILD_IN_SOURCE 1 +) + +ADD_DEPENDENCIES(extern_rocksdb snappy) + +ADD_LIBRARY(rocksdb STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET rocksdb PROPERTY IMPORTED_LOCATION ${ROCKSDB_LIBRARIES}) +ADD_DEPENDENCIES(rocksdb extern_rocksdb) + +LIST(APPEND external_project_dependencies rocksdb) + diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 8adc7a4e396..2ae4518c9df 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -317,6 +317,11 @@ if (WITH_PSCORE) include(external/libmct) # download, build, install libmct list(APPEND third_party_deps extern_libmct) + + if (WITH_HETERPS) + include(external/rocksdb) # download, build, install libmct + list(APPEND third_party_deps extern_rocksdb) + endif() endif() if(WITH_XBYAK) diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc index dfd55f16e1a..9e2a0b35224 100644 --- a/paddle/fluid/distributed/fleet.cc +++ b/paddle/fluid/distributed/fleet.cc @@ -417,8 +417,10 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync( return; } -void FleetWrapper::LoadModel(const std::string& path, const int mode) { - auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); +void FleetWrapper::LoadModel(const std::string& path, const std::string& mode) { + auto* communicator = Communicator::GetInstance(); + auto ret = communicator->_worker_ptr->load(path, mode); + // auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model from path:" << path << " failed"; @@ -429,8 +431,11 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) { void FleetWrapper::LoadModelOneTable(const uint64_t table_id, const std::string& path, const int mode) { + auto* communicator = Communicator::GetInstance(); auto ret = - pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode)); + communicator->_worker_ptr->load(table_id, path, std::to_string(mode)); + // auto ret = + // pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model of table id: " << table_id diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h index 0da5d1e2bf9..1b2bde85de0 100644 --- a/paddle/fluid/distributed/fleet.h +++ b/paddle/fluid/distributed/fleet.h @@ -200,7 +200,7 @@ class FleetWrapper { void PrintTableStat(const uint64_t table_id); // mode = 0, load all feature // mode = 1, load delta feature, which means load diff - void LoadModel(const std::string& path, const int mode); + void LoadModel(const std::string& path, const std::string& mode); // mode = 0, load all feature // mode = 1, load delta feature, which means load diff void LoadModelOneTable(const uint64_t table_id, const std::string& path, diff --git a/paddle/fluid/distributed/service/ps_local_client.cc b/paddle/fluid/distributed/service/ps_local_client.cc index 2acc845a508..e949b21b02e 100644 --- a/paddle/fluid/distributed/service/ps_local_client.cc +++ b/paddle/fluid/distributed/service/ps_local_client.cc @@ -42,17 +42,17 @@ int32_t PsLocalClient::initialize() { ::std::future PsLocalClient::load(const std::string& epoch, const std::string& mode) { // TODO - // for (auto& it : _table_map) { - // load(it.first, epoch, mode); - //} + for (auto& it : _table_map) { + load(it.first, epoch, mode); + } return done(); } ::std::future PsLocalClient::load(uint32_t table_id, const std::string& epoch, const std::string& mode) { // TODO - // auto* table_ptr = table(table_id); - // table_ptr->load(epoch, mode); + auto* table_ptr = table(table_id); + table_ptr->load(epoch, mode); return done(); } @@ -245,7 +245,6 @@ int32_t PsLocalClient::initialize() { ::std::future PsLocalClient::push_sparse_raw_gradient( size_t table_id, const uint64_t* keys, const float** update_values, size_t num, void* callback) { - VLOG(1) << "wxx push_sparse_raw_gradient"; PSClientClosure* closure = reinterpret_cast(callback); auto* accessor = table_accessor(table_id); auto* table_ptr = table(table_id); diff --git a/paddle/fluid/distributed/service/ps_local_server.h b/paddle/fluid/distributed/service/ps_local_server.h index dfbccc70900..33b0b5fa796 100644 --- a/paddle/fluid/distributed/service/ps_local_server.h +++ b/paddle/fluid/distributed/service/ps_local_server.h @@ -26,9 +26,14 @@ class PsLocalServer : public PSServer { PsLocalServer() {} virtual ~PsLocalServer() {} virtual uint64_t start() { return 0; } - virtual uint64_t start(const std::string& ip, uint32_t port) { return 0; } + virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; } virtual int32_t stop() { return 0; } virtual int32_t port() { return 0; } + virtual int32_t configure( + const PSParameter &config, PSEnvironment &env, size_t server_rank, + const std::vector &server_sub_program = {}) { + return 0; + } private: virtual int32_t initialize() { return 0; } diff --git a/paddle/fluid/distributed/service/server.h b/paddle/fluid/distributed/service/server.h index 74a8cbe44b1..89b089386f5 100644 --- a/paddle/fluid/distributed/service/server.h +++ b/paddle/fluid/distributed/service/server.h @@ -70,7 +70,7 @@ class PSServer { virtual int32_t configure( const PSParameter &config, PSEnvironment &env, size_t server_rank, - const std::vector &server_sub_program = {}) final; + const std::vector &server_sub_program = {}); // return server_ip virtual std::string ip() { return butil::my_ip_cstr(); } diff --git a/paddle/fluid/distributed/table/CMakeLists.txt b/paddle/fluid/distributed/table/CMakeLists.txt index dab39095803..c928ebe90ce 100644 --- a/paddle/fluid/distributed/table/CMakeLists.txt +++ b/paddle/fluid/distributed/table/CMakeLists.txt @@ -9,15 +9,24 @@ set_source_files_properties(${graphDir}/graph_node.cc PROPERTIES COMPILE_FLAGS $ cc_library(graph_node SRCS ${graphDir}/graph_node.cc DEPS WeightedSampler) set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) -cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc -sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} -${RPC_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator) +set(EXTERN_DEP "") +if(WITH_HETERPS) + set(TABLE_SRC common_sparse_table.cc ssd_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc) + set(EXTERN_DEP rocksdb) +else() + set(TABLE_SRC common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc) +endif() + +cc_library(common_table SRCS ${TABLE_SRC} DEPS ${TABLE_DEPS} +${RPC_DEPS} graph_edge graph_node device_context string_helper +simple_threadpool xxhash generator ${EXTERN_DEP}) set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/table/common_sparse_table.cc b/paddle/fluid/distributed/table/common_sparse_table.cc index b667aec186f..e1223face0f 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/table/common_sparse_table.cc @@ -25,83 +25,12 @@ class ValueBlock; } // namespace distributed } // namespace paddle -#define PSERVER_SAVE_SUFFIX ".shard" -using boost::lexical_cast; - namespace paddle { namespace distributed { -enum SaveMode { all, base, delta }; - -struct Meta { - std::string param; - int shard_id; - std::vector names; - std::vector dims; - uint64_t count; - std::unordered_map dims_map; - - explicit Meta(const std::string& metapath) { - std::ifstream file(metapath); - std::string line; - int num_lines = 0; - while (std::getline(file, line)) { - if (StartWith(line, "#")) { - continue; - } - auto pairs = paddle::string::split_string(line, "="); - PADDLE_ENFORCE_EQ( - pairs.size(), 2, - paddle::platform::errors::InvalidArgument( - "info in %s except k=v, but got %s", metapath, line)); - - if (pairs[0] == "param") { - param = pairs[1]; - } - if (pairs[0] == "shard_id") { - shard_id = std::stoi(pairs[1]); - } - if (pairs[0] == "row_names") { - names = paddle::string::split_string(pairs[1], ","); - } - if (pairs[0] == "row_dims") { - auto dims_strs = - paddle::string::split_string(pairs[1], ","); - for (auto& str : dims_strs) { - dims.push_back(std::stoi(str)); - } - } - if (pairs[0] == "count") { - count = std::stoull(pairs[1]); - } - } - for (int x = 0; x < names.size(); ++x) { - dims_map[names[x]] = dims[x]; - } - } - - Meta(std::string param, int shard_id, std::vector row_names, - std::vector dims, uint64_t count) { - this->param = param; - this->shard_id = shard_id; - this->names = row_names; - this->dims = dims; - this->count = count; - } - - std::string ToString() { - std::stringstream ss; - ss << "param=" << param << "\n"; - ss << "shard_id=" << shard_id << "\n"; - ss << "row_names=" << paddle::string::join_strings(names, ',') << "\n"; - ss << "row_dims=" << paddle::string::join_strings(dims, ',') << "\n"; - ss << "count=" << count << "\n"; - return ss.str(); - } -}; - -void ProcessALine(const std::vector& columns, const Meta& meta, - const int64_t id, std::vector>* values) { +void CommonSparseTable::ProcessALine(const std::vector& columns, + const Meta& meta, const int64_t id, + std::vector>* values) { auto colunmn_size = columns.size(); auto load_values = paddle::string::split_string(columns[colunmn_size - 1], ","); @@ -134,8 +63,10 @@ void ProcessALine(const std::vector& columns, const Meta& meta, } } -void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, - const size_t shard_idx, const int64_t total) { +void CommonSparseTable::SaveMetaToText(std::ostream* os, + const CommonAccessorParameter& common, + const size_t shard_idx, + const int64_t total) { // save meta std::stringstream stream; stream << "param=" << common.table_name() << "\n"; @@ -148,8 +79,10 @@ void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, os->write(stream.str().c_str(), sizeof(char) * stream.str().size()); } -int64_t SaveValueToText(std::ostream* os, std::shared_ptr block, - std::shared_ptr<::ThreadPool> pool, const int mode) { +int64_t CommonSparseTable::SaveValueToText(std::ostream* os, + std::shared_ptr block, + std::shared_ptr<::ThreadPool> pool, + const int mode, int shard_id) { int64_t save_num = 0; for (auto& table : block->values_) { for (auto& value : table) { @@ -186,10 +119,10 @@ int64_t SaveValueToText(std::ostream* os, std::shared_ptr block, return save_num; } -int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, - const int pserver_id, const int pserver_num, - const int local_shard_num, - std::vector>* blocks) { +int64_t CommonSparseTable::LoadFromText( + const std::string& valuepath, const std::string& metapath, + const int pserver_id, const int pserver_num, const int local_shard_num, + std::vector>* blocks) { Meta meta = Meta(metapath); int num_lines = 0; @@ -198,7 +131,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, while (std::getline(file, line)) { auto values = paddle::string::split_string(line, "\t"); - auto id = lexical_cast(values[0]); + auto id = lexical_cast(values[0]); if (id % pserver_num != pserver_id) { VLOG(3) << "will not load " << values[0] << " from " << valuepath @@ -388,8 +321,9 @@ int32_t CommonSparseTable::save(const std::string& dirname, int64_t total_ins = 0; for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { // save values - auto shard_save_num = SaveValueToText(vs.get(), shard_values_[shard_id], - _shards_task_pool[shard_id], mode); + auto shard_save_num = + SaveValueToText(vs.get(), shard_values_[shard_id], + _shards_task_pool[shard_id], mode, shard_id); total_ins += shard_save_num; } vs->close(); diff --git a/paddle/fluid/distributed/table/common_sparse_table.h b/paddle/fluid/distributed/table/common_sparse_table.h index 50c295da534..ce3cc11686a 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.h +++ b/paddle/fluid/distributed/table/common_sparse_table.h @@ -32,11 +32,83 @@ #include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/string/string_helper.h" +#define PSERVER_SAVE_SUFFIX ".shard" +using boost::lexical_cast; + namespace paddle { namespace distributed { class SparseOptimizer; +enum SaveMode { all, base, delta }; + +struct Meta { + std::string param; + int shard_id; + std::vector names; + std::vector dims; + uint64_t count; + std::unordered_map dims_map; + + explicit Meta(const std::string& metapath) { + std::ifstream file(metapath); + std::string line; + int num_lines = 0; + while (std::getline(file, line)) { + if (StartWith(line, "#")) { + continue; + } + auto pairs = paddle::string::split_string(line, "="); + PADDLE_ENFORCE_EQ( + pairs.size(), 2, + paddle::platform::errors::InvalidArgument( + "info in %s except k=v, but got %s", metapath, line)); + + if (pairs[0] == "param") { + param = pairs[1]; + } + if (pairs[0] == "shard_id") { + shard_id = std::stoi(pairs[1]); + } + if (pairs[0] == "row_names") { + names = paddle::string::split_string(pairs[1], ","); + } + if (pairs[0] == "row_dims") { + auto dims_strs = + paddle::string::split_string(pairs[1], ","); + for (auto& str : dims_strs) { + dims.push_back(std::stoi(str)); + } + } + if (pairs[0] == "count") { + count = std::stoull(pairs[1]); + } + } + for (int x = 0; x < names.size(); ++x) { + dims_map[names[x]] = dims[x]; + } + } + + Meta(std::string param, int shard_id, std::vector row_names, + std::vector dims, uint64_t count) { + this->param = param; + this->shard_id = shard_id; + this->names = row_names; + this->dims = dims; + this->count = count; + } + + std::string ToString() { + std::stringstream ss; + ss << "param=" << param << "\n"; + ss << "shard_id=" << shard_id << "\n"; + ss << "row_names=" << paddle::string::join_strings(names, ',') << "\n"; + ss << "row_dims=" << paddle::string::join_strings(dims, ',') << "\n"; + ss << "count=" << count << "\n"; + return ss.str(); + } +}; + class CommonSparseTable : public SparseTable { public: CommonSparseTable() { rwlock_.reset(new framework::RWLock); } @@ -56,9 +128,25 @@ class CommonSparseTable : public SparseTable { virtual int32_t initialize_optimizer(); virtual int32_t initialize_recorder(); - int32_t load(const std::string& path, const std::string& param); + virtual int32_t load(const std::string& path, const std::string& param); + + virtual int32_t save(const std::string& path, const std::string& param); + + void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, + const size_t shard_idx, const int64_t total); - int32_t save(const std::string& path, const std::string& param); + int64_t SaveValueToText(std::ostream* os, std::shared_ptr block, + std::shared_ptr<::ThreadPool> pool, const int mode, + int shard_id); + + virtual void ProcessALine(const std::vector& columns, + const Meta& meta, const int64_t id, + std::vector>* values); + + virtual int64_t LoadFromText( + const std::string& valuepath, const std::string& metapath, + const int pserver_id, const int pserver_num, const int local_shard_num, + std::vector>* blocks); virtual std::pair print_table_stat(); virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); @@ -89,7 +177,7 @@ class CommonSparseTable : public SparseTable { virtual int32_t _push_sparse(const uint64_t* keys, const float** values, size_t num); - private: + protected: const int task_pool_size_ = 11; std::vector> _shards_task_pool; diff --git a/paddle/fluid/distributed/table/depends/large_scale_kv.h b/paddle/fluid/distributed/table/depends/large_scale_kv.h index 5c10fca98cd..ac11183d192 100644 --- a/paddle/fluid/distributed/table/depends/large_scale_kv.h +++ b/paddle/fluid/distributed/table/depends/large_scale_kv.h @@ -83,6 +83,7 @@ inline bool probility_entry(VALUE *value, float threshold) { class ValueBlock { public: + typedef typename robin_hood::unordered_map map_type; explicit ValueBlock(const std::vector &value_names, const std::vector &value_dims, const std::vector &value_offsets, @@ -261,6 +262,18 @@ class ValueBlock { value->is_entry_ = state; } + void erase(uint64_t feasign) { + size_t hash = _hasher(feasign); + size_t bucket = compute_bucket(hash); + auto &table = values_[bucket]; + + auto iter = table.find(feasign); + if (iter != table.end()) { + butil::return_object(iter->second); + iter = table.erase(iter); + } + } + void Shrink(const int threshold) { for (auto &table : values_) { for (auto iter = table.begin(); iter != table.end();) { @@ -289,6 +302,23 @@ class ValueBlock { } } + map_type::iterator end() { + return values_[SPARSE_SHARD_BUCKET_NUM - 1].end(); + } + + map_type::iterator Find(uint64_t id) { + size_t hash = _hasher(id); + size_t bucket = compute_bucket(hash); + auto &table = values_[bucket]; + + auto got = table.find(id); + if (got == table.end()) { + return end(); + } else { + return got; + } + } + private: bool Has(const uint64_t id) { size_t hash = _hasher(id); @@ -304,7 +334,7 @@ class ValueBlock { } public: - robin_hood::unordered_map values_[SPARSE_SHARD_BUCKET_NUM]; + map_type values_[SPARSE_SHARD_BUCKET_NUM]; size_t value_length_ = 0; std::hash _hasher; diff --git a/paddle/fluid/distributed/table/depends/rocksdb_warpper.h b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h new file mode 100644 index 00000000000..0e25a89cb14 --- /dev/null +++ b/paddle/fluid/distributed/table/depends/rocksdb_warpper.h @@ -0,0 +1,158 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_HETERPS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace distributed { + +class RocksDBHandler { + public: + RocksDBHandler() {} + ~RocksDBHandler() {} + + static RocksDBHandler* GetInstance() { + static RocksDBHandler handler; + return &handler; + } + + int initialize(const std::string& db_path, const int colnum) { + VLOG(3) << "db path: " << db_path << " colnum: " << colnum; + rocksdb::Options options; + rocksdb::BlockBasedTableOptions bbto; + bbto.block_size = 4 * 1024; + bbto.block_cache = rocksdb::NewLRUCache(64 * 1024 * 1024); + bbto.block_cache_compressed = rocksdb::NewLRUCache(64 * 1024 * 1024); + bbto.cache_index_and_filter_blocks = false; + bbto.filter_policy.reset(rocksdb::NewBloomFilterPolicy(20, false)); + bbto.whole_key_filtering = true; + options.table_factory.reset(rocksdb::NewBlockBasedTableFactory(bbto)); + + options.keep_log_file_num = 100; + options.max_log_file_size = 50 * 1024 * 1024; // 50MB + options.create_if_missing = true; + options.use_direct_reads = true; + options.max_background_flushes = 5; + options.max_background_compactions = 5; + options.base_background_compactions = 10; + options.write_buffer_size = 256 * 1024 * 1024; // 256MB + options.max_write_buffer_number = 8; + options.max_bytes_for_level_base = + options.max_write_buffer_number * options.write_buffer_size; + options.min_write_buffer_number_to_merge = 1; + options.target_file_size_base = 1024 * 1024 * 1024; // 1024MB + options.memtable_prefix_bloom_size_ratio = 0.02; + options.num_levels = 4; + options.max_open_files = -1; + + options.compression = rocksdb::kNoCompression; + options.level0_file_num_compaction_trigger = 8; + options.level0_slowdown_writes_trigger = + 1.8 * options.level0_file_num_compaction_trigger; + options.level0_stop_writes_trigger = + 3.6 * options.level0_file_num_compaction_trigger; + + if (!db_path.empty()) { + std::string rm_cmd = "rm -rf " + db_path; + system(rm_cmd.c_str()); + } + + rocksdb::Status s = rocksdb::DB::Open(options, db_path, &_db); + assert(s.ok()); + _handles.resize(colnum); + for (int i = 0; i < colnum; i++) { + s = _db->CreateColumnFamily(options, "shard_" + std::to_string(i), + &_handles[i]); + assert(s.ok()); + } + LOG(INFO) << "DB initialize success, colnum:" << colnum; + return 0; + } + + int put(int id, const char* key, int key_len, const char* value, + int value_len) { + rocksdb::WriteOptions options; + options.disableWAL = true; + rocksdb::Status s = + _db->Put(options, _handles[id], rocksdb::Slice(key, key_len), + rocksdb::Slice(value, value_len)); + assert(s.ok()); + return 0; + } + + int put_batch(int id, std::vector>& ssd_keys, + std::vector>& ssd_values, int n) { + rocksdb::WriteOptions options; + options.disableWAL = true; + rocksdb::WriteBatch batch(n * 128); + for (int i = 0; i < n; i++) { + batch.Put(_handles[id], + rocksdb::Slice(ssd_keys[i].first, ssd_keys[i].second), + rocksdb::Slice(ssd_values[i].first, ssd_values[i].second)); + } + rocksdb::Status s = _db->Write(options, &batch); + assert(s.ok()); + return 0; + } + + int get(int id, const char* key, int key_len, std::string& value) { + rocksdb::Status s = _db->Get(rocksdb::ReadOptions(), _handles[id], + rocksdb::Slice(key, key_len), &value); + if (s.IsNotFound()) { + return 1; + } + assert(s.ok()); + return 0; + } + + int del_data(int id, const char* key, int key_len) { + rocksdb::WriteOptions options; + options.disableWAL = true; + rocksdb::Status s = + _db->Delete(options, _handles[id], rocksdb::Slice(key, key_len)); + assert(s.ok()); + return 0; + } + + int flush(int id) { + rocksdb::Status s = _db->Flush(rocksdb::FlushOptions(), _handles[id]); + assert(s.ok()); + return 0; + } + + rocksdb::Iterator* get_iterator(int id) { + return _db->NewIterator(rocksdb::ReadOptions(), _handles[id]); + } + + int get_estimate_key_num(uint64_t& num_keys) { + _db->GetAggregatedIntProperty("rocksdb.estimate-num-keys", &num_keys); + return 0; + } + + private: + std::vector _handles; + rocksdb::DB* _db; +}; +} +} +#endif diff --git a/paddle/fluid/distributed/table/ssd_sparse_table.cc b/paddle/fluid/distributed/table/ssd_sparse_table.cc new file mode 100644 index 00000000000..5de6de3d290 --- /dev/null +++ b/paddle/fluid/distributed/table/ssd_sparse_table.cc @@ -0,0 +1,362 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_HETERPS +#include "paddle/fluid/distributed/table/ssd_sparse_table.h" + +DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file"); + +namespace paddle { +namespace distributed { + +int32_t SSDSparseTable::initialize() { + _shards_task_pool.resize(task_pool_size_); + for (int i = 0; i < _shards_task_pool.size(); ++i) { + _shards_task_pool[i].reset(new ::ThreadPool(1)); + } + + sync = _config.common().sync(); + VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync; + + _global_lr = new float(1.0); + + auto common = _config.common(); + int size = static_cast(common.params().size()); + + size_t offset = 0; + for (int x = 0; x < size; ++x) { + auto& varname = common.params()[x]; + auto& dim = common.dims()[x]; + + value_idx_[varname] = x; + value_names_.push_back(varname); + value_dims_.push_back(dim); + value_offsets_.push_back(offset); + initializer_attrs_.push_back(common.initializers()[x]); + + if (varname == "Param") { + param_dim_ = dim; + param_offset_ = offset; + } + + offset += dim; + } + + initialize_value(); + initialize_optimizer(); + initialize_recorder(); + _db = paddle::distributed::RocksDBHandler::GetInstance(); + _db->initialize(FLAGS_rocksdb_path, task_pool_size_); + return 0; +} + +int32_t SSDSparseTable::pull_sparse(float* pull_values, + const PullSparseValue& pull_value) { + auto shard_num = task_pool_size_; + std::vector> tasks(shard_num); + + for (int shard_id = 0; shard_id < shard_num; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, shard_num, &pull_value, &pull_values]() -> int { + auto& block = shard_values_[shard_id]; + + std::vector offsets; + pull_value.Fission(shard_id, shard_num, &offsets); + + for (auto& offset : offsets) { + auto feasign = pull_value.feasigns_[offset]; + auto frequencie = pull_value.frequencies_[offset]; + float* embedding = nullptr; + auto iter = block->Find(feasign); + // in mem + if (iter == block->end()) { + embedding = iter->second->data_.data(); + if (pull_value.is_training_) { + block->AttrUpdate(iter->second, frequencie); + } + } else { + // need create + std::string tmp_str(""); + if (_db->get(shard_id, (char*)&feasign, sizeof(uint64_t), + tmp_str) > 0) { + embedding = block->Init(feasign, true, frequencie); + } else { + // in db + int data_size = tmp_str.size() / sizeof(float); + int value_size = block->value_length_; + float* db_value = (float*)const_cast(tmp_str.c_str()); + VALUE* value = block->InitGet(feasign); + + // copy to mem + memcpy(value->data_.data(), db_value, + value_size * sizeof(float)); + embedding = db_value; + + // param, count, unseen_day + value->count_ = db_value[value_size]; + value->unseen_days_ = db_value[value_size + 1]; + value->is_entry_ = db_value[value_size + 2]; + if (pull_value.is_training_) { + block->AttrUpdate(value, frequencie); + } + } + } + std::copy_n(embedding + param_offset_, param_dim_, + pull_values + param_dim_ * offset); + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + return 0; +} + +int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values, + const uint64_t* keys, size_t num) { + auto shard_num = task_pool_size_; + std::vector> tasks(shard_num); + + std::vector> offset_bucket; + offset_bucket.resize(task_pool_size_); + + for (int x = 0; x < num; ++x) { + auto y = keys[x] % task_pool_size_; + offset_bucket[y].push_back(x); + } + + for (int shard_id = 0; shard_id < shard_num; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, &keys, &pull_values, &offset_bucket]() -> int { + auto& block = shard_values_[shard_id]; + auto& offsets = offset_bucket[shard_id]; + + for (auto& offset : offsets) { + auto feasign = keys[offset]; + auto iter = block->Find(feasign); + VALUE* value = nullptr; + // in mem + if (iter != block->end()) { + value = iter->second; + } else { + // need create + std::string tmp_str(""); + if (_db->get(shard_id, (char*)&feasign, sizeof(uint64_t), + tmp_str) > 0) { + value = block->InitGet(feasign); + } else { + // in db + int data_size = tmp_str.size() / sizeof(float); + int value_size = block->value_length_; + float* db_value = (float*)const_cast(tmp_str.c_str()); + value = block->InitGet(feasign); + + // copy to mem + memcpy(value->data_.data(), db_value, + value_size * sizeof(float)); + + // param, count, unseen_day + value->count_ = db_value[value_size]; + value->unseen_days_ = db_value[value_size + 1]; + value->is_entry_ = db_value[value_size + 2]; + } + } + pull_values[offset] = (char*)value; + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + return 0; +} + +int32_t SSDSparseTable::shrink(const std::string& param) { return 0; } + +int32_t SSDSparseTable::update_table() { + int count = 0; + int value_size = shard_values_[0]->value_length_; + int db_size = 3 + value_size; + float tmp_value[db_size]; + + for (size_t i = 0; i < task_pool_size_; ++i) { + auto& block = shard_values_[i]; + + for (auto& table : block->values_) { + for (auto iter = table.begin(); iter != table.end();) { + VALUE* value = iter->second; + if (value->unseen_days_ >= 1) { + tmp_value[value_size] = value->count_; + tmp_value[value_size + 1] = value->unseen_days_; + tmp_value[value_size + 2] = value->is_entry_; + memcpy(tmp_value, value->data_.data(), sizeof(float) * value_size); + _db->put(i, (char*)&(iter->first), sizeof(uint64_t), (char*)tmp_value, + db_size * sizeof(float)); + count++; + + butil::return_object(iter->second); + iter = table.erase(iter); + } else { + ++iter; + } + } + } + _db->flush(i); + } + VLOG(1) << "Table>> update count: " << count; + return 0; +} + +int64_t SSDSparseTable::SaveValueToText(std::ostream* os, + std::shared_ptr block, + std::shared_ptr<::ThreadPool> pool, + const int mode, int shard_id) { + int64_t save_num = 0; + + for (auto& table : block->values_) { + for (auto& value : table) { + if (mode == SaveMode::delta && !value.second->need_save_) { + continue; + } + + ++save_num; + + std::stringstream ss; + auto* vs = value.second->data_.data(); + + auto id = value.first; + + ss << id << "\t" << value.second->count_ << "\t" + << value.second->unseen_days_ << "\t" << value.second->is_entry_ + << "\t"; + + for (int i = 0; i < block->value_length_ - 1; i++) { + ss << std::to_string(vs[i]) << ","; + } + + ss << std::to_string(vs[block->value_length_ - 1]); + ss << "\n"; + + os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); + + if (mode == SaveMode::base || mode == SaveMode::delta) { + value.second->need_save_ = false; + } + } + } + + if (mode != 1) { + int value_size = block->value_length_; + auto* it = _db->get_iterator(shard_id); + + for (it->SeekToFirst(); it->Valid(); it->Next()) { + float* value = (float*)const_cast(it->value().data()); + std::stringstream ss; + ss << *((uint64_t*)const_cast(it->key().data())) << "\t" + << value[value_size] << "\t" << value[value_size + 1] << "\t" + << value[value_size + 2] << "\t"; + for (int i = 0; i < block->value_length_ - 1; i++) { + ss << std::to_string(value[i]) << ","; + } + + ss << std::to_string(value[block->value_length_ - 1]); + ss << "\n"; + + os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); + } + } + + return save_num; +} + +int32_t SSDSparseTable::load(const std::string& path, + const std::string& param) { + rwlock_->WRLock(); + VLOG(3) << "ssd sparse table load with " << path << " with meta " << param; + LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_, + &shard_values_); + rwlock_->UNLock(); + return 0; +} + +int64_t SSDSparseTable::LoadFromText( + const std::string& valuepath, const std::string& metapath, + const int pserver_id, const int pserver_num, const int local_shard_num, + std::vector>* blocks) { + Meta meta = Meta(metapath); + + int num_lines = 0; + std::ifstream file(valuepath); + std::string line; + + int value_size = shard_values_[0]->value_length_; + int db_size = 3 + value_size; + float tmp_value[db_size]; + + while (std::getline(file, line)) { + auto values = paddle::string::split_string(line, "\t"); + auto id = lexical_cast(values[0]); + + if (id % pserver_num != pserver_id) { + VLOG(3) << "will not load " << values[0] << " from " << valuepath + << ", please check id distribution"; + continue; + } + + auto shard_id = id % local_shard_num; + auto block = blocks->at(shard_id); + + std::vector> kvalues; + ProcessALine(values, meta, id, &kvalues); + + block->Init(id, false); + + VALUE* value_instant = block->GetValue(id); + + if (values.size() == 5) { + value_instant->count_ = lexical_cast(values[1]); + value_instant->unseen_days_ = lexical_cast(values[2]); + value_instant->is_entry_ = + static_cast(lexical_cast(values[3])); + } + + std::vector block_values = block->Get(id, meta.names, meta.dims); + auto blas = GetBlas(); + for (int x = 0; x < meta.names.size(); ++x) { + blas.VCOPY(meta.dims[x], kvalues[x].data(), block_values[x]); + } + VLOG(3) << "loading: " << id + << "unseen day: " << value_instant->unseen_days_; + if (value_instant->unseen_days_ >= 1) { + tmp_value[value_size] = value_instant->count_; + tmp_value[value_size + 1] = value_instant->unseen_days_; + tmp_value[value_size + 2] = value_instant->is_entry_; + memcpy(tmp_value, value_instant->data_.data(), + sizeof(float) * value_size); + _db->put(shard_id, (char*)&(id), sizeof(uint64_t), (char*)tmp_value, + db_size * sizeof(float)); + block->erase(id); + } + } + + return 0; +} + +} // namespace ps +} // namespace paddle +#endif diff --git a/paddle/fluid/distributed/table/ssd_sparse_table.h b/paddle/fluid/distributed/table/ssd_sparse_table.h new file mode 100644 index 00000000000..5e85fa3ce59 --- /dev/null +++ b/paddle/fluid/distributed/table/ssd_sparse_table.h @@ -0,0 +1,61 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/distributed/table/common_sparse_table.h" +#include "paddle/fluid/distributed/table/depends/rocksdb_warpper.h" +#ifdef PADDLE_WITH_HETERPS +namespace paddle { +namespace distributed { +class SSDSparseTable : public CommonSparseTable { + public: + SSDSparseTable() {} + virtual ~SSDSparseTable() {} + + virtual int32_t initialize() override; + + void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, + const size_t shard_idx, const int64_t total); + + int64_t SaveValueToText(std::ostream* os, std::shared_ptr block, + std::shared_ptr<::ThreadPool> pool, const int mode, + int shard_id); + + virtual int64_t LoadFromText( + const std::string& valuepath, const std::string& metapath, + const int pserver_id, const int pserver_num, const int local_shard_num, + std::vector>* blocks); + + virtual int32_t load(const std::string& path, const std::string& param); + + // exchange data + virtual int32_t update_table(); + + virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + + virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, + size_t num); + + virtual int32_t flush() override { return 0; } + virtual int32_t shrink(const std::string& param) override; + virtual void clear() override {} + + private: + RocksDBHandler* _db; + int64_t _cache_tk_size; +}; + +} // namespace ps +} // namespace paddle +#endif diff --git a/paddle/fluid/distributed/table/table.cc b/paddle/fluid/distributed/table/table.cc index 600be954cb5..0f8753c0746 100644 --- a/paddle/fluid/distributed/table/table.cc +++ b/paddle/fluid/distributed/table/table.cc @@ -21,6 +21,9 @@ #include "paddle/fluid/distributed/table/common_graph_table.h" #include "paddle/fluid/distributed/table/common_sparse_table.h" #include "paddle/fluid/distributed/table/sparse_geo_table.h" +#ifdef PADDLE_WITH_HETERPS +#include "paddle/fluid/distributed/table/ssd_sparse_table.h" +#endif #include "paddle/fluid/distributed/table/tensor_accessor.h" #include "paddle/fluid/distributed/table/tensor_table.h" @@ -29,6 +32,9 @@ namespace distributed { REGISTER_PSCORE_CLASS(Table, GraphTable); REGISTER_PSCORE_CLASS(Table, CommonDenseTable); REGISTER_PSCORE_CLASS(Table, CommonSparseTable); +#ifdef PADDLE_WITH_HETERPS +REGISTER_PSCORE_CLASS(Table, SSDSparseTable); +#endif REGISTER_PSCORE_CLASS(Table, SparseGeoTable); REGISTER_PSCORE_CLASS(Table, BarrierTable); REGISTER_PSCORE_CLASS(Table, TensorTable); diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 2e8b551ea4e..9a0ce3900ac 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -118,6 +118,11 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ") for entry attribute.") .SetDefault("none"); + AddAttr("table_class", + "(std::string, default " + ") for table_class.") + .SetDefault("none"); + AddAttr>( "table_names", "(string vector, the split table names that will be fetched from " diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index fa14ad4f63b..a6b542f53ae 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -58,6 +58,8 @@ void BindDistFleetWrapper(py::module* m) { "DistFleetWrapper") .def(py::init([]() { return FleetWrapper::GetInstance(); })) .def("load_sparse", &FleetWrapper::LoadSparseOnServer) + .def("load_model", &FleetWrapper::LoadModel) + .def("load_one_table", &FleetWrapper::LoadModelOneTable) .def("init_server", &FleetWrapper::InitServer) .def("run_server", (uint64_t (FleetWrapper::*)(void)) & FleetWrapper::RunServer) diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 5f9a61371d3..3186df7db58 100644 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -77,6 +77,7 @@ stop_worker = fleet.stop_worker distributed_optimizer = fleet.distributed_optimizer save_inference_model = fleet.save_inference_model save_persistables = fleet.save_persistables +load_model = fleet.load_model minimize = fleet.minimize distributed_model = fleet.distributed_model step = fleet.step diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 5e883f1ac6c..9e5a31d6899 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -540,6 +540,29 @@ class Fleet(object): """ self._runtime_handle._init_server(*args, **kwargs) + def load_model(self, path, mode): + """ + load fleet model from path + + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + fleet.init() + + # build net + # fleet.distributed_optimizer(...) + + fleet.load_model("path", "mode") + + """ + self._runtime_handle.load_model(path, mode) + @is_non_distributed_check @inited_runtime_handler def run_server(self): diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index f18b82eaecd..642d0e427fa 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -35,6 +35,23 @@ def conv_indent(indent): PSERVER_SAVE_SUFFIX = ".shard" +def parse_table_class(varname, o_main_program): + from paddle.fluid.incubate.fleet.parameter_server.ir.public import is_distributed_sparse_op + from paddle.fluid.incubate.fleet.parameter_server.ir.public import is_sparse_op + + for op in o_main_program.global_block().ops: + if not is_distributed_sparse_op(op) and not is_sparse_op(op): + continue + + param_name = op.input("W")[0] + + if param_name == varname and op.type == "lookup_table" or op.type == "lookup_table_v2": + if op.has_attr('table_class') and op.attr("table_class") != "none": + return op.attr('table_class') + else: + return "CommonSparseTable" + + class Accessor: def __init__(self): self.accessor_class = "" @@ -723,13 +740,15 @@ class TheOnePSRuntime(RuntimeBase): table.type = "PS_SPARSE_TABLE" table.shard_num = 256 + common.table_name = self.compiled_strategy.grad_name_to_param_name[ + ctx.origin_varnames()[0]] + if self.compiled_strategy.is_geo_mode(): table.table_class = "SparseGeoTable" else: - table.table_class = "CommonSparseTable" + table.table_class = parse_table_class( + common.table_name, self.origin_main_program) - common.table_name = self.compiled_strategy.grad_name_to_param_name[ - ctx.origin_varnames()[0]] else: table.type = "PS_DENSE_TABLE" table.table_class = "CommonDenseTable" @@ -1044,6 +1063,9 @@ class TheOnePSRuntime(RuntimeBase): def _save_persistables(self, *args, **kwargs): self._ps_inference_save_persistables(*args, **kwargs) + def load_model(self, path, mode): + self._worker.load_model(path, mode) + def _shrink(self, threshold): import paddle.distributed.fleet as fleet fleet.util.barrier() diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 8c48033fc46..30316b77adc 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -967,6 +967,7 @@ def sparse_embedding(input, padding_idx=None, is_test=False, entry=None, + table_class="CommonSparseTable", param_attr=None, dtype='float32'): helper = LayerHelper('sparse_embedding', **locals()) @@ -989,6 +990,10 @@ def sparse_embedding(input, padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( size[0] + padding_idx) + if table_class not in ["CommonSparseTable", "SSDSparseTable"]: + raise ValueError( + "table_class must be in [CommonSparseTable, SSDSparseTable]") + entry_str = "none" if entry is not None: @@ -1011,7 +1016,8 @@ def sparse_embedding(input, 'is_distributed': True, 'remote_prefetch': True, 'is_test': is_test, - 'entry': entry_str + 'entry': entry_str, + 'table_class': table_class }) return tmp diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index d4af3e2f804..89b2a8237dc 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -365,7 +365,41 @@ def ps_gpu_pass(program): for name in remove_var: program.global_block()._remove_var(name) + def _remove_optimizer_var(program): + + embedding_w = {} + for idx, op in list(enumerate(program.global_block().ops)): + if op.type == "lookup_table_grad": + for name in op.input("W"): + embedding_w[name] = 1 + + optimize_vars = [] + optimize_op_role_vars = [] + optimize_need_delete_vars = [] + for op in _get_optimize_ops(program): + for name in op.input("Param"): + if name in embedding_w: + optimize_op_role_vars.extend(op.attr("op_role_var")) + for key_name in op.input_names: + if key_name == "LearningRate": + continue + for var in op.input(key_name): + optimize_vars.append(var) + + optimize_vars = list(set(optimize_vars)) + optimize_op_role_vars = list(set(optimize_op_role_vars)) + + for var in optimize_vars: + if var not in optimize_op_role_vars: + optimize_need_delete_vars.append(var) + need_delete_optimize_vars = list(set(optimize_need_delete_vars)) + + for name in need_delete_optimize_vars: + if program.global_block().has_var(name): + program.global_block()._remove_var(name) + _add_push_box_sparse_op(program) + _remove_optimizer_var(program) _remove_lookup_table_grad_op_and_var(program) return program -- GitLab